386 lines
13 KiB
Python
386 lines
13 KiB
Python
"""
|
||
风险检测任务管理器
|
||
负责创建、管理和执行检测任务
|
||
"""
|
||
import uuid
|
||
from typing import List, Dict, Any, Optional
|
||
from datetime import datetime
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, and_
|
||
from loguru import logger
|
||
|
||
from app.models.risk_detection import (
|
||
DetectionTask,
|
||
DetectionRule,
|
||
TaskType,
|
||
TaskStatus,
|
||
)
|
||
from app.services.risk_detection.engine.rule_engine import RuleEngine
|
||
|
||
|
||
class TaskManager:
|
||
"""风险检测任务管理器"""
|
||
|
||
def __init__(self, db_session: AsyncSession):
|
||
self.db_session = db_session
|
||
|
||
async def create_task(
|
||
self,
|
||
task_name: str,
|
||
task_type: TaskType,
|
||
entity_ids: List[str],
|
||
entity_type: str,
|
||
period: str,
|
||
rule_ids: Optional[List[str]] = None,
|
||
parameters: Optional[Dict[str, Any]] = None,
|
||
) -> DetectionTask:
|
||
"""
|
||
创建检测任务
|
||
|
||
Args:
|
||
task_name: 任务名称
|
||
task_type: 任务类型
|
||
entity_ids: 实体ID列表
|
||
entity_type: 实体类型
|
||
period: 检测期间
|
||
rule_ids: 规则ID列表
|
||
parameters: 额外参数
|
||
|
||
Returns:
|
||
创建的任务
|
||
"""
|
||
logger.info(
|
||
f"创建检测任务: {task_name}, 类型: {task_type.value}, "
|
||
f"实体数: {len(entity_ids)}"
|
||
)
|
||
|
||
# 生成任务ID
|
||
task_id = f"task_{uuid.uuid4().hex[:12]}"
|
||
|
||
# 创建任务
|
||
task = DetectionTask(
|
||
task_id=task_id,
|
||
task_name=task_name,
|
||
task_type=task_type,
|
||
status=TaskStatus.PENDING,
|
||
entity_type=entity_type,
|
||
period=period,
|
||
total_entities=len(entity_ids),
|
||
processed_entities=0,
|
||
parameters=parameters or {},
|
||
created_at=datetime.now(),
|
||
)
|
||
|
||
self.db_session.add(task)
|
||
await self.db_session.flush()
|
||
await self.db_session.refresh(task)
|
||
|
||
# 为每个实体创建任务执行记录
|
||
task_executions = []
|
||
for entity_id in entity_ids:
|
||
# 如果是即时检测(task_type == ON_DEMAND)且有rule_ids,说明是算法代码
|
||
if task_type == TaskType.ON_DEMAND and rule_ids:
|
||
# 即时检测模式:rule_ids是算法代码,直接保存在parameters中
|
||
exec_params = parameters or {}
|
||
exec_params['rule_ids'] = rule_ids
|
||
execution = await self._create_task_execution(
|
||
task.task_id, entity_id, entity_type, period, [], exec_params
|
||
)
|
||
else:
|
||
# 任务模式:从数据库获取规则
|
||
rules = await self._get_rules(rule_ids)
|
||
execution = await self._create_task_execution(
|
||
task.task_id, entity_id, entity_type, period, rules, parameters or {}
|
||
)
|
||
task_executions.append(execution)
|
||
|
||
logger.info(
|
||
f"任务创建完成: {task_id}, 执行记录数: {len(task_executions)}"
|
||
)
|
||
|
||
return task
|
||
|
||
async def execute_task(self, task_id: str) -> Dict[str, Any]:
|
||
"""
|
||
执行检测任务
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
logger.info(f"开始执行任务: {task_id}")
|
||
|
||
try:
|
||
# 获取任务
|
||
task = await self._get_task(task_id)
|
||
if not task:
|
||
raise ValueError(f"任务不存在: {task_id}")
|
||
|
||
# 更新任务状态
|
||
task.status = TaskStatus.RUNNING
|
||
task.started_at = datetime.now()
|
||
await self.db_session.flush()
|
||
|
||
# 获取任务执行记录
|
||
executions = await self._get_task_executions(task_id)
|
||
|
||
# 创建规则引擎
|
||
rule_engine = RuleEngine(self.db_session)
|
||
|
||
# 注册算法
|
||
self._register_algorithms(rule_engine)
|
||
|
||
# 执行每个实体的检测
|
||
total_results = []
|
||
completed_count = 0
|
||
|
||
for execution in executions:
|
||
try:
|
||
# 检查是否有规则ID
|
||
if execution.rule_ids:
|
||
# 从数据库获取规则
|
||
rule_ids = [rid.strip() for rid in execution.rule_ids.split(',') if rid.strip()]
|
||
rules = await self._get_rules(rule_ids)
|
||
else:
|
||
# 即时检测模式:从参数中获取算法代码
|
||
rules = await self._get_rules_from_parameters(execution.parameters)
|
||
|
||
if not rules:
|
||
logger.warning(
|
||
f"实体 {execution.entity_id} 没有可执行的规则,跳过"
|
||
)
|
||
continue
|
||
|
||
# 执行检测
|
||
result = await rule_engine.execute_detection(
|
||
task_id=task_id,
|
||
rules=rules,
|
||
entity_id=execution.entity_id,
|
||
entity_type=execution.entity_type,
|
||
period=execution.period,
|
||
parameters=execution.parameters,
|
||
)
|
||
|
||
total_results.append(result)
|
||
|
||
# 更新执行记录
|
||
execution.status = "completed"
|
||
execution.completed_at = datetime.now()
|
||
execution.result_count = len(result.get("results", []))
|
||
await self.db_session.flush()
|
||
|
||
completed_count += 1
|
||
task.processed_entities = completed_count
|
||
await self.db_session.flush()
|
||
|
||
logger.debug(
|
||
f"实体检测完成: {execution.entity_id}, "
|
||
f"进度: {completed_count}/{len(executions)}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"实体检测失败: {execution.entity_id}, 错误: {str(e)}",
|
||
exc_info=True,
|
||
)
|
||
execution.status = "failed"
|
||
execution.error_message = str(e)
|
||
execution.completed_at = datetime.now()
|
||
await self.db_session.flush()
|
||
|
||
# 汇总任务结果
|
||
summary = self._summarize_task_results(total_results)
|
||
|
||
# 更新任务状态
|
||
task.status = TaskStatus.COMPLETED
|
||
task.completed_at = datetime.now()
|
||
task.result_count = len(total_results)
|
||
task.summary = summary
|
||
await self.db_session.flush()
|
||
|
||
logger.info(
|
||
f"任务执行完成: {task_id}, 结果数: {len(total_results)}"
|
||
)
|
||
|
||
return {
|
||
"task_id": task_id,
|
||
"status": task.status.value,
|
||
"summary": summary,
|
||
"results": total_results,
|
||
"executed_at": datetime.now().isoformat(),
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error("任务执行失败: {}, 错误: {}", task_id, str(e), exc_info=True)
|
||
|
||
# 更新任务状态为失败
|
||
task = await self._get_task(task_id)
|
||
if task:
|
||
task.status = TaskStatus.FAILED
|
||
task.error_message = str(e)
|
||
task.completed_at = datetime.now()
|
||
await self.db_session.flush()
|
||
|
||
raise
|
||
|
||
async def get_task(self, task_id: str) -> Optional[DetectionTask]:
|
||
"""获取任务"""
|
||
stmt = select(DetectionTask).where(DetectionTask.task_id == task_id)
|
||
result = await self.db_session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def list_tasks(
|
||
self,
|
||
status: Optional[TaskStatus] = None,
|
||
task_type: Optional[TaskType] = None,
|
||
limit: int = 100,
|
||
) -> List[DetectionTask]:
|
||
"""查询任务列表"""
|
||
stmt = select(DetectionTask)
|
||
|
||
if status:
|
||
stmt = stmt.where(DetectionTask.status == status)
|
||
|
||
if task_type:
|
||
stmt = stmt.where(DetectionTask.task_type == task_type)
|
||
|
||
stmt = stmt.order_by(DetectionTask.created_at.desc()).limit(limit)
|
||
|
||
result = await self.db_session.execute(stmt)
|
||
return result.scalars().all()
|
||
|
||
async def _get_task(self, task_id: str) -> Optional[DetectionTask]:
|
||
"""获取任务"""
|
||
stmt = select(DetectionTask).where(DetectionTask.task_id == task_id)
|
||
result = await self.db_session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def _get_rules(self, rule_ids: Optional[List[str]]) -> List[DetectionRule]:
|
||
"""获取规则"""
|
||
if not rule_ids:
|
||
# 如果没有指定规则,获取所有启用的规则
|
||
stmt = select(DetectionRule).where(DetectionRule.is_enabled == True)
|
||
else:
|
||
stmt = select(DetectionRule).where(DetectionRule.rule_id.in_(rule_ids))
|
||
|
||
result = await self.db_session.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
async def _get_rules_from_parameters(self, parameters: Dict[str, Any]) -> List[DetectionRule]:
|
||
"""
|
||
从参数中获取算法代码并创建临时的DetectionRule对象
|
||
用于即时检测模式
|
||
"""
|
||
from app.models.risk_detection import RiskLevel
|
||
|
||
# 从parameters中获取算法代码列表
|
||
rule_codes = parameters.get('rule_ids', [])
|
||
if not rule_codes:
|
||
logger.warning("参数中没有找到 rule_ids")
|
||
return []
|
||
|
||
rules = []
|
||
for code in rule_codes:
|
||
# 创建临时的DetectionRule对象
|
||
rule = DetectionRule(
|
||
rule_id=f"temp_{code}_{uuid.uuid4().hex[:8]}",
|
||
rule_name=f"临时规则-{code}",
|
||
algorithm_code=code,
|
||
description=f"即时检测临时规则: {code}",
|
||
parameters=parameters,
|
||
is_enabled=True,
|
||
created_at=datetime.now(),
|
||
)
|
||
rules.append(rule)
|
||
|
||
logger.info(f"从参数创建了 {len(rules)} 个临时规则")
|
||
return rules
|
||
|
||
def _register_algorithms(self, rule_engine: RuleEngine):
|
||
"""注册算法"""
|
||
from app.services.risk_detection.algorithms import (
|
||
RevenueIntegrityAlgorithm,
|
||
PrivateAccountDetectionAlgorithm,
|
||
InvoiceFraudDetectionAlgorithm,
|
||
ExpenseAnomalyDetectionAlgorithm,
|
||
TaxRiskAssessmentAlgorithm,
|
||
)
|
||
|
||
rule_engine.algorithm_registry.register(RevenueIntegrityAlgorithm)
|
||
rule_engine.algorithm_registry.register(PrivateAccountDetectionAlgorithm)
|
||
rule_engine.algorithm_registry.register(InvoiceFraudDetectionAlgorithm)
|
||
rule_engine.algorithm_registry.register(ExpenseAnomalyDetectionAlgorithm)
|
||
rule_engine.algorithm_registry.register(TaxRiskAssessmentAlgorithm)
|
||
|
||
async def _create_task_execution(
|
||
self,
|
||
task_id: str,
|
||
entity_id: str,
|
||
entity_type: str,
|
||
period: str,
|
||
rules: List[DetectionRule],
|
||
parameters: Dict[str, Any],
|
||
):
|
||
"""创建任务执行记录"""
|
||
from app.models.risk_detection import TaskExecution
|
||
|
||
# 将rule_ids列表转换为逗号分隔的字符串
|
||
rule_ids_str = ",".join([r.rule_id for r in rules]) if rules else ""
|
||
|
||
execution = TaskExecution(
|
||
execution_id=f"exec_{uuid.uuid4().hex[:12]}",
|
||
task_id=task_id,
|
||
entity_id=entity_id,
|
||
entity_type=entity_type,
|
||
period=period,
|
||
rule_ids=rule_ids_str,
|
||
parameters=parameters,
|
||
status="pending",
|
||
created_at=datetime.now(),
|
||
)
|
||
|
||
self.db_session.add(execution)
|
||
await self.db_session.flush()
|
||
|
||
return execution
|
||
|
||
async def _get_task_executions(self, task_id: str):
|
||
"""获取任务执行记录"""
|
||
from app.models.risk_detection import TaskExecution
|
||
|
||
stmt = select(TaskExecution).where(TaskExecution.task_id == task_id)
|
||
result = await self.db_session.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
def _summarize_task_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||
"""汇总任务结果"""
|
||
if not results:
|
||
return {"total_entities": 0}
|
||
|
||
total_entities = len(results)
|
||
total_detections = sum(len(r.get("results", [])) for r in results)
|
||
|
||
# 统计风险分布
|
||
risk_distribution = {}
|
||
all_scores = []
|
||
|
||
for result in results:
|
||
summary = result.get("summary", {})
|
||
dist = summary.get("risk_distribution", {})
|
||
|
||
for level, count in dist.items():
|
||
risk_distribution[level] = risk_distribution.get(level, 0) + count
|
||
|
||
all_scores.append(summary.get("avg_score", 0))
|
||
|
||
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0
|
||
|
||
return {
|
||
"total_entities": total_entities,
|
||
"total_detections": total_detections,
|
||
"risk_distribution": risk_distribution,
|
||
"avg_score": round(avg_score, 2),
|
||
}
|