deep-risk/backend/app/tests/test_risk_detection_api.py
2025-12-14 20:08:27 +08:00

505 lines
20 KiB
Python

"""
风险检测API端点测试用例
覆盖场景:
1. 任务创建API
2. 任务执行API
3. 任务查询API
4. 结果查询API
5. 汇总报告API
6. 算法列表API
"""
import pytest
from datetime import datetime, date
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from app.main import app
from app.models.risk_detection import TaskType, TaskStatus, RiskLevel
class TestRiskDetectionAPI:
"""风险检测API测试类"""
@pytest.fixture
def client(self):
"""创建测试客户端"""
return TestClient(app)
@pytest.fixture
def mock_task_data(self):
"""Mock任务数据"""
return {
"task_name": "API测试任务",
"task_type": "on_demand",
"entity_ids": ["ZB_TEST_001"],
"entity_type": "streamer",
"period": "2024-01",
"rule_ids": ["REVENUE_INTEGRITY_CHECK"],
"parameters": {"description": "API测试"}
}
def test_create_task_api(self, client, mock_task_data):
"""测试创建任务API"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
mock_task = MagicMock()
mock_task.task_id = "task_api_001"
mock_task.task_name = "API测试任务"
mock_task.task_type = TaskType.ON_DEMAND
mock_task.status = TaskStatus.PENDING
mock_task.entity_type = "streamer"
mock_task.period = "2024-01"
mock_task.total_entities = 1
mock_task.processed_entities = 0
mock_task.result_count = 0
mock_task.parameters = {"description": "API测试"}
mock_task.created_at = datetime.now()
mock_task.started_at = None
mock_task.completed_at = None
MockTaskManager.return_value.create_task = AsyncMock(return_value=mock_task)
response = client.post("/api/v1/risk-detection/tasks", json=mock_task_data)
assert response.status_code == 201
data = response.json()
assert data["task_id"] == "task_api_001"
assert data["task_name"] == "API测试任务"
assert data["task_type"] == "on_demand"
assert data["status"] == "pending"
def test_create_task_with_auto_binding(self, client, mock_task_data):
"""测试自动实体绑定的任务创建API"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
mock_task = MagicMock()
mock_task.task_id = "task_auto_001"
mock_task.task_name = "自动绑定任务"
mock_task.task_type = TaskType.ON_DEMAND
mock_task.status = TaskStatus.PENDING
mock_task.entity_type = "streamer"
mock_task.period = "2024-01"
mock_task.total_entities = 1
mock_task.processed_entities = 0
mock_task.result_count = 0
mock_task.parameters = {}
mock_task.created_at = datetime.now()
mock_task.started_at = None
mock_task.completed_at = None
MockTaskManager.return_value.create_task = AsyncMock(return_value=mock_task)
# 不提供实体信息,测试自动绑定
request_data = {
"task_name": "自动绑定任务",
"task_type": "on_demand",
"period": "2024-01",
"rule_ids": ["REVENUE_INTEGRITY_CHECK"]
}
response = client.post("/api/v1/risk-detection/tasks", json=request_data)
assert response.status_code == 201
data = response.json()
assert data["task_id"] == "task_auto_001"
def test_execute_task_api(self, client):
"""测试执行任务API"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
MockTaskManager.return_value.execute_task = AsyncMock(return_value={
"task_id": "task_exec_001",
"status": "completed",
"summary": {
"total_entities": 1,
"total_detections": 2,
"risk_distribution": {"HIGH": 1, "MEDIUM": 1},
"avg_score": 75.0
},
"results": [],
"executed_at": datetime.now().isoformat()
})
response = client.post("/api/v1/risk-detection/tasks/task_exec_001/execute")
assert response.status_code == 200
data = response.json()
assert data["task_id"] == "task_exec_001"
assert data["status"] == "completed"
assert data["result_count"] == 0
def test_list_tasks_api(self, client):
"""测试查询任务列表API"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
mock_tasks = []
for i in range(3):
task = MagicMock()
task.task_id = f"task_list_{i:03d}"
task.task_name = f"任务{i+1}"
task.task_type = TaskType.ON_DEMAND if i % 2 == 0 else TaskType.PERIODIC
task.status = TaskStatus.COMPLETED if i > 0 else TaskStatus.PENDING
task.entity_type = "streamer"
task.period = "2024-01"
task.total_entities = 1
task.processed_entities = 1 if i > 0 else 0
task.result_count = i
task.parameters = {}
task.created_at = datetime.now()
task.started_at = datetime.now()
task.completed_at = datetime.now() if i > 0 else None
mock_tasks.append(task)
MockTaskManager.return_value.list_tasks = AsyncMock(return_value=mock_tasks)
# 测试无过滤条件查询
response = client.get("/api/v1/risk-detection/tasks")
assert response.status_code == 200
data = response.json()
assert len(data) == 3
# 测试按状态过滤
response = client.get("/api/v1/risk-detection/tasks?status=pending")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
# 测试按类型过滤
response = client.get("/api/v1/risk-detection/tasks?task_type=on_demand")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
def test_get_task_detail_api(self, client):
"""测试获取任务详情API"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
mock_task = MagicMock()
mock_task.task_id = "task_detail_001"
mock_task.task_name = "任务详情测试"
mock_task.task_type = TaskType.ON_DEMAND
mock_task.status = TaskStatus.COMPLETED
mock_task.entity_type = "streamer"
mock_task.period = "2024-01"
mock_task.total_entities = 5
mock_task.processed_entities = 5
mock_task.result_count = 8
mock_task.summary = {
"total_entities": 5,
"total_detections": 8,
"risk_distribution": {"HIGH": 2, "MEDIUM": 3, "LOW": 3},
"avg_score": 60.0
}
mock_task.parameters = {"test": "value"}
mock_task.error_message = None
mock_task.created_at = datetime.now()
mock_task.started_at = datetime.now()
mock_task.completed_at = datetime.now()
MockTaskManager.return_value.get_task = AsyncMock(return_value=mock_task)
response = client.get("/api/v1/risk-detection/tasks/task_detail_001")
assert response.status_code == 200
data = response.json()
assert data["task_id"] == "task_detail_001"
assert data["task_name"] == "任务详情测试"
assert data["status"] == "completed"
assert data["total_entities"] == 5
assert data["processed_entities"] == 5
assert data["result_count"] == 8
assert "summary" in data
def test_get_task_not_found_api(self, client):
"""测试获取不存在的任务API"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
MockTaskManager.return_value.get_task = AsyncMock(return_value=None)
response = client.get("/api/v1/risk-detection/tasks/task_not_exist")
assert response.status_code == 404
data = response.json()
assert "任务不存在" in data["detail"]
def test_execute_detection_api(self, client):
"""测试即时检测API"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
MockTaskManager.return_value.create_task = MagicMock()
MockTaskManager.return_value.execute_task = AsyncMock(return_value={
"task_id": "task_detection_001",
"status": "completed",
"summary": {
"total_entities": 1,
"total_detections": 1,
"risk_distribution": {"HIGH": 1},
"avg_score": 85.0
},
"results": [],
"executed_at": datetime.now().isoformat()
})
request_data = {
"task_name": "即时检测任务",
"entity_id": "ZB_TEST_001",
"entity_type": "streamer",
"period": "2024-01",
"rule_ids": ["REVENUE_INTEGRITY_CHECK", "PRIVATE_ACCOUNT_DETECTION"],
"parameters": {}
}
response = client.post("/api/v1/risk-detection/execute", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["task_id"] == "task_detection_001"
assert data["status"] == "completed"
assert data["result_count"] == 0
def test_list_results_api(self, client):
"""测试查询检测结果API"""
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
mock_results = []
for i in range(3):
result = MagicMock()
result.id = i + 1
result.task_id = f"task_result_{i%2+1:03d}"
result.rule_id = f"rule_{i}"
result.entity_id = f"ZB_TEST_{i%2+1:03d}"
result.entity_type = "streamer"
result.risk_level = RiskLevel.HIGH if i == 0 else (RiskLevel.MEDIUM if i == 1 else RiskLevel.LOW)
result.risk_score = 85.0 if i == 0 else (65.0 if i == 1 else 30.0)
result.description = f"风险描述{i+1}"
result.suggestion = f"处理建议{i+1}"
result.risk_data = {"test": f"data_{i}"}
result.evidence = []
result.detected_at = datetime.now()
mock_results.append(result)
MockDetectionResult.__table__.columns = MagicMock()
MockDetectionResult.task_id = MagicMock()
MockDetectionResult.entity_id = MagicMock()
MockDetectionResult.entity_type = MagicMock()
MockDetectionResult.risk_level = MagicMock()
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
mock_execute = AsyncMock()
mock_execute.scalars.return_value.all.return_value = mock_results
MockSelect.return_value.where.return_value.order_by.return_value.limit.return_value.execute = mock_execute
response = client.get("/api/v1/risk-detection/results")
assert response.status_code == 200
data = response.json()
assert len(data) == 3
def test_get_result_detail_api(self, client):
"""测试获取检测结果详情API"""
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
result = MagicMock()
result.id = 1
result.task_id = "task_001"
result.rule_id = "rule_revenue_001"
result.entity_id = "ZB_TEST_001"
result.entity_type = "streamer"
result.risk_level = RiskLevel.HIGH
result.risk_score = 85.0
result.description = "高风险问题"
result.suggestion = "立即处理"
result.risk_data = {"key": "value"}
result.evidence = [{"type": "test", "description": "test"}]
result.detected_at = datetime.now()
MockDetectionResult.__table__.columns = MagicMock()
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
mock_execute = AsyncMock()
mock_execute.scalar_one_or_none.return_value = result
MockSelect.return_value.where.return_value.execute = mock_execute
response = client.get("/api/v1/risk-detection/results/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
assert data["task_id"] == "task_001"
assert data["result"]["risk_level"] == "HIGH"
assert data["result"]["risk_score"] == 85.0
def test_get_result_not_found_api(self, client):
"""测试获取不存在的检测结果API"""
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
mock_execute = AsyncMock()
mock_execute.scalar_one_or_none.return_value = None
MockSelect.return_value.where.return_value.execute = mock_execute
response = client.get("/api/v1/risk-detection/results/9999")
assert response.status_code == 404
data = response.json()
assert "结果不存在" in data["detail"]
def test_get_detection_summary_api(self, client):
"""测试获取检测汇总API"""
mock_results = []
for i in range(5):
result = MagicMock()
result.risk_level = RiskLevel.HIGH if i < 2 else (RiskLevel.MEDIUM if i < 4 else RiskLevel.LOW)
result.risk_score = 85.0 if i < 2 else (65.0 if i < 4 else 30.0)
mock_results.append(result)
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
MockDetectionResult.__table__.columns = MagicMock()
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
mock_execute = AsyncMock()
mock_execute.scalars.return_value.all.return_value = mock_results
MockSelect.return_value.where.return_value.execute = mock_execute
response = client.get(
"/api/v1/risk-detection/summary"
"?entity_id=ZB_TEST_001"
"&entity_type=streamer"
"&period=2024-01"
)
assert response.status_code == 200
data = response.json()
assert data["entity_id"] == "ZB_TEST_001"
assert data["entity_type"] == "streamer"
assert data["period"] == "2024-01"
assert data["total_detections"] == 5
assert data["high_risk_count"] == 2
assert data["avg_risk_score"] == 61.0 # (85*2 + 65*2 + 30*1) / 5
assert len(data["recommendations"]) > 0
def test_get_detection_summary_empty(self, client):
"""测试获取空检测汇总API"""
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
MockDetectionResult.__table__.columns = MagicMock()
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
mock_execute = AsyncMock()
mock_execute.scalars.return_value.all.return_value = []
MockSelect.return_value.where.return_value.execute = mock_execute
response = client.get(
"/api/v1/risk-detection/summary"
"?entity_id=ZB_TEST_999"
"&entity_type=streamer"
"&period=2024-01"
)
assert response.status_code == 200
data = response.json()
assert data["total_detections"] == 0
assert data["risk_distribution"] == {}
assert data["avg_risk_score"] == 0.0
assert data["high_risk_count"] == 0
def test_get_algorithms_api(self, client):
"""测试获取算法列表API"""
response = client.get("/api/v1/risk-detection/algorithms")
assert response.status_code == 200
data = response.json()
assert len(data) >= 5
# 验证算法结构
for algo in data:
assert "code" in algo
assert "name" in algo
assert "description" in algo
assert "parameters" in algo
# 验证特定算法存在
algo_codes = [algo["code"] for algo in data]
assert "REVENUE_INTEGRITY_CHECK" in algo_codes
assert "PRIVATE_ACCOUNT_DETECTION" in algo_codes
assert "INVOICE_FRAUD_DETECTION" in algo_codes
assert "EXPENSE_ANOMALY_DETECTION" in algo_codes
assert "TAX_RISK_ASSESSMENT" in algo_codes
def test_get_rules_api(self, client):
"""测试获取规则列表API"""
response = client.get("/api/v1/risk-detection/rules")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 5
# 验证规则结构
for rule in data:
assert "rule_id" in rule
assert "rule_name" in rule
assert "algorithm_code" in rule
assert "is_enabled" in rule
def test_create_task_validation_error(self, client):
"""测试创建任务参数验证错误"""
# 缺少必需参数
invalid_data = {
"task_name": "测试任务"
# 缺少task_type、period等
}
response = client.post("/api/v1/risk-detection/tasks", json=invalid_data)
# 返回422表示验证错误
assert response.status_code == 422
def test_execute_task_not_found(self, client):
"""测试执行不存在的任务"""
with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager:
MockTaskManager.return_value.execute_task = AsyncMock(
side_effect=Exception("任务不存在")
)
response = client.post("/api/v1/risk-detection/tasks/task_not_exist/execute")
assert response.status_code == 500
data = response.json()
assert "执行任务失败" in data["detail"]
def test_list_results_with_filters(self, client):
"""测试带过滤条件的查询结果API"""
with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult:
MockDetectionResult.__table__.columns = MagicMock()
with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect:
mock_execute = AsyncMock()
mock_execute.scalars.return_value.all.return_value = []
MockSelect.return_value.where.return_value.order_by.return_value.limit.return_value.execute = mock_execute
# 测试按任务ID过滤
response = client.get(
"/api/v1/risk-detection/results"
"?task_id=task_filter_001"
"&limit=50"
)
assert response.status_code == 200
# 测试按实体过滤
response = client.get(
"/api/v1/risk-detection/results"
"?entity_id=ZB_TEST_001"
"&entity_type=streamer"
)
assert response.status_code == 200
# 测试按风险等级过滤
response = client.get(
"/api/v1/risk-detection/results"
"?risk_level=HIGH"
)
assert response.status_code == 200
if __name__ == "__main__":
pytest.main([__file__, "-v"])