deep-risk/backend/app/services/risk_detection/algorithms/base.py
2025-12-14 20:08:27 +08:00

210 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
风险检测算法基类
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from loguru import logger
from app.models.risk_detection import RiskLevel
class DetectionContext:
"""检测上下文"""
def __init__(
self,
task_id: str,
rule_id: str,
parameters: Optional[Dict[str, Any]] = None,
db_session: Optional[AsyncSession] = None,
):
self.task_id = task_id
self.rule_id = rule_id
self.parameters = parameters or {}
self.db_session = db_session
self.start_time = datetime.now()
def get_parameter(self, key: str, default: Any = None) -> Any:
"""获取参数"""
return self.parameters.get(key, default)
def set_parameter(self, key: str, value: Any) -> None:
"""设置参数"""
self.parameters[key] = value
class RiskEvidence:
"""风险证据"""
def __init__(
self,
evidence_type: str,
description: str,
data: Any,
metadata: Optional[Dict[str, str]] = None,
):
self.type = evidence_type
self.description = description
self.data = data
self.metadata = metadata or {}
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"type": self.type,
"description": self.description,
"data": self.data,
"metadata": self.metadata,
}
def __repr__(self):
"""返回对象的字符串表示"""
return f"RiskEvidence(type={self.type}, description={self.description})"
@property
def __dict__(self):
"""返回对象的字典表示用于JSON序列化"""
return self.to_dict()
class DetectionResult:
"""检测结果"""
def __init__(
self,
task_id: str,
rule_id: str,
entity_id: str,
entity_type: str,
risk_level: RiskLevel,
risk_score: float,
description: str,
suggestion: str,
risk_data: Optional[Dict[str, Any]] = None,
evidence: Optional[List[RiskEvidence]] = None,
):
self.task_id = task_id
self.rule_id = rule_id
self.entity_id = entity_id
self.entity_type = entity_type
self.risk_level = risk_level
self.risk_score = risk_score
self.description = description
self.suggestion = suggestion
self.risk_data = risk_data or {}
self.evidence = evidence or []
self.detected_at = datetime.now()
def add_evidence(self, evidence: RiskEvidence):
"""添加证据"""
self.evidence.append(evidence)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"task_id": self.task_id,
"rule_id": self.rule_id,
"entity_id": self.entity_id,
"entity_type": self.entity_type,
"risk_level": self.risk_level.value if isinstance(self.risk_level, RiskLevel) else self.risk_level,
"risk_score": self.risk_score,
"description": self.description,
"suggestion": self.suggestion,
"risk_data": self.risk_data,
"evidence": [e.to_dict() for e in self.evidence],
"detected_at": self.detected_at.isoformat(),
}
class RiskDetectionAlgorithm(ABC):
"""风险检测算法基类"""
def __init__(self):
self.config: Dict[str, Any] = {}
@abstractmethod
def get_algorithm_code(self) -> str:
"""获取算法编码"""
pass
@abstractmethod
def get_algorithm_name(self) -> str:
"""获取算法名称"""
pass
def get_description(self) -> str:
"""获取算法描述"""
return f"算法: {self.get_algorithm_name()}"
def init(self, config: Dict[str, Any]):
"""初始化算法"""
self.config = config or {}
async def detect(self, context: DetectionContext) -> DetectionResult:
"""
执行检测(异步方法)
执行流程:
1. 数据准备
2. 数据校验
3. 执行检测
4. 结果后处理
"""
try:
# 数据准备
await self.prepare_data(context)
# 数据校验
await self.validate_data(context)
# 执行检测
result = await self._do_detect(context)
# 结果后处理
await self.post_process(result, context)
# 记录执行日志
execution_time = (datetime.now() - context.start_time).total_seconds()
logger.info(
f"算法 {self.get_algorithm_code()} 执行完成,"
f"实体ID: {result.entity_id},风险等级: {result.risk_level.value}"
f"耗时: {execution_time:.2f}"
)
return result
except Exception as e:
logger.error(
f"算法 {self.get_algorithm_code()} 执行失败: {str(e)}",
exc_info=True
)
# 返回错误结果
return DetectionResult(
task_id=context.task_id,
rule_id=context.rule_id,
entity_id=context.get_parameter("entity_id", ""),
entity_type=context.get_parameter("entity_type", "unknown"),
risk_level=RiskLevel.UNKNOWN,
risk_score=0.0,
description=f"检测执行失败: {str(e)}",
suggestion="请检查系统日志或联系管理员",
)
@abstractmethod
async def _do_detect(self, context: DetectionContext) -> DetectionResult:
"""执行具体检测逻辑(子类必须实现)"""
pass
async def prepare_data(self, context: DetectionContext):
"""数据准备(可被子类重写)"""
pass
async def validate_data(self, context: DetectionContext):
"""数据校验(可被子类重写)"""
pass
async def post_process(self, result: DetectionResult, context: DetectionContext):
"""结果后处理(可被子类重写)"""
pass