deep-risk/backend/app/services/risk_detection/task_manager/scheduler.py
2025-12-14 20:08:27 +08:00

233 lines
6.6 KiB
Python

"""
风险检测任务调度器
负责定时任务调度和自动检测任务执行
"""
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from loguru import logger
from app.models.risk_detection import DetectionTask, TaskType, TaskStatus
from .task_manager import TaskManager
class DetectionScheduler:
"""风险检测任务调度器"""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
self.task_manager = TaskManager(db_session)
async def schedule_periodic_detection(
self,
entity_ids: List[str],
entity_type: str,
period: str,
auto_execute: bool = True,
) -> DetectionTask:
"""
调度定期检测任务
Args:
entity_ids: 实体ID列表
entity_type: 实体类型
period: 检测期间
auto_execute: 是否自动执行
Returns:
创建的任务
"""
logger.info(
f"调度定期检测: 实体数={len(entity_ids)}, 期间={period}"
)
# 创建任务
task = await self.task_manager.create_task(
task_name=f"定期检测-{period}",
task_type=TaskType.PERIODIC,
entity_ids=entity_ids,
entity_type=entity_type,
period=period,
parameters={"schedule_type": "periodic"},
)
# 如果自动执行,立即执行
if auto_execute:
await self._execute_task_async(task.task_id)
return task
async def schedule_on_demand_detection(
self,
entity_id: str,
entity_type: str,
period: str,
rule_ids: Optional[List[str]] = None,
auto_execute: bool = True,
) -> DetectionTask:
"""
调度按需检测任务
Args:
entity_id: 实体ID
entity_type: 实体类型
period: 检测期间
rule_ids: 指定规则ID列表
auto_execute: 是否自动执行
Returns:
创建的任务
"""
logger.info(
f"调度按需检测: entity_id={entity_id}, 期间={period}"
)
# 创建任务
task = await self.task_manager.create_task(
task_name=f"按需检测-{entity_id}-{period}",
task_type=TaskType.ON_DEMAND,
entity_ids=[entity_id],
entity_type=entity_type,
period=period,
rule_ids=rule_ids,
parameters={"schedule_type": "on_demand"},
)
# 如果自动执行,立即执行
if auto_execute:
await self._execute_task_async(task.task_id)
return task
async def schedule_batch_detection(
self,
entity_ids: List[str],
entity_type: str,
period: str,
rule_ids: Optional[List[str]] = None,
auto_execute: bool = False,
) -> DetectionTask:
"""
调度批量检测任务
Args:
entity_ids: 实体ID列表
entity_type: 实体类型
period: 检测期间
rule_ids: 指定规则ID列表
auto_execute: 是否自动执行
Returns:
创建的任务
"""
logger.info(
f"调度批量检测: 实体数={len(entity_ids)}, 期间={period}"
)
# 创建任务
task = await self.task_manager.create_task(
task_name=f"批量检测-{period}",
task_type=TaskType.BATCH,
entity_ids=entity_ids,
entity_type=entity_type,
period=period,
rule_ids=rule_ids,
parameters={
"schedule_type": "batch",
"batch_size": len(entity_ids),
},
)
# 批量任务不自动执行,需要手动触发
if auto_execute:
await self._execute_task_async(task.task_id)
return task
async def check_pending_tasks(self) -> List[DetectionTask]:
"""检查待执行的任务"""
stmt = select(DetectionTask).where(
DetectionTask.status == TaskStatus.PENDING
)
result = await self.db_session.execute(stmt)
tasks = list(result.scalars().all())
logger.info(f"发现 {len(tasks)} 个待执行任务")
return tasks
async def retry_failed_tasks(
self,
max_retries: int = 3,
) -> List[DetectionTask]:
"""重试失败的任务"""
stmt = select(DetectionTask).where(
and_(
DetectionTask.status == TaskStatus.FAILED,
DetectionTask.retry_count < max_retries,
)
)
result = await self.db_session.execute(stmt)
failed_tasks = list(result.scalars().all())
logger.info(f"发现 {len(failed_tasks)} 个失败任务需要重试")
retried_tasks = []
for task in failed_tasks:
try:
# 重置任务状态
task.status = TaskStatus.PENDING
task.retry_count = (task.retry_count or 0) + 1
await self.db_session.flush()
# 重新执行
await self._execute_task_async(task.task_id)
retried_tasks.append(task)
except Exception as e:
logger.error(
f"任务重试失败: {task.task_id}, 错误: {str(e)}",
exc_info=True,
)
return retried_tasks
async def cleanup_old_tasks(
self,
days: int = 30,
) -> int:
"""清理旧任务"""
cutoff_date = datetime.now() - timedelta(days=days)
stmt = select(DetectionTask).where(
and_(
DetectionTask.created_at < cutoff_date,
DetectionTask.status.in_([TaskStatus.COMPLETED, TaskStatus.FAILED]),
)
)
result = await self.db_session.execute(stmt)
old_tasks = list(result.scalars().all())
count = len(old_tasks)
if count > 0:
logger.info(f"清理 {count} 个旧任务")
for task in old_tasks:
await self.db_session.delete(task)
await self.db_session.commit()
return count
async def _execute_task_async(self, task_id: str):
"""异步执行任务(在后台运行)"""
import asyncio
# 在后台运行任务执行
asyncio.create_task(self.task_manager.execute_task(task_id))
logger.info(f"任务已提交后台执行: {task_id}")