442 lines
14 KiB
Python
442 lines
14 KiB
Python
"""
|
||
依赖解析器
|
||
解析检测规则之间的依赖关系,生成执行顺序
|
||
|
||
功能:
|
||
1. 构建规则依赖图(有向无环图 DAG)
|
||
2. 检测循环依赖
|
||
3. 拓扑排序生成执行顺序
|
||
4. 按层级分组(支持并行执行优化)
|
||
"""
|
||
from typing import List, Dict, Set, Optional, Tuple
|
||
from loguru import logger
|
||
from dataclasses import dataclass, field
|
||
|
||
from app.models.risk_detection import DetectionRule
|
||
|
||
|
||
@dataclass
|
||
class DependencyNode:
|
||
"""依赖图节点"""
|
||
rule_id: str
|
||
algorithm_code: str
|
||
dependencies: List[str] = field(default_factory=list) # 依赖的规则ID列表
|
||
dependents: List[str] = field(default_factory=list) # 依赖本规则的规则ID列表
|
||
level: int = 0 # 执行层级(越小越先执行)
|
||
|
||
|
||
class DependencyGraph:
|
||
"""
|
||
依赖关系图
|
||
|
||
使用邻接表表示规则依赖关系
|
||
- graph[rule_id] = 该规则的依赖列表
|
||
- reverse_graph[rule_id] = 依赖该规则的规则列表
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.nodes: Dict[str, DependencyNode] = {} # rule_id -> Node
|
||
self.graph: Dict[str, List[str]] = {} # rule_id -> [dependent_rule_ids]
|
||
self.reverse_graph: Dict[str, List[str]] = {} # rule_id -> [dependency_rule_ids]
|
||
self.in_degree: Dict[str, int] = {} # rule_id -> 入度
|
||
|
||
def add_node(self, rule_id: str, algorithm_code: str):
|
||
"""添加规则节点"""
|
||
if rule_id not in self.nodes:
|
||
self.nodes[rule_id] = DependencyNode(
|
||
rule_id=rule_id,
|
||
algorithm_code=algorithm_code
|
||
)
|
||
self.graph[rule_id] = []
|
||
self.reverse_graph[rule_id] = []
|
||
self.in_degree[rule_id] = 0
|
||
|
||
def add_edge(self, from_rule: str, to_rule: str):
|
||
"""
|
||
添加依赖边:from_rule 依赖 to_rule
|
||
即:to_rule 必须在 from_rule 之前执行
|
||
|
||
Args:
|
||
from_rule: 依赖方规则ID
|
||
to_rule: 被依赖方规则ID
|
||
"""
|
||
if from_rule not in self.nodes or to_rule not in self.nodes:
|
||
raise ValueError(f"规则不存在:{from_rule} 或 {to_rule}")
|
||
|
||
# 避免重复添加边
|
||
if to_rule not in self.nodes[from_rule].dependencies:
|
||
# 更新节点依赖关系
|
||
self.nodes[from_rule].dependencies.append(to_rule)
|
||
self.nodes[to_rule].dependents.append(from_rule)
|
||
|
||
# 更新邻接表(注意:graph表示的是"被依赖"关系)
|
||
self.graph[to_rule].append(from_rule)
|
||
self.reverse_graph[from_rule].append(to_rule)
|
||
|
||
# 更新入度
|
||
self.in_degree[from_rule] += 1
|
||
|
||
def get_node(self, rule_id: str) -> Optional[DependencyNode]:
|
||
"""获取节点"""
|
||
return self.nodes.get(rule_id)
|
||
|
||
def get_all_nodes(self) -> List[DependencyNode]:
|
||
"""获取所有节点"""
|
||
return list(self.nodes.values())
|
||
|
||
def has_cycle(self) -> Tuple[bool, Optional[List[str]]]:
|
||
"""
|
||
检测是否有循环依赖(使用Kahn算法)
|
||
|
||
Returns:
|
||
(has_cycle, cycle_path)
|
||
- has_cycle: 是否存在循环
|
||
- cycle_path: 循环路径(如果存在)
|
||
"""
|
||
# 复制入度(避免修改原数据)
|
||
in_degree_copy = self.in_degree.copy()
|
||
|
||
# 队列初始化:入度为0的节点
|
||
queue = [rule_id for rule_id, degree in in_degree_copy.items() if degree == 0]
|
||
processed_count = 0
|
||
|
||
while queue:
|
||
rule_id = queue.pop(0)
|
||
processed_count += 1
|
||
|
||
# 处理该节点的所有出边
|
||
for dependent in self.graph[rule_id]:
|
||
in_degree_copy[dependent] -= 1
|
||
if in_degree_copy[dependent] == 0:
|
||
queue.append(dependent)
|
||
|
||
# 如果处理的节点数少于总节点数,说明存在循环
|
||
if processed_count < len(self.nodes):
|
||
# 找出循环中的节点
|
||
cycle_nodes = [
|
||
rule_id for rule_id, degree in in_degree_copy.items()
|
||
if degree > 0
|
||
]
|
||
|
||
# 尝试找出一条循环路径
|
||
cycle_path = self._find_cycle_path(cycle_nodes)
|
||
|
||
return True, cycle_path
|
||
|
||
return False, None
|
||
|
||
def _find_cycle_path(self, cycle_nodes: List[str]) -> List[str]:
|
||
"""
|
||
在给定的节点集合中找出一条循环路径
|
||
|
||
使用DFS寻找环
|
||
"""
|
||
visited = set()
|
||
path = []
|
||
cycle_path = []
|
||
|
||
def dfs(node: str) -> bool:
|
||
"""DFS寻找环,返回True表示找到环"""
|
||
if node in path:
|
||
# 找到环,记录从当前节点开始的环路径
|
||
cycle_start_idx = path.index(node)
|
||
nonlocal cycle_path
|
||
cycle_path = path[cycle_start_idx:] + [node]
|
||
return True
|
||
|
||
if node in visited:
|
||
return False
|
||
|
||
visited.add(node)
|
||
path.append(node)
|
||
|
||
# 只在cycle_nodes中的依赖关系中搜索
|
||
for dep in self.reverse_graph.get(node, []):
|
||
if dep in cycle_nodes:
|
||
if dfs(dep):
|
||
return True
|
||
|
||
path.pop()
|
||
return False
|
||
|
||
# 从cycle_nodes中的任一节点开始DFS
|
||
for start_node in cycle_nodes:
|
||
if start_node not in visited:
|
||
if dfs(start_node):
|
||
return cycle_path
|
||
|
||
return cycle_nodes # 如果未找到具体路径,返回所有循环节点
|
||
|
||
def calculate_levels(self):
|
||
"""
|
||
计算每个规则的执行层级
|
||
|
||
使用BFS从入度为0的节点开始,层层推进
|
||
层级越小,越先执行
|
||
"""
|
||
# 复制入度
|
||
in_degree_copy = self.in_degree.copy()
|
||
|
||
# 初始化队列:入度为0的节点在第0层
|
||
queue = [(rule_id, 0) for rule_id, degree in in_degree_copy.items() if degree == 0]
|
||
|
||
while queue:
|
||
rule_id, level = queue.pop(0)
|
||
|
||
# 设置节点层级
|
||
self.nodes[rule_id].level = level
|
||
|
||
# 处理依赖该节点的所有节点
|
||
for dependent in self.graph[rule_id]:
|
||
in_degree_copy[dependent] -= 1
|
||
if in_degree_copy[dependent] == 0:
|
||
# 依赖方的层级 = 被依赖方的层级 + 1
|
||
queue.append((dependent, level + 1))
|
||
|
||
def get_levels(self) -> List[List[str]]:
|
||
"""
|
||
获取按层级分组的规则列表
|
||
|
||
Returns:
|
||
[[level_0_rules], [level_1_rules], ...]
|
||
同一层级的规则可以并行执行
|
||
"""
|
||
# 按层级分组
|
||
level_dict: Dict[int, List[str]] = {}
|
||
|
||
for rule_id, node in self.nodes.items():
|
||
level = node.level
|
||
if level not in level_dict:
|
||
level_dict[level] = []
|
||
level_dict[level].append(rule_id)
|
||
|
||
# 按层级顺序返回
|
||
max_level = max(level_dict.keys()) if level_dict else 0
|
||
return [level_dict.get(i, []) for i in range(max_level + 1)]
|
||
|
||
def to_dict(self) -> Dict:
|
||
"""转换为字典(用于序列化)"""
|
||
return {
|
||
"nodes": [
|
||
{
|
||
"rule_id": node.rule_id,
|
||
"algorithm_code": node.algorithm_code,
|
||
"dependencies": node.dependencies,
|
||
"dependents": node.dependents,
|
||
"level": node.level,
|
||
}
|
||
for node in self.nodes.values()
|
||
],
|
||
"levels": self.get_levels(),
|
||
}
|
||
|
||
|
||
class DependencyResolver:
|
||
"""
|
||
依赖解析器
|
||
|
||
负责分析规则依赖关系并生成执行层级
|
||
"""
|
||
|
||
# 预定义的算法依赖关系(可选,用于快速配置)
|
||
DEFAULT_ALGORITHM_DEPENDENCIES = {
|
||
"REVENUE_INTEGRITY_CHECK": [], # 收入完整性检测无前置依赖
|
||
"PRIVATE_ACCOUNT_DETECTION": [], # 私户收款检测无前置依赖
|
||
"INVOICE_FRAUD_DETECTION": [], # 发票虚开检测无前置依赖
|
||
"EXPENSE_ANOMALY_DETECTION": [], # 费用异常检测无前置依赖
|
||
"TAX_RATE_CHECK": [], # 税率错误检测无前置依赖
|
||
"TAX_RISK_ASSESSMENT": [ # 综合评估依赖前述所有检测
|
||
"REVENUE_INTEGRITY_CHECK",
|
||
"PRIVATE_ACCOUNT_DETECTION",
|
||
"INVOICE_FRAUD_DETECTION",
|
||
"EXPENSE_ANOMALY_DETECTION",
|
||
"TAX_RATE_CHECK",
|
||
],
|
||
}
|
||
|
||
def __init__(self):
|
||
self.graph: Optional[DependencyGraph] = None
|
||
|
||
def analyze(
|
||
self,
|
||
rules: List[DetectionRule],
|
||
use_default_dependencies: bool = True
|
||
) -> DependencyGraph:
|
||
"""
|
||
分析规则依赖关系,返回依赖图
|
||
|
||
Args:
|
||
rules: 规则列表
|
||
use_default_dependencies: 是否使用默认依赖关系
|
||
|
||
Returns:
|
||
DependencyGraph
|
||
|
||
Raises:
|
||
ValueError: 如果检测到循环依赖
|
||
"""
|
||
logger.info(f"开始分析 {len(rules)} 个规则的依赖关系")
|
||
|
||
# 1. 创建依赖图
|
||
graph = DependencyGraph()
|
||
|
||
# 2. 添加所有规则节点
|
||
rule_dict = {} # algorithm_code -> rule_id
|
||
for rule in rules:
|
||
graph.add_node(rule.rule_id, rule.algorithm_code)
|
||
rule_dict[rule.algorithm_code] = rule.rule_id
|
||
|
||
# 3. 添加依赖边
|
||
for rule in rules:
|
||
# 从规则参数中提取依赖
|
||
dependencies = self._extract_dependencies(
|
||
rule,
|
||
rule_dict,
|
||
use_default_dependencies
|
||
)
|
||
|
||
for dep_rule_id in dependencies:
|
||
if dep_rule_id in graph.nodes:
|
||
graph.add_edge(rule.rule_id, dep_rule_id)
|
||
logger.debug(
|
||
f"规则 {rule.rule_id} 依赖 {dep_rule_id}"
|
||
)
|
||
else:
|
||
logger.warning(
|
||
f"规则 {rule.rule_id} 依赖的规则 {dep_rule_id} 不存在,已忽略"
|
||
)
|
||
|
||
# 4. 检测循环依赖
|
||
has_cycle, cycle_path = graph.has_cycle()
|
||
if has_cycle:
|
||
cycle_str = " -> ".join(cycle_path) if cycle_path else "未知"
|
||
error_msg = f"检测到循环依赖,无法生成执行计划:{cycle_str}"
|
||
logger.error(error_msg)
|
||
raise ValueError(error_msg)
|
||
|
||
# 5. 计算层级
|
||
graph.calculate_levels()
|
||
|
||
# 6. 记录结果
|
||
levels = graph.get_levels()
|
||
logger.info(
|
||
f"依赖分析完成,共 {len(levels)} 个执行层级,"
|
||
f"最大并行度 {max(len(level) for level in levels) if levels else 0}"
|
||
)
|
||
for i, level in enumerate(levels):
|
||
logger.info(f" 层级 {i}:{len(level)} 个规则 - {level}")
|
||
|
||
self.graph = graph
|
||
return graph
|
||
|
||
def _extract_dependencies(
|
||
self,
|
||
rule: DetectionRule,
|
||
rule_dict: Dict[str, str],
|
||
use_default_dependencies: bool
|
||
) -> List[str]:
|
||
"""
|
||
从规则参数中提取依赖的规则ID列表
|
||
|
||
优先级:
|
||
1. 规则参数中的 "dependencies" 字段
|
||
2. 默认算法依赖关系(如果启用)
|
||
3. 空列表
|
||
|
||
Args:
|
||
rule: 检测规则
|
||
rule_dict: algorithm_code -> rule_id 映射
|
||
use_default_dependencies: 是否使用默认依赖
|
||
|
||
Returns:
|
||
依赖的规则ID列表
|
||
"""
|
||
dependencies = []
|
||
|
||
# 1. 从规则参数中提取
|
||
if rule.parameters and "dependencies" in rule.parameters:
|
||
param_deps = rule.parameters["dependencies"]
|
||
|
||
if isinstance(param_deps, list):
|
||
for dep in param_deps:
|
||
if isinstance(dep, str):
|
||
# 支持两种格式:rule_id 或 algorithm_code
|
||
if dep.startswith("RULE_"):
|
||
# 直接使用rule_id
|
||
dependencies.append(dep)
|
||
else:
|
||
# 通过algorithm_code查找rule_id
|
||
dep_rule_id = rule_dict.get(dep)
|
||
if dep_rule_id:
|
||
dependencies.append(dep_rule_id)
|
||
else:
|
||
logger.warning(
|
||
f"规则 {rule.rule_id} 依赖的算法 {dep} 不存在"
|
||
)
|
||
|
||
# 2. 如果没有显式依赖,使用默认依赖
|
||
elif use_default_dependencies:
|
||
default_deps = self.DEFAULT_ALGORITHM_DEPENDENCIES.get(
|
||
rule.algorithm_code, []
|
||
)
|
||
|
||
for dep_algorithm in default_deps:
|
||
dep_rule_id = rule_dict.get(dep_algorithm)
|
||
if dep_rule_id:
|
||
dependencies.append(dep_rule_id)
|
||
|
||
return dependencies
|
||
|
||
def get_execution_order(self) -> List[str]:
|
||
"""
|
||
获取规则的串行执行顺序(拓扑排序)
|
||
|
||
Returns:
|
||
规则ID列表,按执行顺序排列
|
||
|
||
Raises:
|
||
ValueError: 如果未进行依赖分析
|
||
"""
|
||
if not self.graph:
|
||
raise ValueError("尚未进行依赖分析,请先调用 analyze() 方法")
|
||
|
||
# 按层级展开为串行顺序
|
||
levels = self.graph.get_levels()
|
||
return [rule_id for level in levels for rule_id in level]
|
||
|
||
def get_execution_levels(self) -> List[List[str]]:
|
||
"""
|
||
获取规则的并行执行层级
|
||
|
||
Returns:
|
||
[[level_0_rules], [level_1_rules], ...]
|
||
同一层级的规则可以并行执行
|
||
|
||
Raises:
|
||
ValueError: 如果未进行依赖分析
|
||
"""
|
||
if not self.graph:
|
||
raise ValueError("尚未进行依赖分析,请先调用 analyze() 方法")
|
||
|
||
return self.graph.get_levels()
|
||
|
||
def validate(self, rules: List[DetectionRule]) -> Tuple[bool, Optional[str]]:
|
||
"""
|
||
验证规则依赖关系的有效性
|
||
|
||
Args:
|
||
rules: 规则列表
|
||
|
||
Returns:
|
||
(is_valid, error_message)
|
||
"""
|
||
try:
|
||
# 尝试分析依赖关系
|
||
self.analyze(rules)
|
||
return True, None
|
||
except ValueError as e:
|
||
return False, str(e)
|
||
except Exception as e:
|
||
logger.error(f"依赖验证失败: {str(e)}", exc_info=True)
|
||
return False, f"依赖验证失败: {str(e)}"
|