deep-risk/backend/app/tests/test_task_manager.py
2025-12-14 20:08:27 +08:00

585 lines
21 KiB
Python

"""
任务管理器测试用例
覆盖场景:
1. 创建检测任务(按需、定期、批量)
2. 查询任务列表
3. 执行检测任务
4. 任务状态管理
5. 错误处理和重试机制
"""
import pytest
from datetime import datetime, date, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from app.services.risk_detection.task_manager import TaskManager
from app.services.risk_detection.task_manager.scheduler import DetectionScheduler
from app.models.risk_detection import (
DetectionTask,
DetectionRule,
TaskType,
TaskStatus,
RiskLevel,
)
from app.models.risk_detection import TaskExecution
class TestTaskManager:
"""任务管理器测试类"""
@pytest.fixture
def mock_db_session(self):
"""创建Mock数据库会话"""
session = AsyncMock(spec=AsyncSession)
session.execute = AsyncMock()
session.flush = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.commit = AsyncMock()
return session
@pytest.fixture
def task_manager(self, mock_db_session):
"""创建任务管理器实例"""
return TaskManager(mock_db_session)
@pytest.fixture
def mock_detection_rules(self):
"""Mock检测规则数据"""
return [
DetectionRule(
rule_id="rule_revenue_001",
rule_name="收入完整性检测",
algorithm_code="REVENUE_INTEGRITY_CHECK",
description="检测平台充值与申报收入的匹配度",
is_enabled=True,
parameters={"threshold": 0.1},
created_at=datetime.now(),
),
DetectionRule(
rule_id="rule_private_001",
rule_name="私户收款检测",
algorithm_code="PRIVATE_ACCOUNT_DETECTION",
description="识别使用私人账户收款的风险",
is_enabled=True,
parameters={"threshold_amount": 10000},
created_at=datetime.now(),
),
]
@pytest.mark.asyncio
async def test_create_task_on_demand(self, task_manager, mock_db_session, mock_detection_rules):
"""测试创建按需检测任务"""
# 设置mock返回值
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules
# 创建任务
task = await task_manager.create_task(
task_name="测试按需检测任务",
task_type=TaskType.ON_DEMAND,
entity_ids=["ZB_TEST_001"],
entity_type="streamer",
period="2024-01",
rule_ids=["rule_revenue_001"],
parameters={"description": "测试任务"}
)
# 验证结果
assert task.task_id is not None
assert task.task_name == "测试按需检测任务"
assert task.task_type == TaskType.ON_DEMAND
assert task.status == TaskStatus.PENDING
assert task.entity_type == "streamer"
assert task.period == "2024-01"
assert task.total_entities == 1
assert task.parameters["description"] == "测试任务"
# 验证执行记录被创建
mock_db_session.add.assert_called()
@pytest.mark.asyncio
async def test_create_task_periodic(self, task_manager, mock_db_session, mock_detection_rules):
"""测试创建定期检测任务"""
# 设置mock返回值
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules
# 创建任务
task = await task_manager.create_task(
task_name="测试定期检测任务",
task_type=TaskType.PERIODIC,
entity_ids=["ZB_TEST_001", "ZB_TEST_002"],
entity_type="streamer",
period="2024-01",
parameters={"schedule_type": "periodic"}
)
# 验证结果
assert task.task_id is not None
assert task.task_name == "测试定期检测任务"
assert task.task_type == TaskType.PERIODIC
assert task.total_entities == 2
@pytest.mark.asyncio
async def test_create_task_batch(self, task_manager, mock_db_session, mock_detection_rules):
"""测试创建批量检测任务"""
# 设置mock返回值
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules
# 创建任务
task = await task_manager.create_task(
task_name="测试批量检测任务",
task_type=TaskType.BATCH,
entity_ids=["ZB_TEST_001", "ZB_TEST_002", "ZB_TEST_003"],
entity_type="mcn",
period="2024-01",
rule_ids=["rule_private_001"],
parameters={"batch_size": 3}
)
# 验证结果
assert task.task_id is not None
assert task.task_name == "测试批量检测任务"
assert task.task_type == TaskType.BATCH
assert task.total_entities == 3
@pytest.mark.asyncio
async def test_get_task(self, task_manager, mock_db_session):
"""测试获取任务"""
# 创建mock任务
mock_task = DetectionTask(
task_id="task_test_001",
task_name="测试任务",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=1,
processed_entities=0,
parameters={},
created_at=datetime.now(),
)
# 设置mock返回值
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_task
# 执行测试
task = await task_manager.get_task("task_test_001")
# 验证结果
assert task is not None
assert task.task_id == "task_test_001"
assert task.task_name == "测试任务"
assert task.status == TaskStatus.PENDING
@pytest.mark.asyncio
async def test_get_task_not_found(self, task_manager, mock_db_session):
"""测试获取不存在的任务"""
# 设置mock返回值为None
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
# 执行测试
task = await task_manager.get_task("task_not_exist")
# 验证结果
assert task is None
@pytest.mark.asyncio
async def test_list_tasks(self, task_manager, mock_db_session):
"""测试查询任务列表"""
# 创建mock任务列表
mock_tasks = [
DetectionTask(
task_id="task_001",
task_name="任务1",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=1,
processed_entities=0,
parameters={},
created_at=datetime.now(),
),
DetectionTask(
task_id="task_002",
task_name="任务2",
task_type=TaskType.PERIODIC,
status=TaskStatus.COMPLETED,
entity_type="mcn",
period="2024-01",
total_entities=5,
processed_entities=5,
parameters={},
created_at=datetime.now(),
),
]
# 设置mock返回值
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
# 执行测试 - 无过滤条件
tasks = await task_manager.list_tasks()
assert len(tasks) == 2
# 执行测试 - 按状态过滤
tasks = await task_manager.list_tasks(status=TaskStatus.PENDING)
assert len(tasks) == 2 # 实际项目中这里应该过滤
# 执行测试 - 按类型过滤
tasks = await task_manager.list_tasks(task_type=TaskType.ON_DEMAND)
assert len(tasks) == 2
@pytest.mark.asyncio
async def test_execute_task_success(self, task_manager, mock_db_session):
"""测试成功执行任务"""
# 创建mock任务
mock_task = DetectionTask(
task_id="task_test_001",
task_name="测试任务",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=1,
processed_entities=0,
parameters={},
created_at=datetime.now(),
)
# 创建mock任务执行记录
mock_executions = [
TaskExecution(
execution_id="exec_001",
task_id="task_test_001",
entity_id="ZB_TEST_001",
entity_type="streamer",
period="2024-01",
rule_ids="rule_revenue_001",
parameters={},
status="pending",
created_at=datetime.now(),
)
]
# 设置mock返回值
mock_db_session.execute.side_effect = [
AsyncMock(scalar_one_or_none=lambda: mock_task), # 获取任务
AsyncMock(scalars=lambda: AsyncMock(all=lambda: mock_executions)), # 获取执行记录
AsyncMock(scalars=lambda: AsyncMock(all=lambda: [])), # 获取规则
]
# Mock RuleEngine
with patch('app.services.risk_detection.task_manager.task_manager.RuleEngine') as MockRuleEngine:
mock_rule_engine = AsyncMock()
mock_rule_engine.execute_detection = AsyncMock(return_value={
"results": [
{
"entity_id": "ZB_TEST_001",
"risk_level": RiskLevel.MEDIUM,
"risk_score": 65.0,
"description": "发现收入不匹配问题",
"suggestion": "请检查收入申报是否准确",
}
],
"summary": {
"total_detections": 1,
"risk_distribution": {"MEDIUM": 1},
"avg_score": 65.0,
}
})
MockRuleEngine.return_value = mock_rule_engine
# Mock算法注册
with patch('app.services.risk_detection.task_manager.task_manager.RevenueIntegrityAlgorithm'):
# 执行任务
result = await task_manager.execute_task("task_test_001")
# 验证结果
assert result["task_id"] == "task_test_001"
assert result["status"] == "completed"
assert result["summary"]["total_entities"] == 1
assert result["summary"]["total_detections"] == 1
@pytest.mark.asyncio
async def test_execute_task_not_found(self, task_manager, mock_db_session):
"""测试执行不存在的任务"""
# 设置mock返回值为None
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
# 验证抛出异常
with pytest.raises(ValueError, match="任务不存在"):
await task_manager.execute_task("task_not_exist")
@pytest.mark.asyncio
async def test_execute_task_with_failure(self, task_manager, mock_db_session):
"""测试任务执行失败"""
# 创建mock任务
mock_task = DetectionTask(
task_id="task_test_001",
task_name="测试任务",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=1,
processed_entities=0,
parameters={},
created_at=datetime.now(),
)
# 设置mock返回值
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_task
# Mock RuleEngine抛出异常
with patch('app.services.risk_detection.task_manager.task_manager.RuleEngine') as MockRuleEngine:
mock_rule_engine = AsyncMock()
mock_rule_engine.execute_detection = AsyncMock(side_effect=Exception("执行失败"))
MockRuleEngine.return_value = mock_rule_engine
# 验证抛出异常
with pytest.raises(Exception, match="执行失败"):
await task_manager.execute_task("task_test_001")
class TestDetectionScheduler:
"""任务调度器测试类"""
@pytest.fixture
def mock_db_session(self):
"""创建Mock数据库会话"""
session = AsyncMock(spec=AsyncSession)
session.execute = AsyncMock()
return session
@pytest.fixture
def scheduler(self, mock_db_session):
"""创建调度器实例"""
return DetectionScheduler(mock_db_session)
@pytest.mark.asyncio
async def test_schedule_periodic_detection(self, scheduler, mock_db_session):
"""测试调度定期检测"""
# Mock TaskManager.create_task
mock_task = DetectionTask(
task_id="task_periodic_001",
task_name="定期检测-2024-01",
task_type=TaskType.PERIODIC,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=2,
processed_entities=0,
parameters={"schedule_type": "periodic"},
created_at=datetime.now(),
)
with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create:
mock_create.return_value = mock_task
# 执行测试
task = await scheduler.schedule_periodic_detection(
entity_ids=["ZB_TEST_001", "ZB_TEST_002"],
entity_type="streamer",
period="2024-01",
auto_execute=False
)
# 验证结果
assert task.task_id == "task_periodic_001"
assert task.task_type == TaskType.PERIODIC
mock_create.assert_called_once()
@pytest.mark.asyncio
async def test_schedule_on_demand_detection(self, scheduler, mock_db_session):
"""测试调度按需检测"""
# Mock TaskManager.create_task
mock_task = DetectionTask(
task_id="task_ondemand_001",
task_name="按需检测-ZB_TEST_001-2024-01",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=1,
processed_entities=0,
parameters={"schedule_type": "on_demand"},
created_at=datetime.now(),
)
with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create:
mock_create.return_value = mock_task
# 执行测试
task = await scheduler.schedule_on_demand_detection(
entity_id="ZB_TEST_001",
entity_type="streamer",
period="2024-01",
auto_execute=False
)
# 验证结果
assert task.task_id == "task_ondemand_001"
assert task.task_type == TaskType.ON_DEMAND
mock_create.assert_called_once()
@pytest.mark.asyncio
async def test_schedule_batch_detection(self, scheduler, mock_db_session):
"""测试调度批量检测"""
# Mock TaskManager.create_task
mock_task = DetectionTask(
task_id="task_batch_001",
task_name="批量检测-2024-01",
task_type=TaskType.BATCH,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=10,
processed_entities=0,
parameters={"schedule_type": "batch", "batch_size": 10},
created_at=datetime.now(),
)
with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create:
mock_create.return_value = mock_task
# 执行测试
task = await scheduler.schedule_batch_detection(
entity_ids=[f"ZB_TEST_{i:03d}" for i in range(1, 11)],
entity_type="streamer",
period="2024-01",
auto_execute=False
)
# 验证结果
assert task.task_id == "task_batch_001"
assert task.task_type == TaskType.BATCH
assert task.total_entities == 10
mock_create.assert_called_once()
@pytest.mark.asyncio
async def test_check_pending_tasks(self, scheduler, mock_db_session):
"""测试检查待执行任务"""
# 创建mock待执行任务
mock_tasks = [
DetectionTask(
task_id="task_001",
task_name="待执行任务1",
task_type=TaskType.PERIODIC,
status=TaskStatus.PENDING,
entity_type="streamer",
period="2024-01",
total_entities=1,
processed_entities=0,
parameters={},
created_at=datetime.now(),
),
DetectionTask(
task_id="task_002",
task_name="待执行任务2",
task_type=TaskType.BATCH,
status=TaskStatus.PENDING,
entity_type="mcn",
period="2024-01",
total_entities=5,
processed_entities=0,
parameters={},
created_at=datetime.now(),
),
]
# 设置mock返回值
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
# 执行测试
tasks = await scheduler.check_pending_tasks()
# 验证结果
assert len(tasks) == 2
assert all(task.status == TaskStatus.PENDING for task in tasks)
@pytest.mark.asyncio
async def test_retry_failed_tasks(self, scheduler, mock_db_session):
"""测试重试失败任务"""
# 创建mock失败任务
mock_tasks = [
DetectionTask(
task_id="task_failed_001",
task_name="失败任务1",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.FAILED,
entity_type="streamer",
period="2024-01",
total_entities=1,
processed_entities=0,
retry_count=1,
parameters={},
created_at=datetime.now(),
)
]
# 设置mock返回值
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
# Mock任务执行
with patch.object(scheduler, '_execute_task_async', new_callable=AsyncMock):
# 执行测试
retried_tasks = await scheduler.retry_failed_tasks(max_retries=3)
# 验证结果
assert len(retried_tasks) == 1
assert retried_tasks[0].status == TaskStatus.PENDING
assert retried_tasks[0].retry_count == 2
@pytest.mark.asyncio
async def test_cleanup_old_tasks(self, scheduler, mock_db_session):
"""测试清理旧任务"""
# 创建mock旧任务
old_date = datetime.now() - timedelta(days=35)
mock_tasks = [
DetectionTask(
task_id="task_old_001",
task_name="旧任务1",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.COMPLETED,
entity_type="streamer",
period="2023-01",
total_entities=1,
processed_entities=1,
parameters={},
created_at=old_date,
),
DetectionTask(
task_id="task_old_002",
task_name="旧任务2",
task_type=TaskType.ON_DEMAND,
status=TaskStatus.FAILED,
entity_type="mcn",
period="2023-01",
total_entities=1,
processed_entities=0,
parameters={},
created_at=old_date,
),
]
# 设置mock返回值
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
mock_db_session.delete = MagicMock()
# 执行测试
count = await scheduler.cleanup_old_tasks(days=30)
# 验证结果
assert count == 2
assert mock_db_session.delete.call_count == 2
mock_db_session.commit.assert_called_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])