400 lines
12 KiB
Python
400 lines
12 KiB
Python
"""
|
|
规则引擎单元测试
|
|
"""
|
|
import pytest
|
|
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
|
from typing import List
|
|
|
|
from app.models.risk_detection import DetectionRule, RiskLevel
|
|
from app.services.risk_detection.engine.rule_engine import (
|
|
AlgorithmRegistry,
|
|
RuleEngine,
|
|
get_rule_engine,
|
|
)
|
|
from app.services.risk_detection.engine.execution_plan import ExecutionMode
|
|
from app.services.risk_detection.algorithms.base import RiskDetectionAlgorithm, DetectionContext, DetectionResult
|
|
|
|
|
|
class MockAlgorithm(RiskDetectionAlgorithm):
|
|
"""模拟算法用于测试"""
|
|
|
|
def __init__(self, algorithm_code: str = "MOCK_ALGO"):
|
|
super().__init__()
|
|
self._code = algorithm_code
|
|
|
|
def get_algorithm_code(self) -> str:
|
|
return self._code
|
|
|
|
def get_algorithm_name(self) -> str:
|
|
return f"模拟算法{self._code}"
|
|
|
|
async def _do_detect(self, context: DetectionContext) -> DetectionResult:
|
|
return DetectionResult(
|
|
task_id=context.task_id,
|
|
rule_id=context.rule_id,
|
|
entity_id="test_entity",
|
|
entity_type="test_type",
|
|
risk_level=RiskLevel.LOW,
|
|
risk_score=25.0,
|
|
description="模拟检测结果",
|
|
suggestion="模拟建议"
|
|
)
|
|
|
|
|
|
class TestAlgorithmRegistry:
|
|
"""算法注册表测试"""
|
|
|
|
def test_init(self):
|
|
"""测试初始化"""
|
|
registry = AlgorithmRegistry()
|
|
# 应该自动注册默认算法
|
|
assert len(registry._algorithms) > 0
|
|
|
|
def test_register(self):
|
|
"""测试注册算法"""
|
|
registry = AlgorithmRegistry()
|
|
|
|
# 清除默认算法以便测试
|
|
registry._algorithms.clear()
|
|
registry._instances.clear()
|
|
|
|
registry.register(MockAlgorithm)
|
|
|
|
assert "MOCK_ALGO" in registry._algorithms
|
|
assert registry._algorithms["MOCK_ALGO"] == MockAlgorithm
|
|
|
|
def test_register_without_get_algorithm_code(self):
|
|
"""测试注册没有get_algorithm_code方法的算法"""
|
|
registry = AlgorithmRegistry()
|
|
|
|
class InvalidAlgorithm:
|
|
pass
|
|
|
|
with pytest.raises(ValueError, match="必须实现 get_algorithm_code"):
|
|
registry.register(InvalidAlgorithm)
|
|
|
|
def test_get_algorithm(self):
|
|
"""测试获取算法实例"""
|
|
registry = AlgorithmRegistry()
|
|
|
|
# 清除默认算法以便测试
|
|
registry._algorithms.clear()
|
|
registry._instances.clear()
|
|
|
|
registry.register(MockAlgorithm)
|
|
|
|
algorithm = registry.get_algorithm("MOCK_ALGO")
|
|
|
|
assert isinstance(algorithm, MockAlgorithm)
|
|
assert algorithm.get_algorithm_code() == "MOCK_ALGO"
|
|
|
|
# 第二次获取应该是同一个实例(单例模式)
|
|
algorithm2 = registry.get_algorithm("MOCK_ALGO")
|
|
assert algorithm is algorithm2
|
|
|
|
def test_get_algorithm_not_registered(self):
|
|
"""测试获取未注册的算法"""
|
|
registry = AlgorithmRegistry()
|
|
|
|
with pytest.raises(ValueError, match="未注册的算法"):
|
|
registry.get_algorithm("NONEXISTENT")
|
|
|
|
def test_get_all_algorithms(self):
|
|
"""测试获取所有算法"""
|
|
registry = AlgorithmRegistry()
|
|
|
|
algorithms = registry.get_all_algorithms()
|
|
|
|
assert isinstance(algorithms, list)
|
|
assert len(algorithms) > 0
|
|
|
|
def test_is_registered(self):
|
|
"""测试检查算法是否已注册"""
|
|
registry = AlgorithmRegistry()
|
|
|
|
# 默认应该有算法
|
|
first_algo = registry.get_all_algorithms()[0]
|
|
assert registry.is_registered(first_algo) is True
|
|
|
|
# 不存在的算法
|
|
assert registry.is_registered("NONEXISTENT") is False
|
|
|
|
|
|
class TestRuleEngine:
|
|
"""规则引擎测试"""
|
|
|
|
def setup_method(self):
|
|
"""每个测试方法前执行"""
|
|
self.engine = RuleEngine()
|
|
|
|
# 创建模拟规则
|
|
self.rule1 = Mock(spec=DetectionRule)
|
|
self.rule1.rule_id = "rule_1"
|
|
self.rule1.algorithm_code = "REVENUE_INTEGRITY_CHECK"
|
|
self.rule1.parameters = {}
|
|
self.rule1.is_enabled = True
|
|
|
|
self.rule2 = Mock(spec=DetectionRule)
|
|
self.rule2.rule_id = "rule_2"
|
|
self.rule2.algorithm_code = "REVENUE_INTEGRITY_CHECK"
|
|
self.rule2.parameters = {}
|
|
self.rule2.is_enabled = True
|
|
|
|
def test_init(self):
|
|
"""测试初始化"""
|
|
assert self.engine is not None
|
|
assert self.engine.algorithm_registry is not None
|
|
assert self.engine.dependency_resolver is not None
|
|
assert self.engine.execution_planner is not None
|
|
assert self.engine.result_processor is not None
|
|
|
|
def test_init_with_db_session(self):
|
|
"""测试带数据库会话的初始化"""
|
|
mock_session = Mock()
|
|
engine = RuleEngine(db_session=mock_session)
|
|
|
|
assert engine.db_session == mock_session
|
|
|
|
def test_validate_rules_empty(self):
|
|
"""测试验证空规则列表"""
|
|
with pytest.raises(ValueError, match="规则列表不能为空"):
|
|
self.engine._validate_rules([])
|
|
|
|
def test_validate_rules_unregistered_algorithm(self):
|
|
"""测试验证未注册算法的规则"""
|
|
rule = Mock(spec=DetectionRule)
|
|
rule.rule_id = "rule_1"
|
|
rule.algorithm_code = "NONEXISTENT_ALGO"
|
|
rule.is_enabled = True
|
|
|
|
with pytest.raises(ValueError, match="未注册"):
|
|
self.engine._validate_rules([rule])
|
|
|
|
def test_validate_rules_disabled(self):
|
|
"""测试验证已禁用的规则"""
|
|
rule = Mock(spec=DetectionRule)
|
|
rule.rule_id = "rule_1"
|
|
rule.algorithm_code = "REVENUE_INTEGRITY_CHECK"
|
|
rule.is_enabled = False
|
|
|
|
# 不会抛出异常,只会有警告
|
|
self.engine._validate_rules([rule])
|
|
|
|
def test_validate_rules_valid(self):
|
|
"""测试验证有效规则"""
|
|
# 应该不抛出异常
|
|
self.engine._validate_rules([self.rule1])
|
|
|
|
def test_are_stage_dependencies_met(self):
|
|
"""测试检查阶段依赖"""
|
|
stage = Mock()
|
|
stage.depends_on = ["stage_0", "stage_1"]
|
|
|
|
# 依赖满足
|
|
assert self.engine._are_stage_dependencies_met(stage, ["stage_0", "stage_1"]) is True
|
|
|
|
# 依赖不满足
|
|
assert self.engine._are_stage_dependencies_met(stage, ["stage_0"]) is False
|
|
|
|
def test_execute_single_rule(self):
|
|
"""测试执行单个规则"""
|
|
import asyncio
|
|
|
|
async def test():
|
|
result = await self.engine._execute_single_rule(
|
|
self.rule1,
|
|
"rule_1"
|
|
)
|
|
|
|
assert result is not None
|
|
assert isinstance(result, DetectionResult)
|
|
assert result.rule_id == "rule_1"
|
|
|
|
asyncio.run(test())
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_sequential_stage(self):
|
|
"""测试执行串行阶段"""
|
|
stage = Mock()
|
|
stage.execution_mode = ExecutionMode.SEQUENTIAL
|
|
stage.nodes = [
|
|
Mock(node_id="node_1", rule=self.rule1),
|
|
Mock(node_id="node_2", rule=self.rule2),
|
|
]
|
|
|
|
results = await self.engine._execute_sequential(stage)
|
|
|
|
assert len(results) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_parallel_stage(self):
|
|
"""测试执行并行阶段"""
|
|
stage = Mock()
|
|
stage.execution_mode = ExecutionMode.PARALLEL
|
|
stage.nodes = [
|
|
Mock(node_id="node_1", rule=self.rule1),
|
|
Mock(node_id="node_2", rule=self.rule2),
|
|
]
|
|
|
|
results = await self.engine._execute_parallel(stage, max_concurrent=2)
|
|
|
|
assert len(results) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_detection(self):
|
|
"""测试执行检测逻辑"""
|
|
# 模拟执行计划
|
|
plan = Mock()
|
|
plan.stages = []
|
|
|
|
results = await self.engine._execute_detection(plan, max_concurrent=2)
|
|
|
|
assert results == []
|
|
|
|
def test_get_algorithm_info(self):
|
|
"""测试获取算法信息"""
|
|
info = self.engine.get_algorithm_info()
|
|
|
|
assert isinstance(info, dict)
|
|
assert len(info) > 0
|
|
|
|
# 检查收入完整性检测算法信息
|
|
if "REVENUE_INTEGRITY_CHECK" in info:
|
|
algo_info = info["REVENUE_INTEGRITY_CHECK"]
|
|
assert "name" in algo_info
|
|
assert "description" in algo_info
|
|
assert "code" in algo_info
|
|
assert algo_info["code"] == "REVENUE_INTEGRITY_CHECK"
|
|
|
|
def test_validate_execution_plan_valid(self):
|
|
"""测试验证有效执行计划"""
|
|
plan = Mock()
|
|
plan.stages = [
|
|
Mock(stage_id="stage_0", depends_on=[], rule_count=1),
|
|
]
|
|
plan.total_rules = 1
|
|
|
|
is_valid, error = self.engine.validate_execution_plan(plan)
|
|
|
|
assert is_valid is True
|
|
assert error is None
|
|
|
|
def test_validate_execution_plan_empty_stages(self):
|
|
"""测试验证无阶段的执行计划"""
|
|
plan = Mock()
|
|
plan.stages = []
|
|
plan.total_rules = 0
|
|
|
|
is_valid, error = self.engine.validate_execution_plan(plan)
|
|
|
|
assert is_valid is False
|
|
assert "没有阶段" in error
|
|
|
|
def test_validate_execution_plan_invalid_dependencies(self):
|
|
"""测试验证依赖无效的执行计划"""
|
|
plan = Mock()
|
|
plan.stages = [
|
|
Mock(stage_id="stage_0", depends_on=[], rule_count=1),
|
|
Mock(stage_id="stage_1", depends_on=["stage_0"], rule_count=1),
|
|
]
|
|
plan.total_rules = 2
|
|
|
|
# 模拟依赖检查
|
|
self.engine._are_stage_dependencies_met = Mock(return_value=False)
|
|
|
|
is_valid, error = self.engine.validate_execution_plan(plan)
|
|
|
|
assert is_valid is False
|
|
assert "依赖不满足" in error
|
|
|
|
def test_validate_execution_plan_inconsistent_count(self):
|
|
"""测试验证规则数量不一致的执行计划"""
|
|
plan = Mock()
|
|
plan.stages = [
|
|
Mock(stage_id="stage_0", depends_on=[], rule_count=1),
|
|
]
|
|
plan.total_rules = 2 # 不一致
|
|
|
|
# 模拟依赖检查
|
|
self.engine._are_stage_dependencies_met = Mock(return_value=True)
|
|
|
|
is_valid, error = self.engine.validate_execution_plan(plan)
|
|
|
|
assert is_valid is False
|
|
assert "规则数量不一致" in error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dry_run(self):
|
|
"""测试试运行"""
|
|
dry_run_result = await self.engine.dry_run(
|
|
entity_id="entity_1",
|
|
entity_type="type_1",
|
|
period="2024-01",
|
|
rules=[self.rule1, self.rule2],
|
|
execution_mode=ExecutionMode.HYBRID
|
|
)
|
|
|
|
assert "dry_run" in dry_run_result
|
|
assert dry_run_result["dry_run"] is True
|
|
assert "execution_plan" in dry_run_result
|
|
assert "summary" in dry_run_result
|
|
assert "algorithm_info" in dry_run_result
|
|
|
|
def test_estimate_execution_time(self):
|
|
"""测试估算执行时间"""
|
|
plan = Mock()
|
|
plan.total_rules = 5
|
|
plan.execution_mode = ExecutionMode.SEQUENTIAL
|
|
plan.max_level = 0
|
|
|
|
time = self.engine._estimate_execution_time(plan)
|
|
|
|
assert time > 0
|
|
assert isinstance(time, float)
|
|
|
|
def test_estimate_execution_time_parallel(self):
|
|
"""测试估算并行执行时间"""
|
|
plan = Mock()
|
|
plan.total_rules = 5
|
|
plan.execution_mode = ExecutionMode.PARALLEL
|
|
plan.max_level = 0
|
|
|
|
time = self.engine._estimate_execution_time(plan)
|
|
|
|
assert time > 0
|
|
# 并行执行时间应该更短
|
|
|
|
def test_estimate_execution_time_hybrid(self):
|
|
"""测试估算混合模式执行时间"""
|
|
plan = Mock()
|
|
plan.total_rules = 5
|
|
plan.execution_mode = ExecutionMode.HYBRID
|
|
plan.max_level = 2
|
|
plan.stages = [
|
|
Mock(execution_mode=ExecutionMode.PARALLEL),
|
|
Mock(execution_mode=ExecutionMode.PARALLEL),
|
|
]
|
|
|
|
time = self.engine._estimate_execution_time(plan)
|
|
|
|
assert time > 0
|
|
|
|
|
|
class TestGetRuleEngine:
|
|
"""测试全局规则引擎实例"""
|
|
|
|
def test_get_rule_engine(self):
|
|
"""测试获取全局规则引擎实例"""
|
|
engine1 = get_rule_engine()
|
|
engine2 = get_rule_engine()
|
|
|
|
# 应该是同一个实例(单例模式)
|
|
assert engine1 is engine2
|
|
|
|
def test_get_rule_engine_initialization(self):
|
|
"""测试规则引擎初始化"""
|
|
engine = get_rule_engine()
|
|
|
|
assert engine is not None
|
|
assert isinstance(engine, RuleEngine)
|