585 lines
21 KiB
Python
585 lines
21 KiB
Python
"""
|
|
任务管理器测试用例
|
|
|
|
覆盖场景:
|
|
1. 创建检测任务(按需、定期、批量)
|
|
2. 查询任务列表
|
|
3. 执行检测任务
|
|
4. 任务状态管理
|
|
5. 错误处理和重试机制
|
|
"""
|
|
import pytest
|
|
from datetime import datetime, date, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_
|
|
|
|
from app.services.risk_detection.task_manager import TaskManager
|
|
from app.services.risk_detection.task_manager.scheduler import DetectionScheduler
|
|
from app.models.risk_detection import (
|
|
DetectionTask,
|
|
DetectionRule,
|
|
TaskType,
|
|
TaskStatus,
|
|
RiskLevel,
|
|
)
|
|
from app.models.risk_detection import TaskExecution
|
|
|
|
|
|
class TestTaskManager:
|
|
"""任务管理器测试类"""
|
|
|
|
@pytest.fixture
|
|
def mock_db_session(self):
|
|
"""创建Mock数据库会话"""
|
|
session = AsyncMock(spec=AsyncSession)
|
|
session.execute = AsyncMock()
|
|
session.flush = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
session.add = MagicMock()
|
|
session.commit = AsyncMock()
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def task_manager(self, mock_db_session):
|
|
"""创建任务管理器实例"""
|
|
return TaskManager(mock_db_session)
|
|
|
|
@pytest.fixture
|
|
def mock_detection_rules(self):
|
|
"""Mock检测规则数据"""
|
|
return [
|
|
DetectionRule(
|
|
rule_id="rule_revenue_001",
|
|
rule_name="收入完整性检测",
|
|
algorithm_code="REVENUE_INTEGRITY_CHECK",
|
|
description="检测平台充值与申报收入的匹配度",
|
|
is_enabled=True,
|
|
parameters={"threshold": 0.1},
|
|
created_at=datetime.now(),
|
|
),
|
|
DetectionRule(
|
|
rule_id="rule_private_001",
|
|
rule_name="私户收款检测",
|
|
algorithm_code="PRIVATE_ACCOUNT_DETECTION",
|
|
description="识别使用私人账户收款的风险",
|
|
is_enabled=True,
|
|
parameters={"threshold_amount": 10000},
|
|
created_at=datetime.now(),
|
|
),
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_task_on_demand(self, task_manager, mock_db_session, mock_detection_rules):
|
|
"""测试创建按需检测任务"""
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
|
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules
|
|
|
|
# 创建任务
|
|
task = await task_manager.create_task(
|
|
task_name="测试按需检测任务",
|
|
task_type=TaskType.ON_DEMAND,
|
|
entity_ids=["ZB_TEST_001"],
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
rule_ids=["rule_revenue_001"],
|
|
parameters={"description": "测试任务"}
|
|
)
|
|
|
|
# 验证结果
|
|
assert task.task_id is not None
|
|
assert task.task_name == "测试按需检测任务"
|
|
assert task.task_type == TaskType.ON_DEMAND
|
|
assert task.status == TaskStatus.PENDING
|
|
assert task.entity_type == "streamer"
|
|
assert task.period == "2024-01"
|
|
assert task.total_entities == 1
|
|
assert task.parameters["description"] == "测试任务"
|
|
|
|
# 验证执行记录被创建
|
|
mock_db_session.add.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_task_periodic(self, task_manager, mock_db_session, mock_detection_rules):
|
|
"""测试创建定期检测任务"""
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
|
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules
|
|
|
|
# 创建任务
|
|
task = await task_manager.create_task(
|
|
task_name="测试定期检测任务",
|
|
task_type=TaskType.PERIODIC,
|
|
entity_ids=["ZB_TEST_001", "ZB_TEST_002"],
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
parameters={"schedule_type": "periodic"}
|
|
)
|
|
|
|
# 验证结果
|
|
assert task.task_id is not None
|
|
assert task.task_name == "测试定期检测任务"
|
|
assert task.task_type == TaskType.PERIODIC
|
|
assert task.total_entities == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_task_batch(self, task_manager, mock_db_session, mock_detection_rules):
|
|
"""测试创建批量检测任务"""
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
|
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_rules
|
|
|
|
# 创建任务
|
|
task = await task_manager.create_task(
|
|
task_name="测试批量检测任务",
|
|
task_type=TaskType.BATCH,
|
|
entity_ids=["ZB_TEST_001", "ZB_TEST_002", "ZB_TEST_003"],
|
|
entity_type="mcn",
|
|
period="2024-01",
|
|
rule_ids=["rule_private_001"],
|
|
parameters={"batch_size": 3}
|
|
)
|
|
|
|
# 验证结果
|
|
assert task.task_id is not None
|
|
assert task.task_name == "测试批量检测任务"
|
|
assert task.task_type == TaskType.BATCH
|
|
assert task.total_entities == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_task(self, task_manager, mock_db_session):
|
|
"""测试获取任务"""
|
|
# 创建mock任务
|
|
mock_task = DetectionTask(
|
|
task_id="task_test_001",
|
|
task_name="测试任务",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_task
|
|
|
|
# 执行测试
|
|
task = await task_manager.get_task("task_test_001")
|
|
|
|
# 验证结果
|
|
assert task is not None
|
|
assert task.task_id == "task_test_001"
|
|
assert task.task_name == "测试任务"
|
|
assert task.status == TaskStatus.PENDING
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_task_not_found(self, task_manager, mock_db_session):
|
|
"""测试获取不存在的任务"""
|
|
# 设置mock返回值为None
|
|
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
|
|
|
# 执行测试
|
|
task = await task_manager.get_task("task_not_exist")
|
|
|
|
# 验证结果
|
|
assert task is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tasks(self, task_manager, mock_db_session):
|
|
"""测试查询任务列表"""
|
|
# 创建mock任务列表
|
|
mock_tasks = [
|
|
DetectionTask(
|
|
task_id="task_001",
|
|
task_name="任务1",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
),
|
|
DetectionTask(
|
|
task_id="task_002",
|
|
task_name="任务2",
|
|
task_type=TaskType.PERIODIC,
|
|
status=TaskStatus.COMPLETED,
|
|
entity_type="mcn",
|
|
period="2024-01",
|
|
total_entities=5,
|
|
processed_entities=5,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
),
|
|
]
|
|
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
|
|
|
|
# 执行测试 - 无过滤条件
|
|
tasks = await task_manager.list_tasks()
|
|
assert len(tasks) == 2
|
|
|
|
# 执行测试 - 按状态过滤
|
|
tasks = await task_manager.list_tasks(status=TaskStatus.PENDING)
|
|
assert len(tasks) == 2 # 实际项目中这里应该过滤
|
|
|
|
# 执行测试 - 按类型过滤
|
|
tasks = await task_manager.list_tasks(task_type=TaskType.ON_DEMAND)
|
|
assert len(tasks) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_task_success(self, task_manager, mock_db_session):
|
|
"""测试成功执行任务"""
|
|
# 创建mock任务
|
|
mock_task = DetectionTask(
|
|
task_id="task_test_001",
|
|
task_name="测试任务",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
# 创建mock任务执行记录
|
|
mock_executions = [
|
|
TaskExecution(
|
|
execution_id="exec_001",
|
|
task_id="task_test_001",
|
|
entity_id="ZB_TEST_001",
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
rule_ids="rule_revenue_001",
|
|
parameters={},
|
|
status="pending",
|
|
created_at=datetime.now(),
|
|
)
|
|
]
|
|
|
|
# 设置mock返回值
|
|
mock_db_session.execute.side_effect = [
|
|
AsyncMock(scalar_one_or_none=lambda: mock_task), # 获取任务
|
|
AsyncMock(scalars=lambda: AsyncMock(all=lambda: mock_executions)), # 获取执行记录
|
|
AsyncMock(scalars=lambda: AsyncMock(all=lambda: [])), # 获取规则
|
|
]
|
|
|
|
# Mock RuleEngine
|
|
with patch('app.services.risk_detection.task_manager.task_manager.RuleEngine') as MockRuleEngine:
|
|
mock_rule_engine = AsyncMock()
|
|
mock_rule_engine.execute_detection = AsyncMock(return_value={
|
|
"results": [
|
|
{
|
|
"entity_id": "ZB_TEST_001",
|
|
"risk_level": RiskLevel.MEDIUM,
|
|
"risk_score": 65.0,
|
|
"description": "发现收入不匹配问题",
|
|
"suggestion": "请检查收入申报是否准确",
|
|
}
|
|
],
|
|
"summary": {
|
|
"total_detections": 1,
|
|
"risk_distribution": {"MEDIUM": 1},
|
|
"avg_score": 65.0,
|
|
}
|
|
})
|
|
MockRuleEngine.return_value = mock_rule_engine
|
|
|
|
# Mock算法注册
|
|
with patch('app.services.risk_detection.task_manager.task_manager.RevenueIntegrityAlgorithm'):
|
|
# 执行任务
|
|
result = await task_manager.execute_task("task_test_001")
|
|
|
|
# 验证结果
|
|
assert result["task_id"] == "task_test_001"
|
|
assert result["status"] == "completed"
|
|
assert result["summary"]["total_entities"] == 1
|
|
assert result["summary"]["total_detections"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_task_not_found(self, task_manager, mock_db_session):
|
|
"""测试执行不存在的任务"""
|
|
# 设置mock返回值为None
|
|
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
|
|
|
# 验证抛出异常
|
|
with pytest.raises(ValueError, match="任务不存在"):
|
|
await task_manager.execute_task("task_not_exist")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_task_with_failure(self, task_manager, mock_db_session):
|
|
"""测试任务执行失败"""
|
|
# 创建mock任务
|
|
mock_task = DetectionTask(
|
|
task_id="task_test_001",
|
|
task_name="测试任务",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_task
|
|
|
|
# Mock RuleEngine抛出异常
|
|
with patch('app.services.risk_detection.task_manager.task_manager.RuleEngine') as MockRuleEngine:
|
|
mock_rule_engine = AsyncMock()
|
|
mock_rule_engine.execute_detection = AsyncMock(side_effect=Exception("执行失败"))
|
|
MockRuleEngine.return_value = mock_rule_engine
|
|
|
|
# 验证抛出异常
|
|
with pytest.raises(Exception, match="执行失败"):
|
|
await task_manager.execute_task("task_test_001")
|
|
|
|
|
|
class TestDetectionScheduler:
|
|
"""任务调度器测试类"""
|
|
|
|
@pytest.fixture
|
|
def mock_db_session(self):
|
|
"""创建Mock数据库会话"""
|
|
session = AsyncMock(spec=AsyncSession)
|
|
session.execute = AsyncMock()
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def scheduler(self, mock_db_session):
|
|
"""创建调度器实例"""
|
|
return DetectionScheduler(mock_db_session)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_schedule_periodic_detection(self, scheduler, mock_db_session):
|
|
"""测试调度定期检测"""
|
|
# Mock TaskManager.create_task
|
|
mock_task = DetectionTask(
|
|
task_id="task_periodic_001",
|
|
task_name="定期检测-2024-01",
|
|
task_type=TaskType.PERIODIC,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=2,
|
|
processed_entities=0,
|
|
parameters={"schedule_type": "periodic"},
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create:
|
|
mock_create.return_value = mock_task
|
|
|
|
# 执行测试
|
|
task = await scheduler.schedule_periodic_detection(
|
|
entity_ids=["ZB_TEST_001", "ZB_TEST_002"],
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
auto_execute=False
|
|
)
|
|
|
|
# 验证结果
|
|
assert task.task_id == "task_periodic_001"
|
|
assert task.task_type == TaskType.PERIODIC
|
|
mock_create.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_schedule_on_demand_detection(self, scheduler, mock_db_session):
|
|
"""测试调度按需检测"""
|
|
# Mock TaskManager.create_task
|
|
mock_task = DetectionTask(
|
|
task_id="task_ondemand_001",
|
|
task_name="按需检测-ZB_TEST_001-2024-01",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
parameters={"schedule_type": "on_demand"},
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create:
|
|
mock_create.return_value = mock_task
|
|
|
|
# 执行测试
|
|
task = await scheduler.schedule_on_demand_detection(
|
|
entity_id="ZB_TEST_001",
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
auto_execute=False
|
|
)
|
|
|
|
# 验证结果
|
|
assert task.task_id == "task_ondemand_001"
|
|
assert task.task_type == TaskType.ON_DEMAND
|
|
mock_create.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_schedule_batch_detection(self, scheduler, mock_db_session):
|
|
"""测试调度批量检测"""
|
|
# Mock TaskManager.create_task
|
|
mock_task = DetectionTask(
|
|
task_id="task_batch_001",
|
|
task_name="批量检测-2024-01",
|
|
task_type=TaskType.BATCH,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=10,
|
|
processed_entities=0,
|
|
parameters={"schedule_type": "batch", "batch_size": 10},
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
with patch.object(scheduler.task_manager, 'create_task', new_callable=AsyncMock) as mock_create:
|
|
mock_create.return_value = mock_task
|
|
|
|
# 执行测试
|
|
task = await scheduler.schedule_batch_detection(
|
|
entity_ids=[f"ZB_TEST_{i:03d}" for i in range(1, 11)],
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
auto_execute=False
|
|
)
|
|
|
|
# 验证结果
|
|
assert task.task_id == "task_batch_001"
|
|
assert task.task_type == TaskType.BATCH
|
|
assert task.total_entities == 10
|
|
mock_create.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_pending_tasks(self, scheduler, mock_db_session):
|
|
"""测试检查待执行任务"""
|
|
# 创建mock待执行任务
|
|
mock_tasks = [
|
|
DetectionTask(
|
|
task_id="task_001",
|
|
task_name="待执行任务1",
|
|
task_type=TaskType.PERIODIC,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
),
|
|
DetectionTask(
|
|
task_id="task_002",
|
|
task_name="待执行任务2",
|
|
task_type=TaskType.BATCH,
|
|
status=TaskStatus.PENDING,
|
|
entity_type="mcn",
|
|
period="2024-01",
|
|
total_entities=5,
|
|
processed_entities=0,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
),
|
|
]
|
|
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
|
|
|
|
# 执行测试
|
|
tasks = await scheduler.check_pending_tasks()
|
|
|
|
# 验证结果
|
|
assert len(tasks) == 2
|
|
assert all(task.status == TaskStatus.PENDING for task in tasks)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_failed_tasks(self, scheduler, mock_db_session):
|
|
"""测试重试失败任务"""
|
|
# 创建mock失败任务
|
|
mock_tasks = [
|
|
DetectionTask(
|
|
task_id="task_failed_001",
|
|
task_name="失败任务1",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.FAILED,
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
retry_count=1,
|
|
parameters={},
|
|
created_at=datetime.now(),
|
|
)
|
|
]
|
|
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
|
|
|
|
# Mock任务执行
|
|
with patch.object(scheduler, '_execute_task_async', new_callable=AsyncMock):
|
|
# 执行测试
|
|
retried_tasks = await scheduler.retry_failed_tasks(max_retries=3)
|
|
|
|
# 验证结果
|
|
assert len(retried_tasks) == 1
|
|
assert retried_tasks[0].status == TaskStatus.PENDING
|
|
assert retried_tasks[0].retry_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_old_tasks(self, scheduler, mock_db_session):
|
|
"""测试清理旧任务"""
|
|
# 创建mock旧任务
|
|
old_date = datetime.now() - timedelta(days=35)
|
|
mock_tasks = [
|
|
DetectionTask(
|
|
task_id="task_old_001",
|
|
task_name="旧任务1",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.COMPLETED,
|
|
entity_type="streamer",
|
|
period="2023-01",
|
|
total_entities=1,
|
|
processed_entities=1,
|
|
parameters={},
|
|
created_at=old_date,
|
|
),
|
|
DetectionTask(
|
|
task_id="task_old_002",
|
|
task_name="旧任务2",
|
|
task_type=TaskType.ON_DEMAND,
|
|
status=TaskStatus.FAILED,
|
|
entity_type="mcn",
|
|
period="2023-01",
|
|
total_entities=1,
|
|
processed_entities=0,
|
|
parameters={},
|
|
created_at=old_date,
|
|
),
|
|
]
|
|
|
|
# 设置mock返回值
|
|
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_tasks
|
|
mock_db_session.delete = MagicMock()
|
|
|
|
# 执行测试
|
|
count = await scheduler.cleanup_old_tasks(days=30)
|
|
|
|
# 验证结果
|
|
assert count == 2
|
|
assert mock_db_session.delete.call_count == 2
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|