""" 规则引擎单元测试 """ import pytest from unittest.mock import Mock, patch, MagicMock, AsyncMock from typing import List from app.models.risk_detection import DetectionRule, RiskLevel from app.services.risk_detection.engine.rule_engine import ( AlgorithmRegistry, RuleEngine, get_rule_engine, ) from app.services.risk_detection.engine.execution_plan import ExecutionMode from app.services.risk_detection.algorithms.base import RiskDetectionAlgorithm, DetectionContext, DetectionResult class MockAlgorithm(RiskDetectionAlgorithm): """模拟算法用于测试""" def __init__(self, algorithm_code: str = "MOCK_ALGO"): super().__init__() self._code = algorithm_code def get_algorithm_code(self) -> str: return self._code def get_algorithm_name(self) -> str: return f"模拟算法{self._code}" async def _do_detect(self, context: DetectionContext) -> DetectionResult: return DetectionResult( task_id=context.task_id, rule_id=context.rule_id, entity_id="test_entity", entity_type="test_type", risk_level=RiskLevel.LOW, risk_score=25.0, description="模拟检测结果", suggestion="模拟建议" ) class TestAlgorithmRegistry: """算法注册表测试""" def test_init(self): """测试初始化""" registry = AlgorithmRegistry() # 应该自动注册默认算法 assert len(registry._algorithms) > 0 def test_register(self): """测试注册算法""" registry = AlgorithmRegistry() # 清除默认算法以便测试 registry._algorithms.clear() registry._instances.clear() registry.register(MockAlgorithm) assert "MOCK_ALGO" in registry._algorithms assert registry._algorithms["MOCK_ALGO"] == MockAlgorithm def test_register_without_get_algorithm_code(self): """测试注册没有get_algorithm_code方法的算法""" registry = AlgorithmRegistry() class InvalidAlgorithm: pass with pytest.raises(ValueError, match="必须实现 get_algorithm_code"): registry.register(InvalidAlgorithm) def test_get_algorithm(self): """测试获取算法实例""" registry = AlgorithmRegistry() # 清除默认算法以便测试 registry._algorithms.clear() registry._instances.clear() registry.register(MockAlgorithm) algorithm = registry.get_algorithm("MOCK_ALGO") assert isinstance(algorithm, MockAlgorithm) assert algorithm.get_algorithm_code() == "MOCK_ALGO" # 第二次获取应该是同一个实例(单例模式) algorithm2 = registry.get_algorithm("MOCK_ALGO") assert algorithm is algorithm2 def test_get_algorithm_not_registered(self): """测试获取未注册的算法""" registry = AlgorithmRegistry() with pytest.raises(ValueError, match="未注册的算法"): registry.get_algorithm("NONEXISTENT") def test_get_all_algorithms(self): """测试获取所有算法""" registry = AlgorithmRegistry() algorithms = registry.get_all_algorithms() assert isinstance(algorithms, list) assert len(algorithms) > 0 def test_is_registered(self): """测试检查算法是否已注册""" registry = AlgorithmRegistry() # 默认应该有算法 first_algo = registry.get_all_algorithms()[0] assert registry.is_registered(first_algo) is True # 不存在的算法 assert registry.is_registered("NONEXISTENT") is False class TestRuleEngine: """规则引擎测试""" def setup_method(self): """每个测试方法前执行""" self.engine = RuleEngine() # 创建模拟规则 self.rule1 = Mock(spec=DetectionRule) self.rule1.rule_id = "rule_1" self.rule1.algorithm_code = "REVENUE_INTEGRITY_CHECK" self.rule1.parameters = {} self.rule1.is_enabled = True self.rule2 = Mock(spec=DetectionRule) self.rule2.rule_id = "rule_2" self.rule2.algorithm_code = "REVENUE_INTEGRITY_CHECK" self.rule2.parameters = {} self.rule2.is_enabled = True def test_init(self): """测试初始化""" assert self.engine is not None assert self.engine.algorithm_registry is not None assert self.engine.dependency_resolver is not None assert self.engine.execution_planner is not None assert self.engine.result_processor is not None def test_init_with_db_session(self): """测试带数据库会话的初始化""" mock_session = Mock() engine = RuleEngine(db_session=mock_session) assert engine.db_session == mock_session def test_validate_rules_empty(self): """测试验证空规则列表""" with pytest.raises(ValueError, match="规则列表不能为空"): self.engine._validate_rules([]) def test_validate_rules_unregistered_algorithm(self): """测试验证未注册算法的规则""" rule = Mock(spec=DetectionRule) rule.rule_id = "rule_1" rule.algorithm_code = "NONEXISTENT_ALGO" rule.is_enabled = True with pytest.raises(ValueError, match="未注册"): self.engine._validate_rules([rule]) def test_validate_rules_disabled(self): """测试验证已禁用的规则""" rule = Mock(spec=DetectionRule) rule.rule_id = "rule_1" rule.algorithm_code = "REVENUE_INTEGRITY_CHECK" rule.is_enabled = False # 不会抛出异常,只会有警告 self.engine._validate_rules([rule]) def test_validate_rules_valid(self): """测试验证有效规则""" # 应该不抛出异常 self.engine._validate_rules([self.rule1]) def test_are_stage_dependencies_met(self): """测试检查阶段依赖""" stage = Mock() stage.depends_on = ["stage_0", "stage_1"] # 依赖满足 assert self.engine._are_stage_dependencies_met(stage, ["stage_0", "stage_1"]) is True # 依赖不满足 assert self.engine._are_stage_dependencies_met(stage, ["stage_0"]) is False def test_execute_single_rule(self): """测试执行单个规则""" import asyncio async def test(): result = await self.engine._execute_single_rule( self.rule1, "rule_1" ) assert result is not None assert isinstance(result, DetectionResult) assert result.rule_id == "rule_1" asyncio.run(test()) @pytest.mark.asyncio async def test_execute_sequential_stage(self): """测试执行串行阶段""" stage = Mock() stage.execution_mode = ExecutionMode.SEQUENTIAL stage.nodes = [ Mock(node_id="node_1", rule=self.rule1), Mock(node_id="node_2", rule=self.rule2), ] results = await self.engine._execute_sequential(stage) assert len(results) == 2 @pytest.mark.asyncio async def test_execute_parallel_stage(self): """测试执行并行阶段""" stage = Mock() stage.execution_mode = ExecutionMode.PARALLEL stage.nodes = [ Mock(node_id="node_1", rule=self.rule1), Mock(node_id="node_2", rule=self.rule2), ] results = await self.engine._execute_parallel(stage, max_concurrent=2) assert len(results) == 2 @pytest.mark.asyncio async def test_execute_detection(self): """测试执行检测逻辑""" # 模拟执行计划 plan = Mock() plan.stages = [] results = await self.engine._execute_detection(plan, max_concurrent=2) assert results == [] def test_get_algorithm_info(self): """测试获取算法信息""" info = self.engine.get_algorithm_info() assert isinstance(info, dict) assert len(info) > 0 # 检查收入完整性检测算法信息 if "REVENUE_INTEGRITY_CHECK" in info: algo_info = info["REVENUE_INTEGRITY_CHECK"] assert "name" in algo_info assert "description" in algo_info assert "code" in algo_info assert algo_info["code"] == "REVENUE_INTEGRITY_CHECK" def test_validate_execution_plan_valid(self): """测试验证有效执行计划""" plan = Mock() plan.stages = [ Mock(stage_id="stage_0", depends_on=[], rule_count=1), ] plan.total_rules = 1 is_valid, error = self.engine.validate_execution_plan(plan) assert is_valid is True assert error is None def test_validate_execution_plan_empty_stages(self): """测试验证无阶段的执行计划""" plan = Mock() plan.stages = [] plan.total_rules = 0 is_valid, error = self.engine.validate_execution_plan(plan) assert is_valid is False assert "没有阶段" in error def test_validate_execution_plan_invalid_dependencies(self): """测试验证依赖无效的执行计划""" plan = Mock() plan.stages = [ Mock(stage_id="stage_0", depends_on=[], rule_count=1), Mock(stage_id="stage_1", depends_on=["stage_0"], rule_count=1), ] plan.total_rules = 2 # 模拟依赖检查 self.engine._are_stage_dependencies_met = Mock(return_value=False) is_valid, error = self.engine.validate_execution_plan(plan) assert is_valid is False assert "依赖不满足" in error def test_validate_execution_plan_inconsistent_count(self): """测试验证规则数量不一致的执行计划""" plan = Mock() plan.stages = [ Mock(stage_id="stage_0", depends_on=[], rule_count=1), ] plan.total_rules = 2 # 不一致 # 模拟依赖检查 self.engine._are_stage_dependencies_met = Mock(return_value=True) is_valid, error = self.engine.validate_execution_plan(plan) assert is_valid is False assert "规则数量不一致" in error @pytest.mark.asyncio async def test_dry_run(self): """测试试运行""" dry_run_result = await self.engine.dry_run( entity_id="entity_1", entity_type="type_1", period="2024-01", rules=[self.rule1, self.rule2], execution_mode=ExecutionMode.HYBRID ) assert "dry_run" in dry_run_result assert dry_run_result["dry_run"] is True assert "execution_plan" in dry_run_result assert "summary" in dry_run_result assert "algorithm_info" in dry_run_result def test_estimate_execution_time(self): """测试估算执行时间""" plan = Mock() plan.total_rules = 5 plan.execution_mode = ExecutionMode.SEQUENTIAL plan.max_level = 0 time = self.engine._estimate_execution_time(plan) assert time > 0 assert isinstance(time, float) def test_estimate_execution_time_parallel(self): """测试估算并行执行时间""" plan = Mock() plan.total_rules = 5 plan.execution_mode = ExecutionMode.PARALLEL plan.max_level = 0 time = self.engine._estimate_execution_time(plan) assert time > 0 # 并行执行时间应该更短 def test_estimate_execution_time_hybrid(self): """测试估算混合模式执行时间""" plan = Mock() plan.total_rules = 5 plan.execution_mode = ExecutionMode.HYBRID plan.max_level = 2 plan.stages = [ Mock(execution_mode=ExecutionMode.PARALLEL), Mock(execution_mode=ExecutionMode.PARALLEL), ] time = self.engine._estimate_execution_time(plan) assert time > 0 class TestGetRuleEngine: """测试全局规则引擎实例""" def test_get_rule_engine(self): """测试获取全局规则引擎实例""" engine1 = get_rule_engine() engine2 = get_rule_engine() # 应该是同一个实例(单例模式) assert engine1 is engine2 def test_get_rule_engine_initialization(self): """测试规则引擎初始化""" engine = get_rule_engine() assert engine is not None assert isinstance(engine, RuleEngine)