210 lines
6.1 KiB
Python
210 lines
6.1 KiB
Python
"""
|
||
风险检测算法基类
|
||
"""
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, Any, Optional, List
|
||
from datetime import datetime
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from loguru import logger
|
||
from app.models.risk_detection import RiskLevel
|
||
|
||
|
||
class DetectionContext:
|
||
"""检测上下文"""
|
||
|
||
def __init__(
|
||
self,
|
||
task_id: str,
|
||
rule_id: str,
|
||
parameters: Optional[Dict[str, Any]] = None,
|
||
db_session: Optional[AsyncSession] = None,
|
||
):
|
||
self.task_id = task_id
|
||
self.rule_id = rule_id
|
||
self.parameters = parameters or {}
|
||
self.db_session = db_session
|
||
self.start_time = datetime.now()
|
||
|
||
def get_parameter(self, key: str, default: Any = None) -> Any:
|
||
"""获取参数"""
|
||
return self.parameters.get(key, default)
|
||
|
||
def set_parameter(self, key: str, value: Any) -> None:
|
||
"""设置参数"""
|
||
self.parameters[key] = value
|
||
|
||
|
||
class RiskEvidence:
|
||
"""风险证据"""
|
||
|
||
def __init__(
|
||
self,
|
||
evidence_type: str,
|
||
description: str,
|
||
data: Any,
|
||
metadata: Optional[Dict[str, str]] = None,
|
||
):
|
||
self.type = evidence_type
|
||
self.description = description
|
||
self.data = data
|
||
self.metadata = metadata or {}
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return {
|
||
"type": self.type,
|
||
"description": self.description,
|
||
"data": self.data,
|
||
"metadata": self.metadata,
|
||
}
|
||
|
||
def __repr__(self):
|
||
"""返回对象的字符串表示"""
|
||
return f"RiskEvidence(type={self.type}, description={self.description})"
|
||
|
||
@property
|
||
def __dict__(self):
|
||
"""返回对象的字典表示,用于JSON序列化"""
|
||
return self.to_dict()
|
||
|
||
|
||
class DetectionResult:
|
||
"""检测结果"""
|
||
|
||
def __init__(
|
||
self,
|
||
task_id: str,
|
||
rule_id: str,
|
||
entity_id: str,
|
||
entity_type: str,
|
||
risk_level: RiskLevel,
|
||
risk_score: float,
|
||
description: str,
|
||
suggestion: str,
|
||
risk_data: Optional[Dict[str, Any]] = None,
|
||
evidence: Optional[List[RiskEvidence]] = None,
|
||
):
|
||
self.task_id = task_id
|
||
self.rule_id = rule_id
|
||
self.entity_id = entity_id
|
||
self.entity_type = entity_type
|
||
self.risk_level = risk_level
|
||
self.risk_score = risk_score
|
||
self.description = description
|
||
self.suggestion = suggestion
|
||
self.risk_data = risk_data or {}
|
||
self.evidence = evidence or []
|
||
self.detected_at = datetime.now()
|
||
|
||
def add_evidence(self, evidence: RiskEvidence):
|
||
"""添加证据"""
|
||
self.evidence.append(evidence)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return {
|
||
"task_id": self.task_id,
|
||
"rule_id": self.rule_id,
|
||
"entity_id": self.entity_id,
|
||
"entity_type": self.entity_type,
|
||
"risk_level": self.risk_level.value if isinstance(self.risk_level, RiskLevel) else self.risk_level,
|
||
"risk_score": self.risk_score,
|
||
"description": self.description,
|
||
"suggestion": self.suggestion,
|
||
"risk_data": self.risk_data,
|
||
"evidence": [e.to_dict() for e in self.evidence],
|
||
"detected_at": self.detected_at.isoformat(),
|
||
}
|
||
|
||
|
||
class RiskDetectionAlgorithm(ABC):
|
||
"""风险检测算法基类"""
|
||
|
||
def __init__(self):
|
||
self.config: Dict[str, Any] = {}
|
||
|
||
@abstractmethod
|
||
def get_algorithm_code(self) -> str:
|
||
"""获取算法编码"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_algorithm_name(self) -> str:
|
||
"""获取算法名称"""
|
||
pass
|
||
|
||
def get_description(self) -> str:
|
||
"""获取算法描述"""
|
||
return f"算法: {self.get_algorithm_name()}"
|
||
|
||
def init(self, config: Dict[str, Any]):
|
||
"""初始化算法"""
|
||
self.config = config or {}
|
||
|
||
async def detect(self, context: DetectionContext) -> DetectionResult:
|
||
"""
|
||
执行检测(异步方法)
|
||
|
||
执行流程:
|
||
1. 数据准备
|
||
2. 数据校验
|
||
3. 执行检测
|
||
4. 结果后处理
|
||
"""
|
||
try:
|
||
# 数据准备
|
||
await self.prepare_data(context)
|
||
|
||
# 数据校验
|
||
await self.validate_data(context)
|
||
|
||
# 执行检测
|
||
result = await self._do_detect(context)
|
||
|
||
# 结果后处理
|
||
await self.post_process(result, context)
|
||
|
||
# 记录执行日志
|
||
execution_time = (datetime.now() - context.start_time).total_seconds()
|
||
logger.info(
|
||
f"算法 {self.get_algorithm_code()} 执行完成,"
|
||
f"实体ID: {result.entity_id},风险等级: {result.risk_level.value},"
|
||
f"耗时: {execution_time:.2f}秒"
|
||
)
|
||
|
||
return result
|
||
except Exception as e:
|
||
logger.error(
|
||
f"算法 {self.get_algorithm_code()} 执行失败: {str(e)}",
|
||
exc_info=True
|
||
)
|
||
# 返回错误结果
|
||
return DetectionResult(
|
||
task_id=context.task_id,
|
||
rule_id=context.rule_id,
|
||
entity_id=context.get_parameter("entity_id", ""),
|
||
entity_type=context.get_parameter("entity_type", "unknown"),
|
||
risk_level=RiskLevel.UNKNOWN,
|
||
risk_score=0.0,
|
||
description=f"检测执行失败: {str(e)}",
|
||
suggestion="请检查系统日志或联系管理员",
|
||
)
|
||
|
||
@abstractmethod
|
||
async def _do_detect(self, context: DetectionContext) -> DetectionResult:
|
||
"""执行具体检测逻辑(子类必须实现)"""
|
||
pass
|
||
|
||
async def prepare_data(self, context: DetectionContext):
|
||
"""数据准备(可被子类重写)"""
|
||
pass
|
||
|
||
async def validate_data(self, context: DetectionContext):
|
||
"""数据校验(可被子类重写)"""
|
||
pass
|
||
|
||
async def post_process(self, result: DetectionResult, context: DetectionContext):
|
||
"""结果后处理(可被子类重写)"""
|
||
pass
|
||
|