""" 任务管理器测试用例 覆盖场景: 1. 创建检测任务(按需、定期、批量) 2. 查询任务列表 3. 执行检测任务 4. 任务状态管理 5. 错误处理和重试机制 """ import pytest from datetime import datetime, date, timedelta from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_ from app.services.risk_detection.task_manager import TaskManager from app.services.risk_detection.task_manager.scheduler import DetectionScheduler from app.models.risk_detection import ( DetectionTask, DetectionRule, TaskType, TaskStatus, RiskLevel, ) from app.models.risk_detection import TaskExecution class TestTaskManager: """任务管理器测试类""" @pytest.fixture def mock_db_session(self): """创建Mock数据库会话""" session = AsyncMock(spec=AsyncSession) session.execute = AsyncMock() session.flush = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.commit = AsyncMock() return session @pytest.fixture def task_manager(self, mock_db_session): """创建任务管理器实例""" return TaskManager(mock_db_session) @pytest.fixture def mock_detection_rules(self): """Mock检测规则数据""" return [ DetectionRule( rule_id="rule_revenue_001", rule_name="收入完整性检测", algorithm_code="REVENUE_INTEGRITY_CHECK", description="检测平台充值与申报收入的匹配度", is_enabled=True, parameters={"threshold": 0.1}, created_at=datetime.now(), ), DetectionRule( rule_id="rule_private_001", rule_name="私户收款检测", algorithm_code="PRIVATE_ACCOUNT_DETECTION", description="识别使用私人账户收款的风险", is_enabled=True, parameters={"threshold_amount": 10000}, created_at=datetime.now(), ), ] @pytest.mark.asyncio async def test_create_task_on_demand(self, task_manager, mock_db_session, mock_detection_rules): """测试创建按需检测任务""" # 设置mock返回值 mock_db_session.execute.return_value.scalar_one_or_none.return_value = None mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules # 创建任务 task = await task_manager.create_task( task_name="测试按需检测任务", task_type=TaskType.ON_DEMAND, entity_ids=["ZB_TEST_001"], entity_type="streamer", period="2024-01", rule_ids=["rule_revenue_001"], parameters={"description": "测试任务"} ) # 验证结果 assert task.task_id is not None assert task.task_name == "测试按需检测任务" assert task.task_type == TaskType.ON_DEMAND assert task.status == TaskStatus.PENDING assert task.entity_type == "streamer" assert task.period == "2024-01" assert task.total_entities == 1 assert task.parameters["description"] == "测试任务" # 验证执行记录被创建 mock_db_session.add.assert_called() @pytest.mark.asyncio async def test_create_task_periodic(self, task_manager, mock_db_session, mock_detection_rules): """测试创建定期检测任务""" # 设置mock返回值 mock_db_session.execute.return_value.scalar_one_or_none.return_value = None mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules # 创建任务 task = await task_manager.create_task( task_name="测试定期检测任务", task_type=TaskType.PERIODIC, entity_ids=["ZB_TEST_001", "ZB_TEST_002"], entity_type="streamer", period="2024-01", parameters={"schedule_type": "periodic"} ) # 验证结果 assert task.task_id is not None assert task.task_name == "测试定期检测任务" assert task.task_type == TaskType.PERIODIC assert task.total_entities == 2 @pytest.mark.asyncio async def test_create_task_batch(self, task_manager, mock_db_session, mock_detection_rules): """测试创建批量检测任务""" # 设置mock返回值 mock_db_session.execute.return_value.scalar_one_or_none.return_value = None mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules # 创建任务 task = await task_manager.create_task( task_name="测试批量检测任务", task_type=TaskType.BATCH, entity_ids=["ZB_TEST_001", "ZB_TEST_002", "ZB_TEST_003"], entity_type="mcn", period="2024-01", rule_ids=["rule_private_001"], parameters={"batch_size": 3} ) # 验证结果 assert task.task_id is not None assert task.task_name == "测试批量检测任务" assert task.task_type == TaskType.BATCH assert task.total_entities == 3 @pytest.mark.asyncio async def test_get_task(self, task_manager, mock_db_session): """测试获取任务""" # 创建mock任务 mock_task = DetectionTask( task_id="task_test_001", task_name="测试任务", task_type=TaskType.ON_DEMAND, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=1, processed_entities=0, parameters={}, created_at=datetime.now(), ) # 设置mock返回值 mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_task # 执行测试 task = await task_manager.get_task("task_test_001") # 验证结果 assert task is not None assert task.task_id == "task_test_001" assert task.task_name == "测试任务" assert task.status == TaskStatus.PENDING @pytest.mark.asyncio async def test_get_task_not_found(self, task_manager, mock_db_session): """测试获取不存在的任务""" # 设置mock返回值为None mock_db_session.execute.return_value.scalar_one_or_none.return_value = None # 执行测试 task = await task_manager.get_task("task_not_exist") # 验证结果 assert task is None @pytest.mark.asyncio async def test_list_tasks(self, task_manager, mock_db_session): """测试查询任务列表""" # 创建mock任务列表 mock_tasks = [ DetectionTask( task_id="task_001", task_name="任务1", task_type=TaskType.ON_DEMAND, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=1, processed_entities=0, parameters={}, created_at=datetime.now(), ), DetectionTask( task_id="task_002", task_name="任务2", task_type=TaskType.PERIODIC, status=TaskStatus.COMPLETED, entity_type="mcn", period="2024-01", total_entities=5, processed_entities=5, parameters={}, created_at=datetime.now(), ), ] # 设置mock返回值 mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks # 执行测试 - 无过滤条件 tasks = await task_manager.list_tasks() assert len(tasks) == 2 # 执行测试 - 按状态过滤 tasks = await task_manager.list_tasks(status=TaskStatus.PENDING) assert len(tasks) == 2 # 实际项目中这里应该过滤 # 执行测试 - 按类型过滤 tasks = await task_manager.list_tasks(task_type=TaskType.ON_DEMAND) assert len(tasks) == 2 @pytest.mark.asyncio async def test_execute_task_success(self, task_manager, mock_db_session): """测试成功执行任务""" # 创建mock任务 mock_task = DetectionTask( task_id="task_test_001", task_name="测试任务", task_type=TaskType.ON_DEMAND, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=1, processed_entities=0, parameters={}, created_at=datetime.now(), ) # 创建mock任务执行记录 mock_executions = [ TaskExecution( execution_id="exec_001", task_id="task_test_001", entity_id="ZB_TEST_001", entity_type="streamer", period="2024-01", rule_ids="rule_revenue_001", parameters={}, status="pending", created_at=datetime.now(), ) ] # 设置mock返回值 mock_db_session.execute.side_effect = [ AsyncMock(scalar_one_or_none=lambda: mock_task), # 获取任务 AsyncMock(scalars=lambda: AsyncMock(all=lambda: mock_executions)), # 获取执行记录 AsyncMock(scalars=lambda: AsyncMock(all=lambda: [])), # 获取规则 ] # Mock RuleEngine with patch('app.services.risk_detection.task_manager.task_manager.RuleEngine') as MockRuleEngine: mock_rule_engine = AsyncMock() mock_rule_engine.execute_detection = AsyncMock(return_value={ "results": [ { "entity_id": "ZB_TEST_001", "risk_level": RiskLevel.MEDIUM, "risk_score": 65.0, "description": "发现收入不匹配问题", "suggestion": "请检查收入申报是否准确", } ], "summary": { "total_detections": 1, "risk_distribution": {"MEDIUM": 1}, "avg_score": 65.0, } }) MockRuleEngine.return_value = mock_rule_engine # Mock算法注册 with patch('app.services.risk_detection.task_manager.task_manager.RevenueIntegrityAlgorithm'): # 执行任务 result = await task_manager.execute_task("task_test_001") # 验证结果 assert result["task_id"] == "task_test_001" assert result["status"] == "completed" assert result["summary"]["total_entities"] == 1 assert result["summary"]["total_detections"] == 1 @pytest.mark.asyncio async def test_execute_task_not_found(self, task_manager, mock_db_session): """测试执行不存在的任务""" # 设置mock返回值为None mock_db_session.execute.return_value.scalar_one_or_none.return_value = None # 验证抛出异常 with pytest.raises(ValueError, match="任务不存在"): await task_manager.execute_task("task_not_exist") @pytest.mark.asyncio async def test_execute_task_with_failure(self, task_manager, mock_db_session): """测试任务执行失败""" # 创建mock任务 mock_task = DetectionTask( task_id="task_test_001", task_name="测试任务", task_type=TaskType.ON_DEMAND, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=1, processed_entities=0, parameters={}, created_at=datetime.now(), ) # 设置mock返回值 mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_task # Mock RuleEngine抛出异常 with patch('app.services.risk_detection.task_manager.task_manager.RuleEngine') as MockRuleEngine: mock_rule_engine = AsyncMock() mock_rule_engine.execute_detection = AsyncMock(side_effect=Exception("执行失败")) MockRuleEngine.return_value = mock_rule_engine # 验证抛出异常 with pytest.raises(Exception, match="执行失败"): await task_manager.execute_task("task_test_001") class TestDetectionScheduler: """任务调度器测试类""" @pytest.fixture def mock_db_session(self): """创建Mock数据库会话""" session = AsyncMock(spec=AsyncSession) session.execute = AsyncMock() return session @pytest.fixture def scheduler(self, mock_db_session): """创建调度器实例""" return DetectionScheduler(mock_db_session) @pytest.mark.asyncio async def test_schedule_periodic_detection(self, scheduler, mock_db_session): """测试调度定期检测""" # Mock TaskManager.create_task mock_task = DetectionTask( task_id="task_periodic_001", task_name="定期检测-2024-01", task_type=TaskType.PERIODIC, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=2, processed_entities=0, parameters={"schedule_type": "periodic"}, created_at=datetime.now(), ) with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create: mock_create.return_value = mock_task # 执行测试 task = await scheduler.schedule_periodic_detection( entity_ids=["ZB_TEST_001", "ZB_TEST_002"], entity_type="streamer", period="2024-01", auto_execute=False ) # 验证结果 assert task.task_id == "task_periodic_001" assert task.task_type == TaskType.PERIODIC mock_create.assert_called_once() @pytest.mark.asyncio async def test_schedule_on_demand_detection(self, scheduler, mock_db_session): """测试调度按需检测""" # Mock TaskManager.create_task mock_task = DetectionTask( task_id="task_ondemand_001", task_name="按需检测-ZB_TEST_001-2024-01", task_type=TaskType.ON_DEMAND, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=1, processed_entities=0, parameters={"schedule_type": "on_demand"}, created_at=datetime.now(), ) with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create: mock_create.return_value = mock_task # 执行测试 task = await scheduler.schedule_on_demand_detection( entity_id="ZB_TEST_001", entity_type="streamer", period="2024-01", auto_execute=False ) # 验证结果 assert task.task_id == "task_ondemand_001" assert task.task_type == TaskType.ON_DEMAND mock_create.assert_called_once() @pytest.mark.asyncio async def test_schedule_batch_detection(self, scheduler, mock_db_session): """测试调度批量检测""" # Mock TaskManager.create_task mock_task = DetectionTask( task_id="task_batch_001", task_name="批量检测-2024-01", task_type=TaskType.BATCH, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=10, processed_entities=0, parameters={"schedule_type": "batch", "batch_size": 10}, created_at=datetime.now(), ) with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create: mock_create.return_value = mock_task # 执行测试 task = await scheduler.schedule_batch_detection( entity_ids=[f"ZB_TEST_{i:03d}" for i in range(1, 11)], entity_type="streamer", period="2024-01", auto_execute=False ) # 验证结果 assert task.task_id == "task_batch_001" assert task.task_type == TaskType.BATCH assert task.total_entities == 10 mock_create.assert_called_once() @pytest.mark.asyncio async def test_check_pending_tasks(self, scheduler, mock_db_session): """测试检查待执行任务""" # 创建mock待执行任务 mock_tasks = [ DetectionTask( task_id="task_001", task_name="待执行任务1", task_type=TaskType.PERIODIC, status=TaskStatus.PENDING, entity_type="streamer", period="2024-01", total_entities=1, processed_entities=0, parameters={}, created_at=datetime.now(), ), DetectionTask( task_id="task_002", task_name="待执行任务2", task_type=TaskType.BATCH, status=TaskStatus.PENDING, entity_type="mcn", period="2024-01", total_entities=5, processed_entities=0, parameters={}, created_at=datetime.now(), ), ] # 设置mock返回值 mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks # 执行测试 tasks = await scheduler.check_pending_tasks() # 验证结果 assert len(tasks) == 2 assert all(task.status == TaskStatus.PENDING for task in tasks) @pytest.mark.asyncio async def test_retry_failed_tasks(self, scheduler, mock_db_session): """测试重试失败任务""" # 创建mock失败任务 mock_tasks = [ DetectionTask( task_id="task_failed_001", task_name="失败任务1", task_type=TaskType.ON_DEMAND, status=TaskStatus.FAILED, entity_type="streamer", period="2024-01", total_entities=1, processed_entities=0, retry_count=1, parameters={}, created_at=datetime.now(), ) ] # 设置mock返回值 mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks # Mock任务执行 with patch.object(scheduler, '_execute_task_async', new_callable=AsyncMock): # 执行测试 retried_tasks = await scheduler.retry_failed_tasks(max_retries=3) # 验证结果 assert len(retried_tasks) == 1 assert retried_tasks[0].status == TaskStatus.PENDING assert retried_tasks[0].retry_count == 2 @pytest.mark.asyncio async def test_cleanup_old_tasks(self, scheduler, mock_db_session): """测试清理旧任务""" # 创建mock旧任务 old_date = datetime.now() - timedelta(days=35) mock_tasks = [ DetectionTask( task_id="task_old_001", task_name="旧任务1", task_type=TaskType.ON_DEMAND, status=TaskStatus.COMPLETED, entity_type="streamer", period="2023-01", total_entities=1, processed_entities=1, parameters={}, created_at=old_date, ), DetectionTask( task_id="task_old_002", task_name="旧任务2", task_type=TaskType.ON_DEMAND, status=TaskStatus.FAILED, entity_type="mcn", period="2023-01", total_entities=1, processed_entities=0, parameters={}, created_at=old_date, ), ] # 设置mock返回值 mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks mock_db_session.delete = MagicMock() # 执行测试 count = await scheduler.cleanup_old_tasks(days=30) # 验证结果 assert count == 2 assert mock_db_session.delete.call_count == 2 mock_db_session.commit.assert_called_once() if __name__ == "__main__": pytest.main([__file__, "-v"])