""" 风险检测任务调度器 负责定时任务调度和自动检测任务执行 """ from typing import List, Dict, Any, Optional from datetime import datetime, timedelta from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from loguru import logger from app.models.risk_detection import DetectionTask, TaskType, TaskStatus from .task_manager import TaskManager class DetectionScheduler: """风险检测任务调度器""" def __init__(self, db_session: AsyncSession): self.db_session = db_session self.task_manager = TaskManager(db_session) async def schedule_periodic_detection( self, entity_ids: List[str], entity_type: str, period: str, auto_execute: bool = True, ) -> DetectionTask: """ 调度定期检测任务 Args: entity_ids: 实体ID列表 entity_type: 实体类型 period: 检测期间 auto_execute: 是否自动执行 Returns: 创建的任务 """ logger.info( f"调度定期检测: 实体数={len(entity_ids)}, 期间={period}" ) # 创建任务 task = await self.task_manager.create_task( task_name=f"定期检测-{period}", task_type=TaskType.PERIODIC, entity_ids=entity_ids, entity_type=entity_type, period=period, parameters={"schedule_type": "periodic"}, ) # 如果自动执行,立即执行 if auto_execute: await self._execute_task_async(task.task_id) return task async def schedule_on_demand_detection( self, entity_id: str, entity_type: str, period: str, rule_ids: Optional[List[str]] = None, auto_execute: bool = True, ) -> DetectionTask: """ 调度按需检测任务 Args: entity_id: 实体ID entity_type: 实体类型 period: 检测期间 rule_ids: 指定规则ID列表 auto_execute: 是否自动执行 Returns: 创建的任务 """ logger.info( f"调度按需检测: entity_id={entity_id}, 期间={period}" ) # 创建任务 task = await self.task_manager.create_task( task_name=f"按需检测-{entity_id}-{period}", task_type=TaskType.ON_DEMAND, entity_ids=[entity_id], entity_type=entity_type, period=period, rule_ids=rule_ids, parameters={"schedule_type": "on_demand"}, ) # 如果自动执行,立即执行 if auto_execute: await self._execute_task_async(task.task_id) return task async def schedule_batch_detection( self, entity_ids: List[str], entity_type: str, period: str, rule_ids: Optional[List[str]] = None, auto_execute: bool = False, ) -> DetectionTask: """ 调度批量检测任务 Args: entity_ids: 实体ID列表 entity_type: 实体类型 period: 检测期间 rule_ids: 指定规则ID列表 auto_execute: 是否自动执行 Returns: 创建的任务 """ logger.info( f"调度批量检测: 实体数={len(entity_ids)}, 期间={period}" ) # 创建任务 task = await self.task_manager.create_task( task_name=f"批量检测-{period}", task_type=TaskType.BATCH, entity_ids=entity_ids, entity_type=entity_type, period=period, rule_ids=rule_ids, parameters={ "schedule_type": "batch", "batch_size": len(entity_ids), }, ) # 批量任务不自动执行,需要手动触发 if auto_execute: await self._execute_task_async(task.task_id) return task async def check_pending_tasks(self) -> List[DetectionTask]: """检查待执行的任务""" stmt = select(DetectionTask).where( DetectionTask.status == TaskStatus.PENDING ) result = await self.db_session.execute(stmt) tasks = list(result.scalars().all()) logger.info(f"发现 {len(tasks)} 个待执行任务") return tasks async def retry_failed_tasks( self, max_retries: int = 3, ) -> List[DetectionTask]: """重试失败的任务""" stmt = select(DetectionTask).where( and_( DetectionTask.status == TaskStatus.FAILED, DetectionTask.retry_count < max_retries, ) ) result = await self.db_session.execute(stmt) failed_tasks = list(result.scalars().all()) logger.info(f"发现 {len(failed_tasks)} 个失败任务需要重试") retried_tasks = [] for task in failed_tasks: try: # 重置任务状态 task.status = TaskStatus.PENDING task.retry_count = (task.retry_count or 0) + 1 await self.db_session.flush() # 重新执行 await self._execute_task_async(task.task_id) retried_tasks.append(task) except Exception as e: logger.error( f"任务重试失败: {task.task_id}, 错误: {str(e)}", exc_info=True, ) return retried_tasks async def cleanup_old_tasks( self, days: int = 30, ) -> int: """清理旧任务""" cutoff_date = datetime.now() - timedelta(days=days) stmt = select(DetectionTask).where( and_( DetectionTask.created_at < cutoff_date, DetectionTask.status.in_([TaskStatus.COMPLETED, TaskStatus.FAILED]), ) ) result = await self.db_session.execute(stmt) old_tasks = list(result.scalars().all()) count = len(old_tasks) if count > 0: logger.info(f"清理 {count} 个旧任务") for task in old_tasks: await self.db_session.delete(task) await self.db_session.commit() return count async def _execute_task_async(self, task_id: str): """异步执行任务(在后台运行)""" import asyncio # 在后台运行任务执行 asyncio.create_task(self.task_manager.execute_task(task_id)) logger.info(f"任务已提交后台执行: {task_id}")