""" 风险检测任务管理器 负责创建、管理和执行检测任务 """ import uuid from typing import List, Dict, Any, Optional from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_ from loguru import logger from app.models.risk_detection import ( DetectionTask, DetectionRule, TaskType, TaskStatus, ) from app.services.risk_detection.engine.rule_engine import RuleEngine class TaskManager: """风险检测任务管理器""" def __init__(self, db_session: AsyncSession): self.db_session = db_session async def create_task( self, task_name: str, task_type: TaskType, entity_ids: List[str], entity_type: str, period: str, rule_ids: Optional[List[str]] = None, parameters: Optional[Dict[str, Any]] = None, ) -> DetectionTask: """ 创建检测任务 Args: task_name: 任务名称 task_type: 任务类型 entity_ids: 实体ID列表 entity_type: 实体类型 period: 检测期间 rule_ids: 规则ID列表 parameters: 额外参数 Returns: 创建的任务 """ logger.info( f"创建检测任务: {task_name}, 类型: {task_type.value}, " f"实体数: {len(entity_ids)}" ) # 生成任务ID task_id = f"task_{uuid.uuid4().hex[:12]}" # 创建任务 task = DetectionTask( task_id=task_id, task_name=task_name, task_type=task_type, status=TaskStatus.PENDING, entity_type=entity_type, period=period, total_entities=len(entity_ids), processed_entities=0, parameters=parameters or {}, created_at=datetime.now(), ) self.db_session.add(task) await self.db_session.flush() await self.db_session.refresh(task) # 为每个实体创建任务执行记录 task_executions = [] for entity_id in entity_ids: # 如果是即时检测(task_type == ON_DEMAND)且有rule_ids,说明是算法代码 if task_type == TaskType.ON_DEMAND and rule_ids: # 即时检测模式:rule_ids是算法代码,直接保存在parameters中 exec_params = parameters or {} exec_params['rule_ids'] = rule_ids execution = await self._create_task_execution( task.task_id, entity_id, entity_type, period, [], exec_params ) else: # 任务模式:从数据库获取规则 rules = await self._get_rules(rule_ids) execution = await self._create_task_execution( task.task_id, entity_id, entity_type, period, rules, parameters or {} ) task_executions.append(execution) logger.info( f"任务创建完成: {task_id}, 执行记录数: {len(task_executions)}" ) return task async def execute_task(self, task_id: str) -> Dict[str, Any]: """ 执行检测任务 Args: task_id: 任务ID Returns: 执行结果 """ logger.info(f"开始执行任务: {task_id}") try: # 获取任务 task = await self._get_task(task_id) if not task: raise ValueError(f"任务不存在: {task_id}") # 更新任务状态 task.status = TaskStatus.RUNNING task.started_at = datetime.now() await self.db_session.flush() # 获取任务执行记录 executions = await self._get_task_executions(task_id) # 创建规则引擎 rule_engine = RuleEngine(self.db_session) # 注册算法 self._register_algorithms(rule_engine) # 执行每个实体的检测 total_results = [] completed_count = 0 for execution in executions: try: # 检查是否有规则ID if execution.rule_ids: # 从数据库获取规则 rule_ids = [rid.strip() for rid in execution.rule_ids.split(',') if rid.strip()] rules = await self._get_rules(rule_ids) else: # 即时检测模式:从参数中获取算法代码 rules = await self._get_rules_from_parameters(execution.parameters) if not rules: logger.warning( f"实体 {execution.entity_id} 没有可执行的规则,跳过" ) continue # 执行检测 result = await rule_engine.execute_detection( task_id=task_id, rules=rules, entity_id=execution.entity_id, entity_type=execution.entity_type, period=execution.period, parameters=execution.parameters, ) total_results.append(result) # 更新执行记录 execution.status = "completed" execution.completed_at = datetime.now() execution.result_count = len(result.get("results", [])) await self.db_session.flush() completed_count += 1 task.processed_entities = completed_count await self.db_session.flush() logger.debug( f"实体检测完成: {execution.entity_id}, " f"进度: {completed_count}/{len(executions)}" ) except Exception as e: logger.error( f"实体检测失败: {execution.entity_id}, 错误: {str(e)}", exc_info=True, ) execution.status = "failed" execution.error_message = str(e) execution.completed_at = datetime.now() await self.db_session.flush() # 汇总任务结果 summary = self._summarize_task_results(total_results) # 更新任务状态 task.status = TaskStatus.COMPLETED task.completed_at = datetime.now() task.result_count = len(total_results) task.summary = summary await self.db_session.flush() logger.info( f"任务执行完成: {task_id}, 结果数: {len(total_results)}" ) return { "task_id": task_id, "status": task.status.value, "summary": summary, "results": total_results, "executed_at": datetime.now().isoformat(), } except Exception as e: logger.error("任务执行失败: {}, 错误: {}", task_id, str(e), exc_info=True) # 更新任务状态为失败 task = await self._get_task(task_id) if task: task.status = TaskStatus.FAILED task.error_message = str(e) task.completed_at = datetime.now() await self.db_session.flush() raise async def get_task(self, task_id: str) -> Optional[DetectionTask]: """获取任务""" stmt = select(DetectionTask).where(DetectionTask.task_id == task_id) result = await self.db_session.execute(stmt) return result.scalar_one_or_none() async def list_tasks( self, status: Optional[TaskStatus] = None, task_type: Optional[TaskType] = None, limit: int = 100, ) -> List[DetectionTask]: """查询任务列表""" stmt = select(DetectionTask) if status: stmt = stmt.where(DetectionTask.status == status) if task_type: stmt = stmt.where(DetectionTask.task_type == task_type) stmt = stmt.order_by(DetectionTask.created_at.desc()).limit(limit) result = await self.db_session.execute(stmt) return result.scalars().all() async def _get_task(self, task_id: str) -> Optional[DetectionTask]: """获取任务""" stmt = select(DetectionTask).where(DetectionTask.task_id == task_id) result = await self.db_session.execute(stmt) return result.scalar_one_or_none() async def _get_rules(self, rule_ids: Optional[List[str]]) -> List[DetectionRule]: """获取规则""" if not rule_ids: # 如果没有指定规则,获取所有启用的规则 stmt = select(DetectionRule).where(DetectionRule.is_enabled == True) else: stmt = select(DetectionRule).where(DetectionRule.rule_id.in_(rule_ids)) result = await self.db_session.execute(stmt) return list(result.scalars().all()) async def _get_rules_from_parameters(self, parameters: Dict[str, Any]) -> List[DetectionRule]: """ 从参数中获取算法代码并创建临时的DetectionRule对象 用于即时检测模式 """ from app.models.risk_detection import RiskLevel # 从parameters中获取算法代码列表 rule_codes = parameters.get('rule_ids', []) if not rule_codes: logger.warning("参数中没有找到 rule_ids") return [] rules = [] for code in rule_codes: # 创建临时的DetectionRule对象 rule = DetectionRule( rule_id=f"temp_{code}_{uuid.uuid4().hex[:8]}", rule_name=f"临时规则-{code}", algorithm_code=code, description=f"即时检测临时规则: {code}", parameters=parameters, is_enabled=True, created_at=datetime.now(), ) rules.append(rule) logger.info(f"从参数创建了 {len(rules)} 个临时规则") return rules def _register_algorithms(self, rule_engine: RuleEngine): """注册算法""" from app.services.risk_detection.algorithms import ( RevenueIntegrityAlgorithm, PrivateAccountDetectionAlgorithm, InvoiceFraudDetectionAlgorithm, ExpenseAnomalyDetectionAlgorithm, TaxRiskAssessmentAlgorithm, ) rule_engine.algorithm_registry.register(RevenueIntegrityAlgorithm) rule_engine.algorithm_registry.register(PrivateAccountDetectionAlgorithm) rule_engine.algorithm_registry.register(InvoiceFraudDetectionAlgorithm) rule_engine.algorithm_registry.register(ExpenseAnomalyDetectionAlgorithm) rule_engine.algorithm_registry.register(TaxRiskAssessmentAlgorithm) async def _create_task_execution( self, task_id: str, entity_id: str, entity_type: str, period: str, rules: List[DetectionRule], parameters: Dict[str, Any], ): """创建任务执行记录""" from app.models.risk_detection import TaskExecution # 将rule_ids列表转换为逗号分隔的字符串 rule_ids_str = ",".join([r.rule_id for r in rules]) if rules else "" execution = TaskExecution( execution_id=f"exec_{uuid.uuid4().hex[:12]}", task_id=task_id, entity_id=entity_id, entity_type=entity_type, period=period, rule_ids=rule_ids_str, parameters=parameters, status="pending", created_at=datetime.now(), ) self.db_session.add(execution) await self.db_session.flush() return execution async def _get_task_executions(self, task_id: str): """获取任务执行记录""" from app.models.risk_detection import TaskExecution stmt = select(TaskExecution).where(TaskExecution.task_id == task_id) result = await self.db_session.execute(stmt) return list(result.scalars().all()) def _summarize_task_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """汇总任务结果""" if not results: return {"total_entities": 0} total_entities = len(results) total_detections = sum(len(r.get("results", [])) for r in results) # 统计风险分布 risk_distribution = {} all_scores = [] for result in results: summary = result.get("summary", {}) dist = summary.get("risk_distribution", {}) for level, count in dist.items(): risk_distribution[level] = risk_distribution.get(level, 0) + count all_scores.append(summary.get("avg_score", 0)) avg_score = sum(all_scores) / len(all_scores) if all_scores else 0 return { "total_entities": total_entities, "total_detections": total_detections, "risk_distribution": risk_distribution, "avg_score": round(avg_score, 2), }