deep-risk/backend/app/services/risk_detection/task_manager/task_manager.py
2025-12-14 20:08:27 +08:00

386 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
风险检测任务管理器
负责创建、管理和执行检测任务
"""
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from loguru import logger
from app.models.risk_detection import (
DetectionTask,
DetectionRule,
TaskType,
TaskStatus,
)
from app.services.risk_detection.engine.rule_engine import RuleEngine
class TaskManager:
"""风险检测任务管理器"""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def create_task(
self,
task_name: str,
task_type: TaskType,
entity_ids: List[str],
entity_type: str,
period: str,
rule_ids: Optional[List[str]] = None,
parameters: Optional[Dict[str, Any]] = None,
) -> DetectionTask:
"""
创建检测任务
Args:
task_name: 任务名称
task_type: 任务类型
entity_ids: 实体ID列表
entity_type: 实体类型
period: 检测期间
rule_ids: 规则ID列表
parameters: 额外参数
Returns:
创建的任务
"""
logger.info(
f"创建检测任务: {task_name}, 类型: {task_type.value}, "
f"实体数: {len(entity_ids)}"
)
# 生成任务ID
task_id = f"task_{uuid.uuid4().hex[:12]}"
# 创建任务
task = DetectionTask(
task_id=task_id,
task_name=task_name,
task_type=task_type,
status=TaskStatus.PENDING,
entity_type=entity_type,
period=period,
total_entities=len(entity_ids),
processed_entities=0,
parameters=parameters or {},
created_at=datetime.now(),
)
self.db_session.add(task)
await self.db_session.flush()
await self.db_session.refresh(task)
# 为每个实体创建任务执行记录
task_executions = []
for entity_id in entity_ids:
# 如果是即时检测task_type == ON_DEMAND且有rule_ids说明是算法代码
if task_type == TaskType.ON_DEMAND and rule_ids:
# 即时检测模式rule_ids是算法代码直接保存在parameters中
exec_params = parameters or {}
exec_params['rule_ids'] = rule_ids
execution = await self._create_task_execution(
task.task_id, entity_id, entity_type, period, [], exec_params
)
else:
# 任务模式:从数据库获取规则
rules = await self._get_rules(rule_ids)
execution = await self._create_task_execution(
task.task_id, entity_id, entity_type, period, rules, parameters or {}
)
task_executions.append(execution)
logger.info(
f"任务创建完成: {task_id}, 执行记录数: {len(task_executions)}"
)
return task
async def execute_task(self, task_id: str) -> Dict[str, Any]:
"""
执行检测任务
Args:
task_id: 任务ID
Returns:
执行结果
"""
logger.info(f"开始执行任务: {task_id}")
try:
# 获取任务
task = await self._get_task(task_id)
if not task:
raise ValueError(f"任务不存在: {task_id}")
# 更新任务状态
task.status = TaskStatus.RUNNING
task.started_at = datetime.now()
await self.db_session.flush()
# 获取任务执行记录
executions = await self._get_task_executions(task_id)
# 创建规则引擎
rule_engine = RuleEngine(self.db_session)
# 注册算法
self._register_algorithms(rule_engine)
# 执行每个实体的检测
total_results = []
completed_count = 0
for execution in executions:
try:
# 检查是否有规则ID
if execution.rule_ids:
# 从数据库获取规则
rule_ids = [rid.strip() for rid in execution.rule_ids.split(',') if rid.strip()]
rules = await self._get_rules(rule_ids)
else:
# 即时检测模式:从参数中获取算法代码
rules = await self._get_rules_from_parameters(execution.parameters)
if not rules:
logger.warning(
f"实体 {execution.entity_id} 没有可执行的规则,跳过"
)
continue
# 执行检测
result = await rule_engine.execute_detection(
task_id=task_id,
rules=rules,
entity_id=execution.entity_id,
entity_type=execution.entity_type,
period=execution.period,
parameters=execution.parameters,
)
total_results.append(result)
# 更新执行记录
execution.status = "completed"
execution.completed_at = datetime.now()
execution.result_count = len(result.get("results", []))
await self.db_session.flush()
completed_count += 1
task.processed_entities = completed_count
await self.db_session.flush()
logger.debug(
f"实体检测完成: {execution.entity_id}, "
f"进度: {completed_count}/{len(executions)}"
)
except Exception as e:
logger.error(
f"实体检测失败: {execution.entity_id}, 错误: {str(e)}",
exc_info=True,
)
execution.status = "failed"
execution.error_message = str(e)
execution.completed_at = datetime.now()
await self.db_session.flush()
# 汇总任务结果
summary = self._summarize_task_results(total_results)
# 更新任务状态
task.status = TaskStatus.COMPLETED
task.completed_at = datetime.now()
task.result_count = len(total_results)
task.summary = summary
await self.db_session.flush()
logger.info(
f"任务执行完成: {task_id}, 结果数: {len(total_results)}"
)
return {
"task_id": task_id,
"status": task.status.value,
"summary": summary,
"results": total_results,
"executed_at": datetime.now().isoformat(),
}
except Exception as e:
logger.error("任务执行失败: {}, 错误: {}", task_id, str(e), exc_info=True)
# 更新任务状态为失败
task = await self._get_task(task_id)
if task:
task.status = TaskStatus.FAILED
task.error_message = str(e)
task.completed_at = datetime.now()
await self.db_session.flush()
raise
async def get_task(self, task_id: str) -> Optional[DetectionTask]:
"""获取任务"""
stmt = select(DetectionTask).where(DetectionTask.task_id == task_id)
result = await self.db_session.execute(stmt)
return result.scalar_one_or_none()
async def list_tasks(
self,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: int = 100,
) -> List[DetectionTask]:
"""查询任务列表"""
stmt = select(DetectionTask)
if status:
stmt = stmt.where(DetectionTask.status == status)
if task_type:
stmt = stmt.where(DetectionTask.task_type == task_type)
stmt = stmt.order_by(DetectionTask.created_at.desc()).limit(limit)
result = await self.db_session.execute(stmt)
return result.scalars().all()
async def _get_task(self, task_id: str) -> Optional[DetectionTask]:
"""获取任务"""
stmt = select(DetectionTask).where(DetectionTask.task_id == task_id)
result = await self.db_session.execute(stmt)
return result.scalar_one_or_none()
async def _get_rules(self, rule_ids: Optional[List[str]]) -> List[DetectionRule]:
"""获取规则"""
if not rule_ids:
# 如果没有指定规则,获取所有启用的规则
stmt = select(DetectionRule).where(DetectionRule.is_enabled == True)
else:
stmt = select(DetectionRule).where(DetectionRule.rule_id.in_(rule_ids))
result = await self.db_session.execute(stmt)
return list(result.scalars().all())
async def _get_rules_from_parameters(self, parameters: Dict[str, Any]) -> List[DetectionRule]:
"""
从参数中获取算法代码并创建临时的DetectionRule对象
用于即时检测模式
"""
from app.models.risk_detection import RiskLevel
# 从parameters中获取算法代码列表
rule_codes = parameters.get('rule_ids', [])
if not rule_codes:
logger.warning("参数中没有找到 rule_ids")
return []
rules = []
for code in rule_codes:
# 创建临时的DetectionRule对象
rule = DetectionRule(
rule_id=f"temp_{code}_{uuid.uuid4().hex[:8]}",
rule_name=f"临时规则-{code}",
algorithm_code=code,
description=f"即时检测临时规则: {code}",
parameters=parameters,
is_enabled=True,
created_at=datetime.now(),
)
rules.append(rule)
logger.info(f"从参数创建了 {len(rules)} 个临时规则")
return rules
def _register_algorithms(self, rule_engine: RuleEngine):
"""注册算法"""
from app.services.risk_detection.algorithms import (
RevenueIntegrityAlgorithm,
PrivateAccountDetectionAlgorithm,
InvoiceFraudDetectionAlgorithm,
ExpenseAnomalyDetectionAlgorithm,
TaxRiskAssessmentAlgorithm,
)
rule_engine.algorithm_registry.register(RevenueIntegrityAlgorithm)
rule_engine.algorithm_registry.register(PrivateAccountDetectionAlgorithm)
rule_engine.algorithm_registry.register(InvoiceFraudDetectionAlgorithm)
rule_engine.algorithm_registry.register(ExpenseAnomalyDetectionAlgorithm)
rule_engine.algorithm_registry.register(TaxRiskAssessmentAlgorithm)
async def _create_task_execution(
self,
task_id: str,
entity_id: str,
entity_type: str,
period: str,
rules: List[DetectionRule],
parameters: Dict[str, Any],
):
"""创建任务执行记录"""
from app.models.risk_detection import TaskExecution
# 将rule_ids列表转换为逗号分隔的字符串
rule_ids_str = ",".join([r.rule_id for r in rules]) if rules else ""
execution = TaskExecution(
execution_id=f"exec_{uuid.uuid4().hex[:12]}",
task_id=task_id,
entity_id=entity_id,
entity_type=entity_type,
period=period,
rule_ids=rule_ids_str,
parameters=parameters,
status="pending",
created_at=datetime.now(),
)
self.db_session.add(execution)
await self.db_session.flush()
return execution
async def _get_task_executions(self, task_id: str):
"""获取任务执行记录"""
from app.models.risk_detection import TaskExecution
stmt = select(TaskExecution).where(TaskExecution.task_id == task_id)
result = await self.db_session.execute(stmt)
return list(result.scalars().all())
def _summarize_task_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""汇总任务结果"""
if not results:
return {"total_entities": 0}
total_entities = len(results)
total_detections = sum(len(r.get("results", [])) for r in results)
# 统计风险分布
risk_distribution = {}
all_scores = []
for result in results:
summary = result.get("summary", {})
dist = summary.get("risk_distribution", {})
for level, count in dist.items():
risk_distribution[level] = risk_distribution.get(level, 0) + count
all_scores.append(summary.get("avg_score", 0))
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0
return {
"total_entities": total_entities,
"total_detections": total_detections,
"risk_distribution": risk_distribution,
"avg_score": round(avg_score, 2),
}