deep-risk/backend/app/services/risk_detection/engine/rule_engine.py
2025-12-14 20:08:27 +08:00

760 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
规则引擎
风险检测系统的核心引擎,负责协调各个组件
功能:
1. 管理和注册检测算法
2. 加载和验证检测规则
3. 生成执行计划
4. 调度和执行检测任务
5. 聚合和处理检测结果
6. 提供完整的检测服务接口
"""
from typing import Dict, Any, List, Optional, Type, Set, Tuple
import asyncio
from datetime import datetime
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.risk_detection import DetectionRule, RiskLevel
from .dependency_resolver import DependencyResolver
from .execution_plan import ExecutionPlan, ExecutionPlanner, ExecutionMode
from .result_processor import ResultProcessor, DetectionResult as DetectionResultType
from ..algorithms.base import DetectionContext, RiskDetectionAlgorithm
from ..algorithms.revenue_integrity import RevenueIntegrityAlgorithm
# TODO: 导入其他算法类
# from ..algorithms.private_account import PrivateAccountDetectionAlgorithm
# from ..algorithms.invoice_fraud import InvoiceFraudDetectionAlgorithm
class AlgorithmRegistry:
"""
算法注册表
管理所有可用的检测算法
"""
def __init__(self):
self._algorithms: Dict[str, Type[RiskDetectionAlgorithm]] = {}
self._instances: Dict[str, RiskDetectionAlgorithm] = {}
self._initialize_default_algorithms()
def _initialize_default_algorithms(self):
"""初始化默认算法"""
# 注册收入完整性检测算法
self.register(RevenueIntegrityAlgorithm)
# TODO: 注册其他算法
# self.register(PrivateAccountDetectionAlgorithm)
# self.register(InvoiceFraudDetectionAlgorithm)
logger.info(f"算法注册表初始化完成,共注册 {len(self._algorithms)} 个算法")
def register(self, algorithm_class: Type[RiskDetectionAlgorithm]):
"""
注册算法类
Args:
algorithm_class: 算法类必须继承RiskDetectionAlgorithm
"""
if not hasattr(algorithm_class, 'get_algorithm_code'):
raise ValueError(
f"算法类 {algorithm_class.__name__} 必须实现 get_algorithm_code() 方法"
)
code = algorithm_class().get_algorithm_code()
self._algorithms[code] = algorithm_class
logger.debug(f"注册算法:{code} ({algorithm_class.__name__})")
def get_algorithm(self, algorithm_code: str) -> RiskDetectionAlgorithm:
"""
获取算法实例(单例模式)
Args:
algorithm_code: 算法代码
Returns:
算法实例
Raises:
ValueError: 如果算法未注册
"""
if algorithm_code not in self._algorithms:
raise ValueError(f"未注册的算法:{algorithm_code}")
if algorithm_code not in self._instances:
algorithm_class = self._algorithms[algorithm_code]
self._instances[algorithm_code] = algorithm_class()
return self._instances[algorithm_code]
def get_all_algorithms(self) -> List[str]:
"""获取所有已注册的算法代码"""
return list(self._algorithms.keys())
def is_registered(self, algorithm_code: str) -> bool:
"""检查算法是否已注册"""
return algorithm_code in self._algorithms
class RuleEngine:
"""
规则引擎
风险检测系统的核心引擎
"""
def __init__(self, db_session: Optional[AsyncSession] = None):
self.db_session = db_session
self.algorithm_registry = AlgorithmRegistry()
self.dependency_resolver = DependencyResolver()
self.execution_planner = ExecutionPlanner()
self.result_processor = ResultProcessor()
logger.info("规则引擎初始化完成")
async def execute_detection(
self,
task_id: str,
rules: List[DetectionRule],
entity_id: str,
entity_type: str,
period: str,
parameters: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
执行风险检测(简化接口,供 TaskManager 使用)
Args:
task_id: 任务ID
rules: 检测规则列表
entity_id: 实体ID
entity_type: 实体类型
period: 检测期间
parameters: 检测参数
Returns:
检测结果报告
"""
logger.info(
f"开始执行风险检测task_id={task_id}, entity={entity_id}, "
f"entity_type={entity_type}, period={period}, 规则数={len(rules)}"
)
# 调用完整的检测任务接口
result = await self.execute_detection_task(
task_id=task_id,
entity_id=entity_id,
entity_type=entity_type,
period=period,
rules=rules,
parameters=parameters or {}
)
return result
async def execute_detection_task(
self,
task_id: str,
entity_id: str,
entity_type: str,
period: str,
rules: List[DetectionRule],
execution_mode: ExecutionMode = ExecutionMode.HYBRID,
parameters: Optional[Dict[str, Any]] = None,
rule_weights: Optional[Dict[str, float]] = None,
max_concurrent: int = 10
) -> Dict[str, Any]:
"""
执行检测任务
Args:
task_id: 任务ID
entity_id: 实体ID
entity_type: 实体类型
period: 检测期间
rules: 检测规则列表
execution_mode: 执行模式
rule_weights: 规则权重配置
max_concurrent: 最大并发数
Returns:
检测结果报告
"""
logger.info(
f"开始执行检测任务task_id={task_id}, entity={entity_id}, "
f"规则数={len(rules)}, 模式={execution_mode.value}"
)
start_time = datetime.now()
try:
# 1. 验证规则和算法
self._validate_rules(rules)
# 2. 生成执行计划
execution_plan = self.execution_planner.create_plan(
task_id=task_id,
entity_id=entity_id,
entity_type=entity_type,
period=period,
rules=rules,
execution_mode=execution_mode
)
# 更新执行计划中所有阶段的 parameters
for stage in execution_plan.stages:
if 'context' in stage.__dict__:
stage.context['parameters'] = parameters or {}
logger.info(
f"执行计划生成完成:{len(execution_plan.stages)} 个阶段,"
f"最大层级={execution_plan.max_level}"
)
# 3. 执行检测
results = await self._execute_detection(
execution_plan,
max_concurrent
)
# 4. 处理结果
final_report = self.result_processor.process_results(
results=results,
task_id=task_id,
entity_id=entity_id,
entity_type=entity_type,
period=period,
rule_weights=rule_weights
)
# 5. 添加执行计划信息到报告
final_report["execution_plan"] = execution_plan.to_dict()
# 6. 添加执行统计信息
end_time = datetime.now()
execution_duration = (end_time - start_time).total_seconds()
final_report["execution_stats"] = {
"task_id": task_id,
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"duration_seconds": execution_duration,
"mode": execution_mode.value,
"total_stages": len(execution_plan.stages),
"max_concurrent": max_concurrent,
}
logger.info(
f"检测任务执行完成task_id={task_id}, "
f"耗时={execution_duration:.2f}秒, "
f"总体风险={final_report['summary']['overall_risk_level']}, "
f"综合评分={final_report['summary']['overall_risk_score']:.2f}"
)
return final_report
except Exception as e:
logger.error(f"检测任务执行失败task_id={task_id}, 错误={str(e)}", exc_info=True)
raise
async def _execute_detection(
self,
execution_plan: ExecutionPlan,
max_concurrent: int
) -> List[DetectionResultType]:
"""
执行检测逻辑
Args:
execution_plan: 执行计划
max_concurrent: 最大并发数
Returns:
检测结果列表
"""
results = []
completed_stages = []
# 按阶段顺序执行
for stage in execution_plan.stages:
logger.info(
f"开始执行阶段:{stage.stage_id} ({stage.stage_name}), "
f"规则数={stage.rule_count}, 模式={stage.execution_mode.value}"
)
# 检查阶段依赖
if not self._are_stage_dependencies_met(stage, completed_stages):
raise ValueError(
f"阶段依赖不满足:{stage.stage_id}, "
f"依赖={stage.depends_on}, 已完成={completed_stages}"
)
# 执行当前阶段
stage_results = await self._execute_stage(
stage,
max_concurrent
)
results.extend(stage_results)
completed_stages.append(stage.stage_id)
logger.info(
f"阶段执行完成:{stage.stage_id}, "
f"成功={len(stage_results)}, "
f"累计={len(results)}"
)
return results
async def _execute_stage(
self,
stage,
max_concurrent: int
) -> List[DetectionResultType]:
"""
执行单个阶段
Args:
stage: 执行阶段
max_concurrent: 最大并发数
Returns:
检测结果列表
"""
if stage.execution_mode == ExecutionMode.SEQUENTIAL:
# 串行执行
return await self._execute_sequential(stage)
else:
# 并行执行
return await self._execute_parallel(stage, max_concurrent)
async def _execute_sequential(
self,
stage,
) -> List[DetectionResultType]:
"""串行执行阶段内的所有规则"""
# 从 stage 中获取上下文信息
# stage 应该有 task_id, entity_id, entity_type, period, parameters 等属性
context = getattr(stage, 'context', {})
task_id = context.get('task_id', '')
entity_id = context.get('entity_id', '')
entity_type = context.get('entity_type', '')
period = context.get('period', '')
parameters = context.get('parameters', {})
results = []
for node in stage.nodes:
try:
result = await self._execute_single_rule(
node.rule,
node.rule_id,
task_id=task_id,
entity_id=entity_id,
entity_type=entity_type,
period=period,
parameters=parameters
)
results.append(result)
except Exception as e:
logger.error(
f"规则执行失败:{node.rule_id}, 错误={str(e)}",
exc_info=True
)
# 创建错误结果
error_result = DetectionResultType(
task_id=task_id,
rule_id=node.rule_id,
entity_id=entity_id,
entity_type=entity_type,
risk_level=RiskLevel.UNKNOWN,
risk_score=0.0,
description=f"规则执行失败:{str(e)}",
suggestion="请检查系统日志或联系管理员",
)
results.append(error_result)
return results
async def _execute_parallel(
self,
stage,
max_concurrent: int
) -> List[DetectionResultType]:
"""并行执行阶段内的所有规则"""
# 从 stage 中获取上下文信息
context = getattr(stage, 'context', {})
task_id = context.get('task_id', '')
entity_id = context.get('entity_id', '')
entity_type = context.get('entity_type', '')
period = context.get('period', '')
parameters = context.get('parameters', {})
semaphore = asyncio.Semaphore(max_concurrent)
async def execute_with_semaphore(node):
async with semaphore:
try:
return await self._execute_single_rule(
node.rule,
node.rule_id,
task_id=task_id,
entity_id=entity_id,
entity_type=entity_type,
period=period,
parameters=parameters
)
except Exception as e:
logger.error(
f"规则执行失败:{node.rule_id}, 错误={str(e)}",
exc_info=True
)
# 创建错误结果
error_result = DetectionResultType(
task_id=task_id,
rule_id=node.rule_id,
entity_id=entity_id,
entity_type=entity_type,
risk_level=RiskLevel.UNKNOWN,
risk_score=0.0,
description=f"规则执行失败:{str(e)}",
suggestion="请检查系统日志或联系管理员",
)
return error_result
# 并发执行所有规则
tasks = [execute_with_semaphore(node) for node in stage.nodes]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
final_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"规则执行异常:{stage.nodes[i].rule_id}, 错误={str(result)}")
# 创建异常结果
error_result = DetectionResultType(
task_id=task_id,
rule_id=stage.nodes[i].rule_id,
entity_id=entity_id,
entity_type=entity_type,
risk_level=RiskLevel.UNKNOWN,
risk_score=0.0,
description=f"规则执行异常:{str(result)}",
suggestion="请检查系统日志或联系管理员",
)
final_results.append(error_result)
else:
final_results.append(result)
return final_results
async def _execute_single_rule(
self,
rule: DetectionRule,
rule_id: str,
task_id: str = "",
entity_id: str = "",
entity_type: str = "",
period: str = "",
parameters: Optional[Dict[str, Any]] = None,
) -> DetectionResultType:
"""
执行单个规则检测
Args:
rule: 检测规则
rule_id: 规则ID
task_id: 任务ID
entity_id: 实体ID
entity_type: 实体类型
period: 检测期间
parameters: 检测参数
Returns:
检测结果
"""
# 获取算法实例
algorithm = self.algorithm_registry.get_algorithm(rule.algorithm_code)
# 创建检测上下文
context = DetectionContext(
task_id=task_id,
rule_id=rule_id,
parameters=rule.parameters or {},
db_session=self.db_session
)
# 设置实体信息到上下文
context.set_parameter("entity_id", entity_id)
context.set_parameter("entity_type", entity_type)
context.set_parameter("period", period)
if parameters:
for key, value in parameters.items():
context.set_parameter(key, value)
# 执行检测
result = await algorithm.detect(context)
# 设置结果中的任务和实体信息
result.task_id = task_id
result.entity_id = entity_id
result.entity_type = entity_type
# 保存结果到数据库
if self.db_session:
await self._save_detection_result(result, rule_id)
return result
def _are_stage_dependencies_met(self, stage, completed_stages: List[str]) -> bool:
"""检查阶段依赖是否满足"""
return all(dep in completed_stages for dep in stage.depends_on)
def _validate_rules(self, rules: List[DetectionRule]):
"""
验证规则列表
Args:
rules: 检测规则列表
Raises:
ValueError: 如果规则无效
"""
if not rules:
raise ValueError("规则列表不能为空")
# 检查算法是否已注册
for rule in rules:
if not self.algorithm_registry.is_registered(rule.algorithm_code):
raise ValueError(
f"规则 {rule.rule_id} 的算法 {rule.algorithm_code} 未注册"
)
if not rule.is_enabled:
logger.warning(f"规则 {rule.rule_id} 已禁用,将跳过")
logger.info(f"规则验证通过,共 {len(rules)} 个规则")
def get_algorithm_info(self) -> Dict[str, Dict[str, Any]]:
"""
获取所有算法的信息
Returns:
算法信息字典
"""
info = {}
for code in self.algorithm_registry.get_all_algorithms():
try:
algorithm = self.algorithm_registry.get_algorithm(code)
info[code] = {
"name": algorithm.get_algorithm_name(),
"description": algorithm.get_description(),
"code": code,
}
except Exception as e:
logger.error(f"获取算法信息失败:{code}, 错误={str(e)}")
return info
def validate_execution_plan(
self,
execution_plan: ExecutionPlan
) -> Tuple[bool, Optional[str]]:
"""
验证执行计划的有效性
Args:
execution_plan: 执行计划
Returns:
(是否有效, 错误信息)
"""
try:
# 检查阶段数量
if len(execution_plan.stages) == 0:
return False, "执行计划没有阶段"
# 检查阶段依赖
completed = []
for stage in execution_plan.stages:
if not self._are_stage_dependencies_met(stage, completed):
return False, f"阶段 {stage.stage_id} 的依赖不满足"
completed.append(stage.stage_id)
# 检查规则数量一致性
total_rules_in_stages = sum(stage.rule_count for stage in execution_plan.stages)
if total_rules_in_stages != execution_plan.total_rules:
return False, "阶段规则数量与总规则数量不一致"
return True, None
except Exception as e:
return False, f"执行计划验证失败:{str(e)}"
async def dry_run(
self,
entity_id: str,
entity_type: str,
period: str,
rules: List[DetectionRule],
execution_mode: ExecutionMode = ExecutionMode.HYBRID
) -> Dict[str, Any]:
"""
试运行(不实际执行检测,仅生成执行计划)
Args:
entity_id: 实体ID
entity_type: 实体类型
period: 检测期间
rules: 检测规则列表
execution_mode: 执行模式
Returns:
执行计划信息
"""
logger.info(
f"开始试运行entity={entity_id}, 规则数={len(rules)}, 模式={execution_mode.value}"
)
# 验证规则
self._validate_rules(rules)
# 生成执行计划
execution_plan = self.execution_planner.create_plan(
task_id="dry_run",
entity_id=entity_id,
entity_type=entity_type,
period=period,
rules=rules,
execution_mode=execution_mode
)
# 验证执行计划
is_valid, error = self.validate_execution_plan(execution_plan)
if not is_valid:
raise ValueError(f"执行计划无效:{error}")
return {
"dry_run": True,
"execution_plan": execution_plan.to_dict(),
"summary": {
"total_rules": execution_plan.total_rules,
"total_stages": len(execution_plan.stages),
"max_level": execution_plan.max_level,
"execution_mode": execution_plan.execution_mode.value,
"estimated_duration": self._estimate_execution_time(execution_plan),
},
"algorithm_info": self.get_algorithm_info(),
}
def _estimate_execution_time(self, execution_plan: ExecutionPlan) -> float:
"""
估算执行时间(秒)
Args:
execution_plan: 执行计划
Returns:
估算的执行时间
"""
# 基于规则数量的简单估算
# 假设每个规则平均需要 2 秒
base_time_per_rule = 2.0
# 并行阶段可以减少时间
parallel_factor = 0.7 # 并行执行可以节省 30% 时间
estimated_time = execution_plan.total_rules * base_time_per_rule
# 应用并行因子
if execution_plan.execution_mode == ExecutionMode.PARALLEL:
estimated_time *= parallel_factor
elif execution_plan.execution_mode == ExecutionMode.HYBRID:
# 混合模式根据层级数调整
if execution_plan.max_level > 0:
parallel_stages = sum(1 for stage in execution_plan.stages if stage.execution_mode == ExecutionMode.PARALLEL)
estimated_time *= (0.8 - 0.1 * parallel_stages)
return estimated_time
async def _save_detection_result(self, result: DetectionResultType, rule_id: str):
"""保存检测结果到数据库"""
from app.models.risk_detection import DetectionResult, RiskLevel
from datetime import datetime
try:
# 转换 RiskLevel 枚举
risk_level_enum = None
if result.risk_level:
try:
risk_level_enum = RiskLevel(result.risk_level)
except ValueError:
risk_level_enum = RiskLevel.UNKNOWN
# 创建检测结果记录
# 将RiskEvidence对象转换为字典
evidence_list = []
if result.evidence:
for evidence in result.evidence:
if hasattr(evidence, 'to_dict'):
evidence_list.append(evidence.to_dict())
else:
evidence_list.append(evidence.__dict__)
# 对于临时规则即时检测模式rule_id不需要保存在数据库中
# 设置为None以避免外键约束错误
db_rule_id = None if rule_id.startswith('temp_') else rule_id
db_result = DetectionResult(
task_id=result.task_id,
rule_id=db_rule_id,
entity_id=result.entity_id,
entity_type=result.entity_type,
risk_level=risk_level_enum,
risk_score=result.risk_score,
description=result.description,
suggestion=result.suggestion,
risk_data=result.risk_data or {},
evidence=evidence_list,
detected_at=datetime.now()
)
self.db_session.add(db_result)
await self.db_session.flush()
logger.debug(
f"保存检测结果成功: task_id={result.task_id}, "
f"rule_id={rule_id}, entity_id={result.entity_id}"
)
except Exception as e:
logger.error(
f"保存检测结果失败: task_id={result.task_id}, "
f"rule_id={rule_id}, 错误={str(e)}",
exc_info=True
)
# 不抛出异常,避免影响主流程
# 全局规则引擎实例
_rule_engine_instance: Optional[RuleEngine] = None
def get_rule_engine() -> RuleEngine:
"""
获取全局规则引擎实例(单例模式)
Returns:
RuleEngine实例
"""
global _rule_engine_instance
if _rule_engine_instance is None:
_rule_engine_instance = RuleEngine()
logger.info("初始化全局规则引擎实例")
return _rule_engine_instance