505 lines
20 KiB
Python
505 lines
20 KiB
Python
"""
|
|
风险检测API端点测试用例
|
|
|
|
覆盖场景:
|
|
1. 任务创建API
|
|
2. 任务执行API
|
|
3. 任务查询API
|
|
4. 结果查询API
|
|
5. 汇总报告API
|
|
6. 算法列表API
|
|
"""
|
|
import pytest
|
|
from datetime import datetime, date
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.main import app
|
|
from app.models.risk_detection import TaskType, TaskStatus, RiskLevel
|
|
|
|
|
|
class TestRiskDetectionAPI:
|
|
"""风险检测API测试类"""
|
|
|
|
@pytest.fixture
|
|
def client(self):
|
|
"""创建测试客户端"""
|
|
return TestClient(app)
|
|
|
|
@pytest.fixture
|
|
def mock_task_data(self):
|
|
"""Mock任务数据"""
|
|
return {
|
|
"task_name": "API测试任务",
|
|
"task_type": "on_demand",
|
|
"entity_ids": ["ZB_TEST_001"],
|
|
"entity_type": "streamer",
|
|
"period": "2024-01",
|
|
"rule_ids": ["REVENUE_INTEGRITY_CHECK"],
|
|
"parameters": {"description": "API测试"}
|
|
}
|
|
|
|
def test_create_task_api(self, client, mock_task_data):
|
|
"""测试创建任务API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
mock_task = MagicMock()
|
|
mock_task.task_id = "task_api_001"
|
|
mock_task.task_name = "API测试任务"
|
|
mock_task.task_type = TaskType.ON_DEMAND
|
|
mock_task.status = TaskStatus.PENDING
|
|
mock_task.entity_type = "streamer"
|
|
mock_task.period = "2024-01"
|
|
mock_task.total_entities = 1
|
|
mock_task.processed_entities = 0
|
|
mock_task.result_count = 0
|
|
mock_task.parameters = {"description": "API测试"}
|
|
mock_task.created_at = datetime.now()
|
|
mock_task.started_at = None
|
|
mock_task.completed_at = None
|
|
|
|
MockTaskManager.return_value.create_task = AsyncMock(return_value=mock_task)
|
|
|
|
response = client.post("/api/v1/risk-detection/tasks", json=mock_task_data)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["task_id"] == "task_api_001"
|
|
assert data["task_name"] == "API测试任务"
|
|
assert data["task_type"] == "on_demand"
|
|
assert data["status"] == "pending"
|
|
|
|
def test_create_task_with_auto_binding(self, client, mock_task_data):
|
|
"""测试自动实体绑定的任务创建API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
mock_task = MagicMock()
|
|
mock_task.task_id = "task_auto_001"
|
|
mock_task.task_name = "自动绑定任务"
|
|
mock_task.task_type = TaskType.ON_DEMAND
|
|
mock_task.status = TaskStatus.PENDING
|
|
mock_task.entity_type = "streamer"
|
|
mock_task.period = "2024-01"
|
|
mock_task.total_entities = 1
|
|
mock_task.processed_entities = 0
|
|
mock_task.result_count = 0
|
|
mock_task.parameters = {}
|
|
mock_task.created_at = datetime.now()
|
|
mock_task.started_at = None
|
|
mock_task.completed_at = None
|
|
|
|
MockTaskManager.return_value.create_task = AsyncMock(return_value=mock_task)
|
|
|
|
# 不提供实体信息,测试自动绑定
|
|
request_data = {
|
|
"task_name": "自动绑定任务",
|
|
"task_type": "on_demand",
|
|
"period": "2024-01",
|
|
"rule_ids": ["REVENUE_INTEGRITY_CHECK"]
|
|
}
|
|
|
|
response = client.post("/api/v1/risk-detection/tasks", json=request_data)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["task_id"] == "task_auto_001"
|
|
|
|
def test_execute_task_api(self, client):
|
|
"""测试执行任务API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
MockTaskManager.return_value.execute_task = AsyncMock(return_value={
|
|
"task_id": "task_exec_001",
|
|
"status": "completed",
|
|
"summary": {
|
|
"total_entities": 1,
|
|
"total_detections": 2,
|
|
"risk_distribution": {"HIGH": 1, "MEDIUM": 1},
|
|
"avg_score": 75.0
|
|
},
|
|
"results": [],
|
|
"executed_at": datetime.now().isoformat()
|
|
})
|
|
|
|
response = client.post("/api/v1/risk-detection/tasks/task_exec_001/execute")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["task_id"] == "task_exec_001"
|
|
assert data["status"] == "completed"
|
|
assert data["result_count"] == 0
|
|
|
|
def test_list_tasks_api(self, client):
|
|
"""测试查询任务列表API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
mock_tasks = []
|
|
for i in range(3):
|
|
task = MagicMock()
|
|
task.task_id = f"task_list_{i:03d}"
|
|
task.task_name = f"任务{i+1}"
|
|
task.task_type = TaskType.ON_DEMAND if i % 2 == 0 else TaskType.PERIODIC
|
|
task.status = TaskStatus.COMPLETED if i > 0 else TaskStatus.PENDING
|
|
task.entity_type = "streamer"
|
|
task.period = "2024-01"
|
|
task.total_entities = 1
|
|
task.processed_entities = 1 if i > 0 else 0
|
|
task.result_count = i
|
|
task.parameters = {}
|
|
task.created_at = datetime.now()
|
|
task.started_at = datetime.now()
|
|
task.completed_at = datetime.now() if i > 0 else None
|
|
mock_tasks.append(task)
|
|
|
|
MockTaskManager.return_value.list_tasks = AsyncMock(return_value=mock_tasks)
|
|
|
|
# 测试无过滤条件查询
|
|
response = client.get("/api/v1/risk-detection/tasks")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 3
|
|
|
|
# 测试按状态过滤
|
|
response = client.get("/api/v1/risk-detection/tasks?status=pending")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 1
|
|
|
|
# 测试按类型过滤
|
|
response = client.get("/api/v1/risk-detection/tasks?task_type=on_demand")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 2
|
|
|
|
def test_get_task_detail_api(self, client):
|
|
"""测试获取任务详情API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
mock_task = MagicMock()
|
|
mock_task.task_id = "task_detail_001"
|
|
mock_task.task_name = "任务详情测试"
|
|
mock_task.task_type = TaskType.ON_DEMAND
|
|
mock_task.status = TaskStatus.COMPLETED
|
|
mock_task.entity_type = "streamer"
|
|
mock_task.period = "2024-01"
|
|
mock_task.total_entities = 5
|
|
mock_task.processed_entities = 5
|
|
mock_task.result_count = 8
|
|
mock_task.summary = {
|
|
"total_entities": 5,
|
|
"total_detections": 8,
|
|
"risk_distribution": {"HIGH": 2, "MEDIUM": 3, "LOW": 3},
|
|
"avg_score": 60.0
|
|
}
|
|
mock_task.parameters = {"test": "value"}
|
|
mock_task.error_message = None
|
|
mock_task.created_at = datetime.now()
|
|
mock_task.started_at = datetime.now()
|
|
mock_task.completed_at = datetime.now()
|
|
|
|
MockTaskManager.return_value.get_task = AsyncMock(return_value=mock_task)
|
|
|
|
response = client.get("/api/v1/risk-detection/tasks/task_detail_001")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["task_id"] == "task_detail_001"
|
|
assert data["task_name"] == "任务详情测试"
|
|
assert data["status"] == "completed"
|
|
assert data["total_entities"] == 5
|
|
assert data["processed_entities"] == 5
|
|
assert data["result_count"] == 8
|
|
assert "summary" in data
|
|
|
|
def test_get_task_not_found_api(self, client):
|
|
"""测试获取不存在的任务API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
MockTaskManager.return_value.get_task = AsyncMock(return_value=None)
|
|
|
|
response = client.get("/api/v1/risk-detection/tasks/task_not_exist")
|
|
|
|
assert response.status_code == 404
|
|
data = response.json()
|
|
assert "任务不存在" in data["detail"]
|
|
|
|
def test_execute_detection_api(self, client):
|
|
"""测试即时检测API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
MockTaskManager.return_value.create_task = MagicMock()
|
|
MockTaskManager.return_value.execute_task = AsyncMock(return_value={
|
|
"task_id": "task_detection_001",
|
|
"status": "completed",
|
|
"summary": {
|
|
"total_entities": 1,
|
|
"total_detections": 1,
|
|
"risk_distribution": {"HIGH": 1},
|
|
"avg_score": 85.0
|
|
},
|
|
"results": [],
|
|
"executed_at": datetime.now().isoformat()
|
|
})
|
|
|
|
request_data = {
|
|
"task_name": "即时检测任务",
|
|
"entity_id": "ZB_TEST_001",
|
|
"entity_type": "streamer",
|
|
"period": "2024-01",
|
|
"rule_ids": ["REVENUE_INTEGRITY_CHECK", "PRIVATE_ACCOUNT_DETECTION"],
|
|
"parameters": {}
|
|
}
|
|
|
|
response = client.post("/api/v1/risk-detection/execute", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["task_id"] == "task_detection_001"
|
|
assert data["status"] == "completed"
|
|
assert data["result_count"] == 0
|
|
|
|
def test_list_results_api(self, client):
|
|
"""测试查询检测结果API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
|
|
mock_results = []
|
|
for i in range(3):
|
|
result = MagicMock()
|
|
result.id = i + 1
|
|
result.task_id = f"task_result_{i%2+1:03d}"
|
|
result.rule_id = f"rule_{i}"
|
|
result.entity_id = f"ZB_TEST_{i%2+1:03d}"
|
|
result.entity_type = "streamer"
|
|
result.risk_level = RiskLevel.HIGH if i == 0 else (RiskLevel.MEDIUM if i == 1 else RiskLevel.LOW)
|
|
result.risk_score = 85.0 if i == 0 else (65.0 if i == 1 else 30.0)
|
|
result.description = f"风险描述{i+1}"
|
|
result.suggestion = f"处理建议{i+1}"
|
|
result.risk_data = {"test": f"data_{i}"}
|
|
result.evidence = []
|
|
result.detected_at = datetime.now()
|
|
mock_results.append(result)
|
|
|
|
MockDetectionResult.__table__.columns = MagicMock()
|
|
MockDetectionResult.task_id = MagicMock()
|
|
MockDetectionResult.entity_id = MagicMock()
|
|
MockDetectionResult.entity_type = MagicMock()
|
|
MockDetectionResult.risk_level = MagicMock()
|
|
|
|
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
|
|
mock_execute = AsyncMock()
|
|
mock_execute.scalars.return_value.all.return_value = mock_results
|
|
|
|
MockSelect.return_value.where.return_value.order_by.return_value.limit.return_value.execute = mock_execute
|
|
|
|
response = client.get("/api/v1/risk-detection/results")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) == 3
|
|
|
|
def test_get_result_detail_api(self, client):
|
|
"""测试获取检测结果详情API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
|
|
result = MagicMock()
|
|
result.id = 1
|
|
result.task_id = "task_001"
|
|
result.rule_id = "rule_revenue_001"
|
|
result.entity_id = "ZB_TEST_001"
|
|
result.entity_type = "streamer"
|
|
result.risk_level = RiskLevel.HIGH
|
|
result.risk_score = 85.0
|
|
result.description = "高风险问题"
|
|
result.suggestion = "立即处理"
|
|
result.risk_data = {"key": "value"}
|
|
result.evidence = [{"type": "test", "description": "test"}]
|
|
result.detected_at = datetime.now()
|
|
|
|
MockDetectionResult.__table__.columns = MagicMock()
|
|
|
|
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
|
|
mock_execute = AsyncMock()
|
|
mock_execute.scalar_one_or_none.return_value = result
|
|
|
|
MockSelect.return_value.where.return_value.execute = mock_execute
|
|
|
|
response = client.get("/api/v1/risk-detection/results/1")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == 1
|
|
assert data["task_id"] == "task_001"
|
|
assert data["result"]["risk_level"] == "HIGH"
|
|
assert data["result"]["risk_score"] == 85.0
|
|
|
|
def test_get_result_not_found_api(self, client):
|
|
"""测试获取不存在的检测结果API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
|
|
mock_execute = AsyncMock()
|
|
mock_execute.scalar_one_or_none.return_value = None
|
|
|
|
MockSelect.return_value.where.return_value.execute = mock_execute
|
|
|
|
response = client.get("/api/v1/risk-detection/results/9999")
|
|
|
|
assert response.status_code == 404
|
|
data = response.json()
|
|
assert "结果不存在" in data["detail"]
|
|
|
|
def test_get_detection_summary_api(self, client):
|
|
"""测试获取检测汇总API"""
|
|
mock_results = []
|
|
for i in range(5):
|
|
result = MagicMock()
|
|
result.risk_level = RiskLevel.HIGH if i < 2 else (RiskLevel.MEDIUM if i < 4 else RiskLevel.LOW)
|
|
result.risk_score = 85.0 if i < 2 else (65.0 if i < 4 else 30.0)
|
|
mock_results.append(result)
|
|
|
|
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
|
|
MockDetectionResult.__table__.columns = MagicMock()
|
|
|
|
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
|
|
mock_execute = AsyncMock()
|
|
mock_execute.scalars.return_value.all.return_value = mock_results
|
|
|
|
MockSelect.return_value.where.return_value.execute = mock_execute
|
|
|
|
response = client.get(
|
|
"/api/v1/risk-detection/summary"
|
|
"?entity_id=ZB_TEST_001"
|
|
"&entity_type=streamer"
|
|
"&period=2024-01"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["entity_id"] == "ZB_TEST_001"
|
|
assert data["entity_type"] == "streamer"
|
|
assert data["period"] == "2024-01"
|
|
assert data["total_detections"] == 5
|
|
assert data["high_risk_count"] == 2
|
|
assert data["avg_risk_score"] == 61.0 # (85*2 + 65*2 + 30*1) / 5
|
|
assert len(data["recommendations"]) > 0
|
|
|
|
def test_get_detection_summary_empty(self, client):
|
|
"""测试获取空检测汇总API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
|
|
MockDetectionResult.__table__.columns = MagicMock()
|
|
|
|
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
|
|
mock_execute = AsyncMock()
|
|
mock_execute.scalars.return_value.all.return_value = []
|
|
|
|
MockSelect.return_value.where.return_value.execute = mock_execute
|
|
|
|
response = client.get(
|
|
"/api/v1/risk-detection/summary"
|
|
"?entity_id=ZB_TEST_999"
|
|
"&entity_type=streamer"
|
|
"&period=2024-01"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total_detections"] == 0
|
|
assert data["risk_distribution"] == {}
|
|
assert data["avg_risk_score"] == 0.0
|
|
assert data["high_risk_count"] == 0
|
|
|
|
def test_get_algorithms_api(self, client):
|
|
"""测试获取算法列表API"""
|
|
response = client.get("/api/v1/risk-detection/algorithms")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) >= 5
|
|
|
|
# 验证算法结构
|
|
for algo in data:
|
|
assert "code" in algo
|
|
assert "name" in algo
|
|
assert "description" in algo
|
|
assert "parameters" in algo
|
|
|
|
# 验证特定算法存在
|
|
algo_codes = [algo["code"] for algo in data]
|
|
assert "REVENUE_INTEGRITY_CHECK" in algo_codes
|
|
assert "PRIVATE_ACCOUNT_DETECTION" in algo_codes
|
|
assert "INVOICE_FRAUD_DETECTION" in algo_codes
|
|
assert "EXPENSE_ANOMALY_DETECTION" in algo_codes
|
|
assert "TAX_RISK_ASSESSMENT" in algo_codes
|
|
|
|
def test_get_rules_api(self, client):
|
|
"""测试获取规则列表API"""
|
|
response = client.get("/api/v1/risk-detection/rules")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) >= 5
|
|
|
|
# 验证规则结构
|
|
for rule in data:
|
|
assert "rule_id" in rule
|
|
assert "rule_name" in rule
|
|
assert "algorithm_code" in rule
|
|
assert "is_enabled" in rule
|
|
|
|
def test_create_task_validation_error(self, client):
|
|
"""测试创建任务参数验证错误"""
|
|
# 缺少必需参数
|
|
invalid_data = {
|
|
"task_name": "测试任务"
|
|
# 缺少task_type、period等
|
|
}
|
|
|
|
response = client.post("/api/v1/risk-detection/tasks", json=invalid_data)
|
|
|
|
# 返回422表示验证错误
|
|
assert response.status_code == 422
|
|
|
|
def test_execute_task_not_found(self, client):
|
|
"""测试执行不存在的任务"""
|
|
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
|
|
MockTaskManager.return_value.execute_task = AsyncMock(
|
|
side_effect=Exception("任务不存在")
|
|
)
|
|
|
|
response = client.post("/api/v1/risk-detection/tasks/task_not_exist/execute")
|
|
|
|
assert response.status_code == 500
|
|
data = response.json()
|
|
assert "执行任务失败" in data["detail"]
|
|
|
|
def test_list_results_with_filters(self, client):
|
|
"""测试带过滤条件的查询结果API"""
|
|
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
|
|
MockDetectionResult.__table__.columns = MagicMock()
|
|
|
|
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
|
|
mock_execute = AsyncMock()
|
|
mock_execute.scalars.return_value.all.return_value = []
|
|
|
|
MockSelect.return_value.where.return_value.order_by.return_value.limit.return_value.execute = mock_execute
|
|
|
|
# 测试按任务ID过滤
|
|
response = client.get(
|
|
"/api/v1/risk-detection/results"
|
|
"?task_id=task_filter_001"
|
|
"&limit=50"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
# 测试按实体过滤
|
|
response = client.get(
|
|
"/api/v1/risk-detection/results"
|
|
"?entity_id=ZB_TEST_001"
|
|
"&entity_type=streamer"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
# 测试按风险等级过滤
|
|
response = client.get(
|
|
"/api/v1/risk-detection/results"
|
|
"?risk_level=HIGH"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|