deep-risk/backend/tests/services/risk_detection/engine/test_rule_engine.py
2025-12-14 20:08:27 +08:00

400 lines
12 KiB
Python

"""
规则引擎单元测试
"""
import pytest
from unittest.mock import Mock, patch, MagicMock, AsyncMock
from typing import List
from app.models.risk_detection import DetectionRule, RiskLevel
from app.services.risk_detection.engine.rule_engine import (
AlgorithmRegistry,
RuleEngine,
get_rule_engine,
)
from app.services.risk_detection.engine.execution_plan import ExecutionMode
from app.services.risk_detection.algorithms.base import RiskDetectionAlgorithm, DetectionContext, DetectionResult
class MockAlgorithm(RiskDetectionAlgorithm):
"""模拟算法用于测试"""
def __init__(self, algorithm_code: str = "MOCK_ALGO"):
super().__init__()
self._code = algorithm_code
def get_algorithm_code(self) -> str:
return self._code
def get_algorithm_name(self) -> str:
return f"模拟算法{self._code}"
async def _do_detect(self, context: DetectionContext) -> DetectionResult:
return DetectionResult(
task_id=context.task_id,
rule_id=context.rule_id,
entity_id="test_entity",
entity_type="test_type",
risk_level=RiskLevel.LOW,
risk_score=25.0,
description="模拟检测结果",
suggestion="模拟建议"
)
class TestAlgorithmRegistry:
"""算法注册表测试"""
def test_init(self):
"""测试初始化"""
registry = AlgorithmRegistry()
# 应该自动注册默认算法
assert len(registry._algorithms) > 0
def test_register(self):
"""测试注册算法"""
registry = AlgorithmRegistry()
# 清除默认算法以便测试
registry._algorithms.clear()
registry._instances.clear()
registry.register(MockAlgorithm)
assert "MOCK_ALGO" in registry._algorithms
assert registry._algorithms["MOCK_ALGO"] == MockAlgorithm
def test_register_without_get_algorithm_code(self):
"""测试注册没有get_algorithm_code方法的算法"""
registry = AlgorithmRegistry()
class InvalidAlgorithm:
pass
with pytest.raises(ValueError, match="必须实现 get_algorithm_code"):
registry.register(InvalidAlgorithm)
def test_get_algorithm(self):
"""测试获取算法实例"""
registry = AlgorithmRegistry()
# 清除默认算法以便测试
registry._algorithms.clear()
registry._instances.clear()
registry.register(MockAlgorithm)
algorithm = registry.get_algorithm("MOCK_ALGO")
assert isinstance(algorithm, MockAlgorithm)
assert algorithm.get_algorithm_code() == "MOCK_ALGO"
# 第二次获取应该是同一个实例(单例模式)
algorithm2 = registry.get_algorithm("MOCK_ALGO")
assert algorithm is algorithm2
def test_get_algorithm_not_registered(self):
"""测试获取未注册的算法"""
registry = AlgorithmRegistry()
with pytest.raises(ValueError, match="未注册的算法"):
registry.get_algorithm("NONEXISTENT")
def test_get_all_algorithms(self):
"""测试获取所有算法"""
registry = AlgorithmRegistry()
algorithms = registry.get_all_algorithms()
assert isinstance(algorithms, list)
assert len(algorithms) > 0
def test_is_registered(self):
"""测试检查算法是否已注册"""
registry = AlgorithmRegistry()
# 默认应该有算法
first_algo = registry.get_all_algorithms()[0]
assert registry.is_registered(first_algo) is True
# 不存在的算法
assert registry.is_registered("NONEXISTENT") is False
class TestRuleEngine:
"""规则引擎测试"""
def setup_method(self):
"""每个测试方法前执行"""
self.engine = RuleEngine()
# 创建模拟规则
self.rule1 = Mock(spec=DetectionRule)
self.rule1.rule_id = "rule_1"
self.rule1.algorithm_code = "REVENUE_INTEGRITY_CHECK"
self.rule1.parameters = {}
self.rule1.is_enabled = True
self.rule2 = Mock(spec=DetectionRule)
self.rule2.rule_id = "rule_2"
self.rule2.algorithm_code = "REVENUE_INTEGRITY_CHECK"
self.rule2.parameters = {}
self.rule2.is_enabled = True
def test_init(self):
"""测试初始化"""
assert self.engine is not None
assert self.engine.algorithm_registry is not None
assert self.engine.dependency_resolver is not None
assert self.engine.execution_planner is not None
assert self.engine.result_processor is not None
def test_init_with_db_session(self):
"""测试带数据库会话的初始化"""
mock_session = Mock()
engine = RuleEngine(db_session=mock_session)
assert engine.db_session == mock_session
def test_validate_rules_empty(self):
"""测试验证空规则列表"""
with pytest.raises(ValueError, match="规则列表不能为空"):
self.engine._validate_rules([])
def test_validate_rules_unregistered_algorithm(self):
"""测试验证未注册算法的规则"""
rule = Mock(spec=DetectionRule)
rule.rule_id = "rule_1"
rule.algorithm_code = "NONEXISTENT_ALGO"
rule.is_enabled = True
with pytest.raises(ValueError, match="未注册"):
self.engine._validate_rules([rule])
def test_validate_rules_disabled(self):
"""测试验证已禁用的规则"""
rule = Mock(spec=DetectionRule)
rule.rule_id = "rule_1"
rule.algorithm_code = "REVENUE_INTEGRITY_CHECK"
rule.is_enabled = False
# 不会抛出异常,只会有警告
self.engine._validate_rules([rule])
def test_validate_rules_valid(self):
"""测试验证有效规则"""
# 应该不抛出异常
self.engine._validate_rules([self.rule1])
def test_are_stage_dependencies_met(self):
"""测试检查阶段依赖"""
stage = Mock()
stage.depends_on = ["stage_0", "stage_1"]
# 依赖满足
assert self.engine._are_stage_dependencies_met(stage, ["stage_0", "stage_1"]) is True
# 依赖不满足
assert self.engine._are_stage_dependencies_met(stage, ["stage_0"]) is False
def test_execute_single_rule(self):
"""测试执行单个规则"""
import asyncio
async def test():
result = await self.engine._execute_single_rule(
self.rule1,
"rule_1"
)
assert result is not None
assert isinstance(result, DetectionResult)
assert result.rule_id == "rule_1"
asyncio.run(test())
@pytest.mark.asyncio
async def test_execute_sequential_stage(self):
"""测试执行串行阶段"""
stage = Mock()
stage.execution_mode = ExecutionMode.SEQUENTIAL
stage.nodes = [
Mock(node_id="node_1", rule=self.rule1),
Mock(node_id="node_2", rule=self.rule2),
]
results = await self.engine._execute_sequential(stage)
assert len(results) == 2
@pytest.mark.asyncio
async def test_execute_parallel_stage(self):
"""测试执行并行阶段"""
stage = Mock()
stage.execution_mode = ExecutionMode.PARALLEL
stage.nodes = [
Mock(node_id="node_1", rule=self.rule1),
Mock(node_id="node_2", rule=self.rule2),
]
results = await self.engine._execute_parallel(stage, max_concurrent=2)
assert len(results) == 2
@pytest.mark.asyncio
async def test_execute_detection(self):
"""测试执行检测逻辑"""
# 模拟执行计划
plan = Mock()
plan.stages = []
results = await self.engine._execute_detection(plan, max_concurrent=2)
assert results == []
def test_get_algorithm_info(self):
"""测试获取算法信息"""
info = self.engine.get_algorithm_info()
assert isinstance(info, dict)
assert len(info) > 0
# 检查收入完整性检测算法信息
if "REVENUE_INTEGRITY_CHECK" in info:
algo_info = info["REVENUE_INTEGRITY_CHECK"]
assert "name" in algo_info
assert "description" in algo_info
assert "code" in algo_info
assert algo_info["code"] == "REVENUE_INTEGRITY_CHECK"
def test_validate_execution_plan_valid(self):
"""测试验证有效执行计划"""
plan = Mock()
plan.stages = [
Mock(stage_id="stage_0", depends_on=[], rule_count=1),
]
plan.total_rules = 1
is_valid, error = self.engine.validate_execution_plan(plan)
assert is_valid is True
assert error is None
def test_validate_execution_plan_empty_stages(self):
"""测试验证无阶段的执行计划"""
plan = Mock()
plan.stages = []
plan.total_rules = 0
is_valid, error = self.engine.validate_execution_plan(plan)
assert is_valid is False
assert "没有阶段" in error
def test_validate_execution_plan_invalid_dependencies(self):
"""测试验证依赖无效的执行计划"""
plan = Mock()
plan.stages = [
Mock(stage_id="stage_0", depends_on=[], rule_count=1),
Mock(stage_id="stage_1", depends_on=["stage_0"], rule_count=1),
]
plan.total_rules = 2
# 模拟依赖检查
self.engine._are_stage_dependencies_met = Mock(return_value=False)
is_valid, error = self.engine.validate_execution_plan(plan)
assert is_valid is False
assert "依赖不满足" in error
def test_validate_execution_plan_inconsistent_count(self):
"""测试验证规则数量不一致的执行计划"""
plan = Mock()
plan.stages = [
Mock(stage_id="stage_0", depends_on=[], rule_count=1),
]
plan.total_rules = 2 # 不一致
# 模拟依赖检查
self.engine._are_stage_dependencies_met = Mock(return_value=True)
is_valid, error = self.engine.validate_execution_plan(plan)
assert is_valid is False
assert "规则数量不一致" in error
@pytest.mark.asyncio
async def test_dry_run(self):
"""测试试运行"""
dry_run_result = await self.engine.dry_run(
entity_id="entity_1",
entity_type="type_1",
period="2024-01",
rules=[self.rule1, self.rule2],
execution_mode=ExecutionMode.HYBRID
)
assert "dry_run" in dry_run_result
assert dry_run_result["dry_run"] is True
assert "execution_plan" in dry_run_result
assert "summary" in dry_run_result
assert "algorithm_info" in dry_run_result
def test_estimate_execution_time(self):
"""测试估算执行时间"""
plan = Mock()
plan.total_rules = 5
plan.execution_mode = ExecutionMode.SEQUENTIAL
plan.max_level = 0
time = self.engine._estimate_execution_time(plan)
assert time > 0
assert isinstance(time, float)
def test_estimate_execution_time_parallel(self):
"""测试估算并行执行时间"""
plan = Mock()
plan.total_rules = 5
plan.execution_mode = ExecutionMode.PARALLEL
plan.max_level = 0
time = self.engine._estimate_execution_time(plan)
assert time > 0
# 并行执行时间应该更短
def test_estimate_execution_time_hybrid(self):
"""测试估算混合模式执行时间"""
plan = Mock()
plan.total_rules = 5
plan.execution_mode = ExecutionMode.HYBRID
plan.max_level = 2
plan.stages = [
Mock(execution_mode=ExecutionMode.PARALLEL),
Mock(execution_mode=ExecutionMode.PARALLEL),
]
time = self.engine._estimate_execution_time(plan)
assert time > 0
class TestGetRuleEngine:
"""测试全局规则引擎实例"""
def test_get_rule_engine(self):
"""测试获取全局规则引擎实例"""
engine1 = get_rule_engine()
engine2 = get_rule_engine()
# 应该是同一个实例(单例模式)
assert engine1 is engine2
def test_get_rule_engine_initialization(self):
"""测试规则引擎初始化"""
engine = get_rule_engine()
assert engine is not None
assert isinstance(engine, RuleEngine)