""" 风险检测API端点测试用例 覆盖场景: 1. 任务创建API 2. 任务执行API 3. 任务查询API 4. 结果查询API 5. 汇总报告API 6. 算法列表API """ import pytest from datetime import datetime, date from unittest.mock import AsyncMock, MagicMock, patch from fastapi.testclient import TestClient from app.main import app from app.models.risk_detection import TaskType, TaskStatus, RiskLevel class TestRiskDetectionAPI: """风险检测API测试类""" @pytest.fixture def client(self): """创建测试客户端""" return TestClient(app) @pytest.fixture def mock_task_data(self): """Mock任务数据""" return { "task_name": "API测试任务", "task_type": "on_demand", "entity_ids": ["ZB_TEST_001"], "entity_type": "streamer", "period": "2024-01", "rule_ids": ["REVENUE_INTEGRITY_CHECK"], "parameters": {"description": "API测试"} } def test_create_task_api(self, client, mock_task_data): """测试创建任务API""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: mock_task = MagicMock() mock_task.task_id = "task_api_001" mock_task.task_name = "API测试任务" mock_task.task_type = TaskType.ON_DEMAND mock_task.status = TaskStatus.PENDING mock_task.entity_type = "streamer" mock_task.period = "2024-01" mock_task.total_entities = 1 mock_task.processed_entities = 0 mock_task.result_count = 0 mock_task.parameters = {"description": "API测试"} mock_task.created_at = datetime.now() mock_task.started_at = None mock_task.completed_at = None MockTaskManager.return_value.create_task = AsyncMock(return_value=mock_task) response = client.post("/api/v1/risk-detection/tasks", json=mock_task_data) assert response.status_code == 201 data = response.json() assert data["task_id"] == "task_api_001" assert data["task_name"] == "API测试任务" assert data["task_type"] == "on_demand" assert data["status"] == "pending" def test_create_task_with_auto_binding(self, client, mock_task_data): """测试自动实体绑定的任务创建API""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: mock_task = MagicMock() mock_task.task_id = "task_auto_001" mock_task.task_name = "自动绑定任务" mock_task.task_type = TaskType.ON_DEMAND mock_task.status = TaskStatus.PENDING mock_task.entity_type = "streamer" mock_task.period = "2024-01" mock_task.total_entities = 1 mock_task.processed_entities = 0 mock_task.result_count = 0 mock_task.parameters = {} mock_task.created_at = datetime.now() mock_task.started_at = None mock_task.completed_at = None MockTaskManager.return_value.create_task = AsyncMock(return_value=mock_task) # 不提供实体信息,测试自动绑定 request_data = { "task_name": "自动绑定任务", "task_type": "on_demand", "period": "2024-01", "rule_ids": ["REVENUE_INTEGRITY_CHECK"] } response = client.post("/api/v1/risk-detection/tasks", json=request_data) assert response.status_code == 201 data = response.json() assert data["task_id"] == "task_auto_001" def test_execute_task_api(self, client): """测试执行任务API""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: MockTaskManager.return_value.execute_task = AsyncMock(return_value={ "task_id": "task_exec_001", "status": "completed", "summary": { "total_entities": 1, "total_detections": 2, "risk_distribution": {"HIGH": 1, "MEDIUM": 1}, "avg_score": 75.0 }, "results": [], "executed_at": datetime.now().isoformat() }) response = client.post("/api/v1/risk-detection/tasks/task_exec_001/execute") assert response.status_code == 200 data = response.json() assert data["task_id"] == "task_exec_001" assert data["status"] == "completed" assert data["result_count"] == 0 def test_list_tasks_api(self, client): """测试查询任务列表API""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: mock_tasks = [] for i in range(3): task = MagicMock() task.task_id = f"task_list_{i:03d}" task.task_name = f"任务{i+1}" task.task_type = TaskType.ON_DEMAND if i % 2 == 0 else TaskType.PERIODIC task.status = TaskStatus.COMPLETED if i > 0 else TaskStatus.PENDING task.entity_type = "streamer" task.period = "2024-01" task.total_entities = 1 task.processed_entities = 1 if i > 0 else 0 task.result_count = i task.parameters = {} task.created_at = datetime.now() task.started_at = datetime.now() task.completed_at = datetime.now() if i > 0 else None mock_tasks.append(task) MockTaskManager.return_value.list_tasks = AsyncMock(return_value=mock_tasks) # 测试无过滤条件查询 response = client.get("/api/v1/risk-detection/tasks") assert response.status_code == 200 data = response.json() assert len(data) == 3 # 测试按状态过滤 response = client.get("/api/v1/risk-detection/tasks?status=pending") assert response.status_code == 200 data = response.json() assert len(data) == 1 # 测试按类型过滤 response = client.get("/api/v1/risk-detection/tasks?task_type=on_demand") assert response.status_code == 200 data = response.json() assert len(data) == 2 def test_get_task_detail_api(self, client): """测试获取任务详情API""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: mock_task = MagicMock() mock_task.task_id = "task_detail_001" mock_task.task_name = "任务详情测试" mock_task.task_type = TaskType.ON_DEMAND mock_task.status = TaskStatus.COMPLETED mock_task.entity_type = "streamer" mock_task.period = "2024-01" mock_task.total_entities = 5 mock_task.processed_entities = 5 mock_task.result_count = 8 mock_task.summary = { "total_entities": 5, "total_detections": 8, "risk_distribution": {"HIGH": 2, "MEDIUM": 3, "LOW": 3}, "avg_score": 60.0 } mock_task.parameters = {"test": "value"} mock_task.error_message = None mock_task.created_at = datetime.now() mock_task.started_at = datetime.now() mock_task.completed_at = datetime.now() MockTaskManager.return_value.get_task = AsyncMock(return_value=mock_task) response = client.get("/api/v1/risk-detection/tasks/task_detail_001") assert response.status_code == 200 data = response.json() assert data["task_id"] == "task_detail_001" assert data["task_name"] == "任务详情测试" assert data["status"] == "completed" assert data["total_entities"] == 5 assert data["processed_entities"] == 5 assert data["result_count"] == 8 assert "summary" in data def test_get_task_not_found_api(self, client): """测试获取不存在的任务API""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: MockTaskManager.return_value.get_task = AsyncMock(return_value=None) response = client.get("/api/v1/risk-detection/tasks/task_not_exist") assert response.status_code == 404 data = response.json() assert "任务不存在" in data["detail"] def test_execute_detection_api(self, client): """测试即时检测API""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: MockTaskManager.return_value.create_task = MagicMock() MockTaskManager.return_value.execute_task = AsyncMock(return_value={ "task_id": "task_detection_001", "status": "completed", "summary": { "total_entities": 1, "total_detections": 1, "risk_distribution": {"HIGH": 1}, "avg_score": 85.0 }, "results": [], "executed_at": datetime.now().isoformat() }) request_data = { "task_name": "即时检测任务", "entity_id": "ZB_TEST_001", "entity_type": "streamer", "period": "2024-01", "rule_ids": ["REVENUE_INTEGRITY_CHECK", "PRIVATE_ACCOUNT_DETECTION"], "parameters": {} } response = client.post("/api/v1/risk-detection/execute", json=request_data) assert response.status_code == 200 data = response.json() assert data["task_id"] == "task_detection_001" assert data["status"] == "completed" assert data["result_count"] == 0 def test_list_results_api(self, client): """测试查询检测结果API""" with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult: mock_results = [] for i in range(3): result = MagicMock() result.id = i + 1 result.task_id = f"task_result_{i%2+1:03d}" result.rule_id = f"rule_{i}" result.entity_id = f"ZB_TEST_{i%2+1:03d}" result.entity_type = "streamer" result.risk_level = RiskLevel.HIGH if i == 0 else (RiskLevel.MEDIUM if i == 1 else RiskLevel.LOW) result.risk_score = 85.0 if i == 0 else (65.0 if i == 1 else 30.0) result.description = f"风险描述{i+1}" result.suggestion = f"处理建议{i+1}" result.risk_data = {"test": f"data_{i}"} result.evidence = [] result.detected_at = datetime.now() mock_results.append(result) MockDetectionResult.__table__.columns = MagicMock() MockDetectionResult.task_id = MagicMock() MockDetectionResult.entity_id = MagicMock() MockDetectionResult.entity_type = MagicMock() MockDetectionResult.risk_level = MagicMock() with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect: mock_execute = AsyncMock() mock_execute.scalars.return_value.all.return_value = mock_results MockSelect.return_value.where.return_value.order_by.return_value.limit.return_value.execute = mock_execute response = client.get("/api/v1/risk-detection/results") assert response.status_code == 200 data = response.json() assert len(data) == 3 def test_get_result_detail_api(self, client): """测试获取检测结果详情API""" with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult: result = MagicMock() result.id = 1 result.task_id = "task_001" result.rule_id = "rule_revenue_001" result.entity_id = "ZB_TEST_001" result.entity_type = "streamer" result.risk_level = RiskLevel.HIGH result.risk_score = 85.0 result.description = "高风险问题" result.suggestion = "立即处理" result.risk_data = {"key": "value"} result.evidence = [{"type": "test", "description": "test"}] result.detected_at = datetime.now() MockDetectionResult.__table__.columns = MagicMock() with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect: mock_execute = AsyncMock() mock_execute.scalar_one_or_none.return_value = result MockSelect.return_value.where.return_value.execute = mock_execute response = client.get("/api/v1/risk-detection/results/1") assert response.status_code == 200 data = response.json() assert data["id"] == 1 assert data["task_id"] == "task_001" assert data["result"]["risk_level"] == "HIGH" assert data["result"]["risk_score"] == 85.0 def test_get_result_not_found_api(self, client): """测试获取不存在的检测结果API""" with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect: mock_execute = AsyncMock() mock_execute.scalar_one_or_none.return_value = None MockSelect.return_value.where.return_value.execute = mock_execute response = client.get("/api/v1/risk-detection/results/9999") assert response.status_code == 404 data = response.json() assert "结果不存在" in data["detail"] def test_get_detection_summary_api(self, client): """测试获取检测汇总API""" mock_results = [] for i in range(5): result = MagicMock() result.risk_level = RiskLevel.HIGH if i < 2 else (RiskLevel.MEDIUM if i < 4 else RiskLevel.LOW) result.risk_score = 85.0 if i < 2 else (65.0 if i < 4 else 30.0) mock_results.append(result) with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult: MockDetectionResult.__table__.columns = MagicMock() with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect: mock_execute = AsyncMock() mock_execute.scalars.return_value.all.return_value = mock_results MockSelect.return_value.where.return_value.execute = mock_execute response = client.get( "/api/v1/risk-detection/summary" "?entity_id=ZB_TEST_001" "&entity_type=streamer" "&period=2024-01" ) assert response.status_code == 200 data = response.json() assert data["entity_id"] == "ZB_TEST_001" assert data["entity_type"] == "streamer" assert data["period"] == "2024-01" assert data["total_detections"] == 5 assert data["high_risk_count"] == 2 assert data["avg_risk_score"] == 61.0 # (85*2 + 65*2 + 30*1) / 5 assert len(data["recommendations"]) > 0 def test_get_detection_summary_empty(self, client): """测试获取空检测汇总API""" with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult: MockDetectionResult.__table__.columns = MagicMock() with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect: mock_execute = AsyncMock() mock_execute.scalars.return_value.all.return_value = [] MockSelect.return_value.where.return_value.execute = mock_execute response = client.get( "/api/v1/risk-detection/summary" "?entity_id=ZB_TEST_999" "&entity_type=streamer" "&period=2024-01" ) assert response.status_code == 200 data = response.json() assert data["total_detections"] == 0 assert data["risk_distribution"] == {} assert data["avg_risk_score"] == 0.0 assert data["high_risk_count"] == 0 def test_get_algorithms_api(self, client): """测试获取算法列表API""" response = client.get("/api/v1/risk-detection/algorithms") assert response.status_code == 200 data = response.json() assert len(data) >= 5 # 验证算法结构 for algo in data: assert "code" in algo assert "name" in algo assert "description" in algo assert "parameters" in algo # 验证特定算法存在 algo_codes = [algo["code"] for algo in data] assert "REVENUE_INTEGRITY_CHECK" in algo_codes assert "PRIVATE_ACCOUNT_DETECTION" in algo_codes assert "INVOICE_FRAUD_DETECTION" in algo_codes assert "EXPENSE_ANOMALY_DETECTION" in algo_codes assert "TAX_RISK_ASSESSMENT" in algo_codes def test_get_rules_api(self, client): """测试获取规则列表API""" response = client.get("/api/v1/risk-detection/rules") assert response.status_code == 200 data = response.json() assert isinstance(data, list) assert len(data) >= 5 # 验证规则结构 for rule in data: assert "rule_id" in rule assert "rule_name" in rule assert "algorithm_code" in rule assert "is_enabled" in rule def test_create_task_validation_error(self, client): """测试创建任务参数验证错误""" # 缺少必需参数 invalid_data = { "task_name": "测试任务" # 缺少task_type、period等 } response = client.post("/api/v1/risk-detection/tasks", json=invalid_data) # 返回422表示验证错误 assert response.status_code == 422 def test_execute_task_not_found(self, client): """测试执行不存在的任务""" with patch('app.api.v1.endpoints.risk_detection.TaskManager') as MockTaskManager: MockTaskManager.return_value.execute_task = AsyncMock( side_effect=Exception("任务不存在") ) response = client.post("/api/v1/risk-detection/tasks/task_not_exist/execute") assert response.status_code == 500 data = response.json() assert "执行任务失败" in data["detail"] def test_list_results_with_filters(self, client): """测试带过滤条件的查询结果API""" with patch('app.api.v1.endpoints.risk_detection.DetectionResult') as MockDetectionResult: MockDetectionResult.__table__.columns = MagicMock() with patch('app.api.v1.endpoints.risk_detection.select') as MockSelect: mock_execute = AsyncMock() mock_execute.scalars.return_value.all.return_value = [] MockSelect.return_value.where.return_value.order_by.return_value.limit.return_value.execute = mock_execute # 测试按任务ID过滤 response = client.get( "/api/v1/risk-detection/results" "?task_id=task_filter_001" "&limit=50" ) assert response.status_code == 200 # 测试按实体过滤 response = client.get( "/api/v1/risk-detection/results" "?entity_id=ZB_TEST_001" "&entity_type=streamer" ) assert response.status_code == 200 # 测试按风险等级过滤 response = client.get( "/api/v1/risk-detection/results" "?risk_level=HIGH" ) assert response.status_code == 200 if __name__ == "__main__": pytest.main([__file__, "-v"])