deep-risk/backend/app/services/risk_detection/engine/dependency_resolver.py
2025-12-14 20:08:27 +08:00

442 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
依赖解析器
解析检测规则之间的依赖关系,生成执行顺序
功能:
1. 构建规则依赖图(有向无环图 DAG
2. 检测循环依赖
3. 拓扑排序生成执行顺序
4. 按层级分组(支持并行执行优化)
"""
from typing import List, Dict, Set, Optional, Tuple
from loguru import logger
from dataclasses import dataclass, field
from app.models.risk_detection import DetectionRule
@dataclass
class DependencyNode:
"""依赖图节点"""
rule_id: str
algorithm_code: str
dependencies: List[str] = field(default_factory=list) # 依赖的规则ID列表
dependents: List[str] = field(default_factory=list) # 依赖本规则的规则ID列表
level: int = 0 # 执行层级(越小越先执行)
class DependencyGraph:
"""
依赖关系图
使用邻接表表示规则依赖关系
- graph[rule_id] = 该规则的依赖列表
- reverse_graph[rule_id] = 依赖该规则的规则列表
"""
def __init__(self):
self.nodes: Dict[str, DependencyNode] = {} # rule_id -> Node
self.graph: Dict[str, List[str]] = {} # rule_id -> [dependent_rule_ids]
self.reverse_graph: Dict[str, List[str]] = {} # rule_id -> [dependency_rule_ids]
self.in_degree: Dict[str, int] = {} # rule_id -> 入度
def add_node(self, rule_id: str, algorithm_code: str):
"""添加规则节点"""
if rule_id not in self.nodes:
self.nodes[rule_id] = DependencyNode(
rule_id=rule_id,
algorithm_code=algorithm_code
)
self.graph[rule_id] = []
self.reverse_graph[rule_id] = []
self.in_degree[rule_id] = 0
def add_edge(self, from_rule: str, to_rule: str):
"""
添加依赖边from_rule 依赖 to_rule
to_rule 必须在 from_rule 之前执行
Args:
from_rule: 依赖方规则ID
to_rule: 被依赖方规则ID
"""
if from_rule not in self.nodes or to_rule not in self.nodes:
raise ValueError(f"规则不存在:{from_rule}{to_rule}")
# 避免重复添加边
if to_rule not in self.nodes[from_rule].dependencies:
# 更新节点依赖关系
self.nodes[from_rule].dependencies.append(to_rule)
self.nodes[to_rule].dependents.append(from_rule)
# 更新邻接表注意graph表示的是"被依赖"关系)
self.graph[to_rule].append(from_rule)
self.reverse_graph[from_rule].append(to_rule)
# 更新入度
self.in_degree[from_rule] += 1
def get_node(self, rule_id: str) -> Optional[DependencyNode]:
"""获取节点"""
return self.nodes.get(rule_id)
def get_all_nodes(self) -> List[DependencyNode]:
"""获取所有节点"""
return list(self.nodes.values())
def has_cycle(self) -> Tuple[bool, Optional[List[str]]]:
"""
检测是否有循环依赖使用Kahn算法
Returns:
(has_cycle, cycle_path)
- has_cycle: 是否存在循环
- cycle_path: 循环路径(如果存在)
"""
# 复制入度(避免修改原数据)
in_degree_copy = self.in_degree.copy()
# 队列初始化入度为0的节点
queue = [rule_id for rule_id, degree in in_degree_copy.items() if degree == 0]
processed_count = 0
while queue:
rule_id = queue.pop(0)
processed_count += 1
# 处理该节点的所有出边
for dependent in self.graph[rule_id]:
in_degree_copy[dependent] -= 1
if in_degree_copy[dependent] == 0:
queue.append(dependent)
# 如果处理的节点数少于总节点数,说明存在循环
if processed_count < len(self.nodes):
# 找出循环中的节点
cycle_nodes = [
rule_id for rule_id, degree in in_degree_copy.items()
if degree > 0
]
# 尝试找出一条循环路径
cycle_path = self._find_cycle_path(cycle_nodes)
return True, cycle_path
return False, None
def _find_cycle_path(self, cycle_nodes: List[str]) -> List[str]:
"""
在给定的节点集合中找出一条循环路径
使用DFS寻找环
"""
visited = set()
path = []
cycle_path = []
def dfs(node: str) -> bool:
"""DFS寻找环返回True表示找到环"""
if node in path:
# 找到环,记录从当前节点开始的环路径
cycle_start_idx = path.index(node)
nonlocal cycle_path
cycle_path = path[cycle_start_idx:] + [node]
return True
if node in visited:
return False
visited.add(node)
path.append(node)
# 只在cycle_nodes中的依赖关系中搜索
for dep in self.reverse_graph.get(node, []):
if dep in cycle_nodes:
if dfs(dep):
return True
path.pop()
return False
# 从cycle_nodes中的任一节点开始DFS
for start_node in cycle_nodes:
if start_node not in visited:
if dfs(start_node):
return cycle_path
return cycle_nodes # 如果未找到具体路径,返回所有循环节点
def calculate_levels(self):
"""
计算每个规则的执行层级
使用BFS从入度为0的节点开始层层推进
层级越小,越先执行
"""
# 复制入度
in_degree_copy = self.in_degree.copy()
# 初始化队列入度为0的节点在第0层
queue = [(rule_id, 0) for rule_id, degree in in_degree_copy.items() if degree == 0]
while queue:
rule_id, level = queue.pop(0)
# 设置节点层级
self.nodes[rule_id].level = level
# 处理依赖该节点的所有节点
for dependent in self.graph[rule_id]:
in_degree_copy[dependent] -= 1
if in_degree_copy[dependent] == 0:
# 依赖方的层级 = 被依赖方的层级 + 1
queue.append((dependent, level + 1))
def get_levels(self) -> List[List[str]]:
"""
获取按层级分组的规则列表
Returns:
[[level_0_rules], [level_1_rules], ...]
同一层级的规则可以并行执行
"""
# 按层级分组
level_dict: Dict[int, List[str]] = {}
for rule_id, node in self.nodes.items():
level = node.level
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(rule_id)
# 按层级顺序返回
max_level = max(level_dict.keys()) if level_dict else 0
return [level_dict.get(i, []) for i in range(max_level + 1)]
def to_dict(self) -> Dict:
"""转换为字典(用于序列化)"""
return {
"nodes": [
{
"rule_id": node.rule_id,
"algorithm_code": node.algorithm_code,
"dependencies": node.dependencies,
"dependents": node.dependents,
"level": node.level,
}
for node in self.nodes.values()
],
"levels": self.get_levels(),
}
class DependencyResolver:
"""
依赖解析器
负责分析规则依赖关系并生成执行层级
"""
# 预定义的算法依赖关系(可选,用于快速配置)
DEFAULT_ALGORITHM_DEPENDENCIES = {
"REVENUE_INTEGRITY_CHECK": [], # 收入完整性检测无前置依赖
"PRIVATE_ACCOUNT_DETECTION": [], # 私户收款检测无前置依赖
"INVOICE_FRAUD_DETECTION": [], # 发票虚开检测无前置依赖
"EXPENSE_ANOMALY_DETECTION": [], # 费用异常检测无前置依赖
"TAX_RATE_CHECK": [], # 税率错误检测无前置依赖
"TAX_RISK_ASSESSMENT": [ # 综合评估依赖前述所有检测
"REVENUE_INTEGRITY_CHECK",
"PRIVATE_ACCOUNT_DETECTION",
"INVOICE_FRAUD_DETECTION",
"EXPENSE_ANOMALY_DETECTION",
"TAX_RATE_CHECK",
],
}
def __init__(self):
self.graph: Optional[DependencyGraph] = None
def analyze(
self,
rules: List[DetectionRule],
use_default_dependencies: bool = True
) -> DependencyGraph:
"""
分析规则依赖关系,返回依赖图
Args:
rules: 规则列表
use_default_dependencies: 是否使用默认依赖关系
Returns:
DependencyGraph
Raises:
ValueError: 如果检测到循环依赖
"""
logger.info(f"开始分析 {len(rules)} 个规则的依赖关系")
# 1. 创建依赖图
graph = DependencyGraph()
# 2. 添加所有规则节点
rule_dict = {} # algorithm_code -> rule_id
for rule in rules:
graph.add_node(rule.rule_id, rule.algorithm_code)
rule_dict[rule.algorithm_code] = rule.rule_id
# 3. 添加依赖边
for rule in rules:
# 从规则参数中提取依赖
dependencies = self._extract_dependencies(
rule,
rule_dict,
use_default_dependencies
)
for dep_rule_id in dependencies:
if dep_rule_id in graph.nodes:
graph.add_edge(rule.rule_id, dep_rule_id)
logger.debug(
f"规则 {rule.rule_id} 依赖 {dep_rule_id}"
)
else:
logger.warning(
f"规则 {rule.rule_id} 依赖的规则 {dep_rule_id} 不存在,已忽略"
)
# 4. 检测循环依赖
has_cycle, cycle_path = graph.has_cycle()
if has_cycle:
cycle_str = " -> ".join(cycle_path) if cycle_path else "未知"
error_msg = f"检测到循环依赖,无法生成执行计划:{cycle_str}"
logger.error(error_msg)
raise ValueError(error_msg)
# 5. 计算层级
graph.calculate_levels()
# 6. 记录结果
levels = graph.get_levels()
logger.info(
f"依赖分析完成,共 {len(levels)} 个执行层级,"
f"最大并行度 {max(len(level) for level in levels) if levels else 0}"
)
for i, level in enumerate(levels):
logger.info(f" 层级 {i}{len(level)} 个规则 - {level}")
self.graph = graph
return graph
def _extract_dependencies(
self,
rule: DetectionRule,
rule_dict: Dict[str, str],
use_default_dependencies: bool
) -> List[str]:
"""
从规则参数中提取依赖的规则ID列表
优先级:
1. 规则参数中的 "dependencies" 字段
2. 默认算法依赖关系(如果启用)
3. 空列表
Args:
rule: 检测规则
rule_dict: algorithm_code -> rule_id 映射
use_default_dependencies: 是否使用默认依赖
Returns:
依赖的规则ID列表
"""
dependencies = []
# 1. 从规则参数中提取
if rule.parameters and "dependencies" in rule.parameters:
param_deps = rule.parameters["dependencies"]
if isinstance(param_deps, list):
for dep in param_deps:
if isinstance(dep, str):
# 支持两种格式rule_id 或 algorithm_code
if dep.startswith("RULE_"):
# 直接使用rule_id
dependencies.append(dep)
else:
# 通过algorithm_code查找rule_id
dep_rule_id = rule_dict.get(dep)
if dep_rule_id:
dependencies.append(dep_rule_id)
else:
logger.warning(
f"规则 {rule.rule_id} 依赖的算法 {dep} 不存在"
)
# 2. 如果没有显式依赖,使用默认依赖
elif use_default_dependencies:
default_deps = self.DEFAULT_ALGORITHM_DEPENDENCIES.get(
rule.algorithm_code, []
)
for dep_algorithm in default_deps:
dep_rule_id = rule_dict.get(dep_algorithm)
if dep_rule_id:
dependencies.append(dep_rule_id)
return dependencies
def get_execution_order(self) -> List[str]:
"""
获取规则的串行执行顺序(拓扑排序)
Returns:
规则ID列表按执行顺序排列
Raises:
ValueError: 如果未进行依赖分析
"""
if not self.graph:
raise ValueError("尚未进行依赖分析,请先调用 analyze() 方法")
# 按层级展开为串行顺序
levels = self.graph.get_levels()
return [rule_id for level in levels for rule_id in level]
def get_execution_levels(self) -> List[List[str]]:
"""
获取规则的并行执行层级
Returns:
[[level_0_rules], [level_1_rules], ...]
同一层级的规则可以并行执行
Raises:
ValueError: 如果未进行依赖分析
"""
if not self.graph:
raise ValueError("尚未进行依赖分析,请先调用 analyze() 方法")
return self.graph.get_levels()
def validate(self, rules: List[DetectionRule]) -> Tuple[bool, Optional[str]]:
"""
验证规则依赖关系的有效性
Args:
rules: 规则列表
Returns:
(is_valid, error_message)
"""
try:
# 尝试分析依赖关系
self.analyze(rules)
return True, None
except ValueError as e:
return False, str(e)
except Exception as e:
logger.error(f"依赖验证失败: {str(e)}", exc_info=True)
return False, f"依赖验证失败: {str(e)}"