233 lines
6.6 KiB
Python
233 lines
6.6 KiB
Python
"""
|
|
风险检测任务调度器
|
|
负责定时任务调度和自动检测任务执行
|
|
"""
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime, timedelta
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from loguru import logger
|
|
|
|
from app.models.risk_detection import DetectionTask, TaskType, TaskStatus
|
|
from .task_manager import TaskManager
|
|
|
|
|
|
class DetectionScheduler:
|
|
"""风险检测任务调度器"""
|
|
|
|
def __init__(self, db_session: AsyncSession):
|
|
self.db_session = db_session
|
|
self.task_manager = TaskManager(db_session)
|
|
|
|
async def schedule_periodic_detection(
|
|
self,
|
|
entity_ids: List[str],
|
|
entity_type: str,
|
|
period: str,
|
|
auto_execute: bool = True,
|
|
) -> DetectionTask:
|
|
"""
|
|
调度定期检测任务
|
|
|
|
Args:
|
|
entity_ids: 实体ID列表
|
|
entity_type: 实体类型
|
|
period: 检测期间
|
|
auto_execute: 是否自动执行
|
|
|
|
Returns:
|
|
创建的任务
|
|
"""
|
|
logger.info(
|
|
f"调度定期检测: 实体数={len(entity_ids)}, 期间={period}"
|
|
)
|
|
|
|
# 创建任务
|
|
task = await self.task_manager.create_task(
|
|
task_name=f"定期检测-{period}",
|
|
task_type=TaskType.PERIODIC,
|
|
entity_ids=entity_ids,
|
|
entity_type=entity_type,
|
|
period=period,
|
|
parameters={"schedule_type": "periodic"},
|
|
)
|
|
|
|
# 如果自动执行,立即执行
|
|
if auto_execute:
|
|
await self._execute_task_async(task.task_id)
|
|
|
|
return task
|
|
|
|
async def schedule_on_demand_detection(
|
|
self,
|
|
entity_id: str,
|
|
entity_type: str,
|
|
period: str,
|
|
rule_ids: Optional[List[str]] = None,
|
|
auto_execute: bool = True,
|
|
) -> DetectionTask:
|
|
"""
|
|
调度按需检测任务
|
|
|
|
Args:
|
|
entity_id: 实体ID
|
|
entity_type: 实体类型
|
|
period: 检测期间
|
|
rule_ids: 指定规则ID列表
|
|
auto_execute: 是否自动执行
|
|
|
|
Returns:
|
|
创建的任务
|
|
"""
|
|
logger.info(
|
|
f"调度按需检测: entity_id={entity_id}, 期间={period}"
|
|
)
|
|
|
|
# 创建任务
|
|
task = await self.task_manager.create_task(
|
|
task_name=f"按需检测-{entity_id}-{period}",
|
|
task_type=TaskType.ON_DEMAND,
|
|
entity_ids=[entity_id],
|
|
entity_type=entity_type,
|
|
period=period,
|
|
rule_ids=rule_ids,
|
|
parameters={"schedule_type": "on_demand"},
|
|
)
|
|
|
|
# 如果自动执行,立即执行
|
|
if auto_execute:
|
|
await self._execute_task_async(task.task_id)
|
|
|
|
return task
|
|
|
|
async def schedule_batch_detection(
|
|
self,
|
|
entity_ids: List[str],
|
|
entity_type: str,
|
|
period: str,
|
|
rule_ids: Optional[List[str]] = None,
|
|
auto_execute: bool = False,
|
|
) -> DetectionTask:
|
|
"""
|
|
调度批量检测任务
|
|
|
|
Args:
|
|
entity_ids: 实体ID列表
|
|
entity_type: 实体类型
|
|
period: 检测期间
|
|
rule_ids: 指定规则ID列表
|
|
auto_execute: 是否自动执行
|
|
|
|
Returns:
|
|
创建的任务
|
|
"""
|
|
logger.info(
|
|
f"调度批量检测: 实体数={len(entity_ids)}, 期间={period}"
|
|
)
|
|
|
|
# 创建任务
|
|
task = await self.task_manager.create_task(
|
|
task_name=f"批量检测-{period}",
|
|
task_type=TaskType.BATCH,
|
|
entity_ids=entity_ids,
|
|
entity_type=entity_type,
|
|
period=period,
|
|
rule_ids=rule_ids,
|
|
parameters={
|
|
"schedule_type": "batch",
|
|
"batch_size": len(entity_ids),
|
|
},
|
|
)
|
|
|
|
# 批量任务不自动执行,需要手动触发
|
|
if auto_execute:
|
|
await self._execute_task_async(task.task_id)
|
|
|
|
return task
|
|
|
|
async def check_pending_tasks(self) -> List[DetectionTask]:
|
|
"""检查待执行的任务"""
|
|
stmt = select(DetectionTask).where(
|
|
DetectionTask.status == TaskStatus.PENDING
|
|
)
|
|
result = await self.db_session.execute(stmt)
|
|
tasks = list(result.scalars().all())
|
|
|
|
logger.info(f"发现 {len(tasks)} 个待执行任务")
|
|
|
|
return tasks
|
|
|
|
async def retry_failed_tasks(
|
|
self,
|
|
max_retries: int = 3,
|
|
) -> List[DetectionTask]:
|
|
"""重试失败的任务"""
|
|
stmt = select(DetectionTask).where(
|
|
and_(
|
|
DetectionTask.status == TaskStatus.FAILED,
|
|
DetectionTask.retry_count < max_retries,
|
|
)
|
|
)
|
|
result = await self.db_session.execute(stmt)
|
|
failed_tasks = list(result.scalars().all())
|
|
|
|
logger.info(f"发现 {len(failed_tasks)} 个失败任务需要重试")
|
|
|
|
retried_tasks = []
|
|
for task in failed_tasks:
|
|
try:
|
|
# 重置任务状态
|
|
task.status = TaskStatus.PENDING
|
|
task.retry_count = (task.retry_count or 0) + 1
|
|
await self.db_session.flush()
|
|
|
|
# 重新执行
|
|
await self._execute_task_async(task.task_id)
|
|
|
|
retried_tasks.append(task)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"任务重试失败: {task.task_id}, 错误: {str(e)}",
|
|
exc_info=True,
|
|
)
|
|
|
|
return retried_tasks
|
|
|
|
async def cleanup_old_tasks(
|
|
self,
|
|
days: int = 30,
|
|
) -> int:
|
|
"""清理旧任务"""
|
|
cutoff_date = datetime.now() - timedelta(days=days)
|
|
|
|
stmt = select(DetectionTask).where(
|
|
and_(
|
|
DetectionTask.created_at < cutoff_date,
|
|
DetectionTask.status.in_([TaskStatus.COMPLETED, TaskStatus.FAILED]),
|
|
)
|
|
)
|
|
|
|
result = await self.db_session.execute(stmt)
|
|
old_tasks = list(result.scalars().all())
|
|
|
|
count = len(old_tasks)
|
|
if count > 0:
|
|
logger.info(f"清理 {count} 个旧任务")
|
|
|
|
for task in old_tasks:
|
|
await self.db_session.delete(task)
|
|
|
|
await self.db_session.commit()
|
|
|
|
return count
|
|
|
|
async def _execute_task_async(self, task_id: str):
|
|
"""异步执行任务(在后台运行)"""
|
|
import asyncio
|
|
|
|
# 在后台运行任务执行
|
|
asyncio.create_task(self.task_manager.execute_task(task_id))
|
|
|
|
logger.info(f"任务已提交后台执行: {task_id}")
|