529 lines
19 KiB
Python
529 lines
19 KiB
Python
"""
|
||
检测报告查看功能测试用例
|
||
|
||
覆盖场景:
|
||
1. 查询检测结果列表(按任务、实体、风险等级过滤)
|
||
2. 获取检测结果详情
|
||
3. 获取检测结果汇总报告
|
||
4. 风险分布统计
|
||
5. 报告导出功能(模拟)
|
||
"""
|
||
import pytest
|
||
from datetime import datetime, date, timedelta
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, and_, func
|
||
|
||
from app.models.risk_detection import (
|
||
DetectionResult,
|
||
DetectionTask,
|
||
RiskLevel,
|
||
)
|
||
from app.api.v1.endpoints.risk_detection import router as risk_detection_router
|
||
from fastapi.testclient import TestClient
|
||
from app.main import app
|
||
|
||
|
||
class TestDetectionReport:
|
||
"""检测报告测试类"""
|
||
|
||
@pytest.fixture
|
||
def mock_db_session(self):
|
||
"""创建Mock数据库会话"""
|
||
session = AsyncMock(spec=AsyncSession)
|
||
session.execute = AsyncMock()
|
||
return session
|
||
|
||
@pytest.fixture
|
||
def mock_detection_results(self):
|
||
"""Mock检测结果数据"""
|
||
return [
|
||
DetectionResult(
|
||
id=1,
|
||
task_id="task_test_001",
|
||
rule_id="rule_revenue_001",
|
||
entity_id="ZB_TEST_001",
|
||
entity_type="streamer",
|
||
risk_level=RiskLevel.HIGH,
|
||
risk_score=85.0,
|
||
risk_category="收入完整性",
|
||
description="平台充值与申报收入存在重大不匹配",
|
||
suggestion="请核实收入申报的准确性,补缴相应税款",
|
||
risk_data={
|
||
"recharge_amount": 100000.0,
|
||
"declared_amount": 80000.0,
|
||
"discrepancy": 20000.0,
|
||
"discrepancy_rate": 0.2
|
||
},
|
||
evidence=[
|
||
{
|
||
"type": "bank_transaction",
|
||
"description": "银行流水显示收入100,000元",
|
||
"data": {"amount": 100000.0, "date": "2024-01-15"}
|
||
},
|
||
{
|
||
"type": "tax_declaration",
|
||
"description": "申报收入仅80,000元",
|
||
"data": {"amount": 80000.0, "date": "2024-02-15"}
|
||
}
|
||
],
|
||
status="active",
|
||
is_false_positive=False,
|
||
detected_at=datetime(2024, 2, 20, 10, 30, 0),
|
||
),
|
||
DetectionResult(
|
||
id=2,
|
||
task_id="task_test_001",
|
||
rule_id="rule_private_001",
|
||
entity_id="ZB_TEST_001",
|
||
entity_type="streamer",
|
||
risk_level=RiskLevel.MEDIUM,
|
||
risk_score=65.0,
|
||
risk_category="私户收款",
|
||
description="发现使用私人账户进行资金转账",
|
||
suggestion="建议规范资金管理,使用对公账户",
|
||
risk_data={
|
||
"private_account_amount": 50000.0,
|
||
"total_transactions": 100000.0,
|
||
"private_account_ratio": 0.5
|
||
},
|
||
evidence=[
|
||
{
|
||
"type": "bank_transaction",
|
||
"description": "私人账户转账记录",
|
||
"data": {"amount": 50000.0, "account": "6222021234567890123"}
|
||
}
|
||
],
|
||
status="active",
|
||
is_false_positive=False,
|
||
detected_at=datetime(2024, 2, 20, 11, 0, 0),
|
||
),
|
||
DetectionResult(
|
||
id=3,
|
||
task_id="task_test_002",
|
||
rule_id="rule_expense_001",
|
||
entity_id="ZB_TEST_002",
|
||
entity_type="streamer",
|
||
risk_level=RiskLevel.LOW,
|
||
risk_score=30.0,
|
||
risk_category="费用异常",
|
||
description="费用增长略高于平均水平",
|
||
suggestion="关注费用增长趋势",
|
||
risk_data={
|
||
"current_expense": 50000.0,
|
||
"average_expense": 45000.0,
|
||
"growth_rate": 0.11
|
||
},
|
||
evidence=[],
|
||
status="active",
|
||
is_false_positive=False,
|
||
detected_at=datetime(2024, 2, 21, 14, 20, 0),
|
||
),
|
||
]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_results_with_task_filter(self, mock_db_session, mock_detection_results):
|
||
"""测试按任务ID查询检测结果"""
|
||
from app.api.v1.endpoints.risk_detection import list_results
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [
|
||
r for r in mock_detection_results if r.task_id == "task_test_001"
|
||
]
|
||
|
||
# 执行测试
|
||
results = await list_results(
|
||
task_id="task_test_001",
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert len(results) == 2
|
||
assert all(r["task_id"] == "task_test_001" for r in results)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_results_with_entity_filter(self, mock_db_session, mock_detection_results):
|
||
"""测试按实体ID查询检测结果"""
|
||
from app.api.v1.endpoints.risk_detection import list_results
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [
|
||
r for r in mock_detection_results if r.entity_id == "ZB_TEST_001"
|
||
]
|
||
|
||
# 执行测试
|
||
results = await list_results(
|
||
entity_id="ZB_TEST_001",
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert len(results) == 2
|
||
assert all(r["entity_id"] == "ZB_TEST_001" for r in results)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_results_with_risk_level_filter(self, mock_db_session, mock_detection_results):
|
||
"""测试按风险等级查询检测结果"""
|
||
from app.api.v1.endpoints.risk_detection import list_results
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [
|
||
r for r in mock_detection_results if r.risk_level == RiskLevel.HIGH
|
||
]
|
||
|
||
# 执行测试
|
||
results = await list_results(
|
||
risk_level="HIGH",
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert len(results) == 1
|
||
assert results[0]["result"]["risk_level"] == "HIGH"
|
||
assert results[0]["result"]["risk_score"] == 85.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_results_with_multiple_filters(self, mock_db_session, mock_detection_results):
|
||
"""测试多条件组合查询检测结果"""
|
||
from app.api.v1.endpoints.risk_detection import list_results
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [
|
||
r for r in mock_detection_results
|
||
if r.task_id == "task_test_001" and r.entity_id == "ZB_TEST_001"
|
||
]
|
||
|
||
# 执行测试
|
||
results = await list_results(
|
||
task_id="task_test_001",
|
||
entity_id="ZB_TEST_001",
|
||
limit=100,
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert len(results) == 2
|
||
assert all(r["task_id"] == "task_test_001" and r["entity_id"] == "ZB_TEST_001" for r in results)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_result_detail(self, mock_db_session, mock_detection_results):
|
||
"""测试获取检测结果详情"""
|
||
from app.api.v1.endpoints.risk_detection import get_result
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_detection_results[0]
|
||
|
||
# 执行测试
|
||
result = await get_result(
|
||
result_id=1,
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert result["id"] == 1
|
||
assert result["task_id"] == "task_test_001"
|
||
assert result["rule_id"] == "rule_revenue_001"
|
||
assert result["result"]["entity_id"] == "ZB_TEST_001"
|
||
assert result["result"]["risk_level"] == "HIGH"
|
||
assert result["result"]["risk_score"] == 85.0
|
||
assert result["result"]["description"] == "平台充值与申报收入存在重大不匹配"
|
||
assert len(result["result"]["evidence"]) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_result_not_found(self, mock_db_session):
|
||
"""测试获取不存在的检测结果"""
|
||
from app.api.v1.endpoints.risk_detection import get_result
|
||
|
||
# 设置mock返回值为None
|
||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
||
|
||
# 验证抛出异常
|
||
from fastapi import HTTPException
|
||
with pytest.raises(HTTPException) as exc_info:
|
||
await get_result(result_id=9999, db=mock_db_session)
|
||
assert exc_info.value.status_code == 404
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_detection_summary(self, mock_db_session, mock_detection_results):
|
||
"""测试获取检测结果汇总"""
|
||
from app.api.v1.endpoints.risk_detection import get_detection_summary
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [
|
||
r for r in mock_detection_results if r.entity_id == "ZB_TEST_001"
|
||
]
|
||
|
||
# 执行测试
|
||
summary = await get_detection_summary(
|
||
entity_id="ZB_TEST_001",
|
||
entity_type="streamer",
|
||
period="2024-01",
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert summary["entity_id"] == "ZB_TEST_001"
|
||
assert summary["entity_type"] == "streamer"
|
||
assert summary["period"] == "2024-01"
|
||
assert summary["total_detections"] == 2
|
||
assert "HIGH" in summary["risk_distribution"]
|
||
assert summary["risk_distribution"]["HIGH"] == 1
|
||
assert "MEDIUM" in summary["risk_distribution"]
|
||
assert summary["risk_distribution"]["MEDIUM"] == 1
|
||
assert summary["avg_risk_score"] == 75.0 # (85.0 + 65.0) / 2
|
||
assert summary["high_risk_count"] == 1
|
||
assert len(summary["recommendations"]) > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_detection_summary_no_results(self, mock_db_session):
|
||
"""测试获取空检测结果的汇总"""
|
||
from app.api.v1.endpoints.risk_detection import get_detection_summary
|
||
|
||
# 设置mock返回值为空列表
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||
|
||
# 执行测试
|
||
summary = await get_detection_summary(
|
||
entity_id="ZB_TEST_999",
|
||
entity_type="streamer",
|
||
period="2024-01",
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert summary["entity_id"] == "ZB_TEST_999"
|
||
assert summary["total_detections"] == 0
|
||
assert summary["risk_distribution"] == {}
|
||
assert summary["avg_risk_score"] == 0.0
|
||
assert summary["high_risk_count"] == 0
|
||
assert len(summary["recommendations"]) > 0
|
||
assert "未发现明显风险" in summary["recommendations"][0]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_detection_summary_critical_risk(self, mock_db_session):
|
||
"""测试获取包含极高风险的汇总"""
|
||
from app.api.v1.endpoints.risk_detection import get_detection_summary
|
||
|
||
# 创建包含极高风险的结果
|
||
critical_results = [
|
||
DetectionResult(
|
||
id=1,
|
||
task_id="task_test_001",
|
||
rule_id="rule_revenue_001",
|
||
entity_id="ZB_TEST_001",
|
||
entity_type="streamer",
|
||
risk_level=RiskLevel.CRITICAL,
|
||
risk_score=95.0,
|
||
risk_category="收入完整性",
|
||
description="重大收入隐瞒",
|
||
suggestion="立即处理",
|
||
risk_data={},
|
||
evidence=[],
|
||
status="active",
|
||
is_false_positive=False,
|
||
detected_at=datetime.now(),
|
||
)
|
||
]
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = critical_results
|
||
|
||
# 执行测试
|
||
summary = await get_detection_summary(
|
||
entity_id="ZB_TEST_001",
|
||
entity_type="streamer",
|
||
period="2024-01",
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果
|
||
assert summary["risk_distribution"]["CRITICAL"] == 1
|
||
assert "极高风险项目" in summary["recommendations"][0]
|
||
assert summary["high_risk_count"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_risk_distribution_statistics(self, mock_db_session, mock_detection_results):
|
||
"""测试风险分布统计"""
|
||
from app.api.v1.endpoints.risk_detection import get_detection_summary
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_results
|
||
|
||
# 执行测试
|
||
summary = await get_detection_summary(
|
||
entity_id="ZB_TEST_ALL",
|
||
entity_type="streamer",
|
||
period="2024-01",
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证风险分布统计
|
||
assert summary["total_detections"] == 3
|
||
|
||
# 统计各风险等级数量
|
||
risk_dist = summary["risk_distribution"]
|
||
assert risk_dist.get("HIGH", 0) == 1
|
||
assert risk_dist.get("MEDIUM", 0) == 1
|
||
assert risk_dist.get("LOW", 0) == 1
|
||
|
||
# 验证平均风险评分
|
||
expected_avg = (85.0 + 65.0 + 30.0) / 3
|
||
assert summary["avg_risk_score"] == expected_avg
|
||
|
||
# 验证高风险计数(CRITICAL + HIGH)
|
||
assert summary["high_risk_count"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_results_with_limit(self, mock_db_session, mock_detection_results):
|
||
"""测试查询结果数量限制"""
|
||
from app.api.v1.endpoints.risk_detection import list_results
|
||
|
||
# 设置mock返回值(返回所有结果)
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_detection_results
|
||
|
||
# 执行测试 - 限制数量为2
|
||
results = await list_results(
|
||
limit=2,
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证limit参数被正确传递(在实际实现中应该限制返回数量)
|
||
# 注意:这里我们假设实现已经处理了limit参数
|
||
mock_db_session.execute.assert_called()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_result_evidence_structure(self, mock_db_session, mock_detection_results):
|
||
"""测试结果中证据链结构"""
|
||
from app.api.v1.endpoints.risk_detection import get_result
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_detection_results[0]
|
||
|
||
# 执行测试
|
||
result = await get_result(
|
||
result_id=1,
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证证据链结构
|
||
evidence = result["result"]["evidence"]
|
||
assert len(evidence) == 2
|
||
|
||
# 验证每个证据项的结构
|
||
for ev in evidence:
|
||
assert "type" in ev
|
||
assert "description" in ev
|
||
assert "data" in ev
|
||
assert ev["type"] in ["bank_transaction", "tax_declaration"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_result_risk_data_structure(self, mock_db_session, mock_detection_results):
|
||
"""测试结果中风险数据结构"""
|
||
from app.api.v1.endpoints.risk_detection import get_result
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_detection_results[0]
|
||
|
||
# 执行测试
|
||
result = await get_result(
|
||
result_id=1,
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证风险数据结构
|
||
risk_data = result["result"]["risk_data"]
|
||
assert "recharge_amount" in risk_data
|
||
assert "declared_amount" in risk_data
|
||
assert "discrepancy" in risk_data
|
||
assert "discrepancy_rate" in risk_data
|
||
|
||
# 验证数据类型
|
||
assert isinstance(risk_data["recharge_amount"], (int, float))
|
||
assert isinstance(risk_data["discrepancy_rate"], (int, float))
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_false_positive_filtering(self, mock_db_session, mock_detection_results):
|
||
"""测试误报过滤(如果实现)"""
|
||
from app.api.v1.endpoints.risk_detection import list_results
|
||
|
||
# 创建包含误报的结果
|
||
results_with_fp = mock_detection_results + [
|
||
DetectionResult(
|
||
id=4,
|
||
task_id="task_test_003",
|
||
rule_id="rule_test_001",
|
||
entity_id="ZB_TEST_003",
|
||
entity_type="streamer",
|
||
risk_level=RiskLevel.LOW,
|
||
risk_score=20.0,
|
||
risk_category="测试",
|
||
description="误报结果",
|
||
suggestion="忽略",
|
||
risk_data={},
|
||
evidence=[],
|
||
status="active",
|
||
is_false_positive=True,
|
||
detected_at=datetime.now(),
|
||
)
|
||
]
|
||
|
||
# 设置mock返回值
|
||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = results_with_fp
|
||
|
||
# 执行测试
|
||
results = await list_results(
|
||
db=mock_db_session
|
||
)
|
||
|
||
# 验证结果(如果实现过滤误报功能)
|
||
# 注意:当前实现可能没有过滤误报,这里只是示例
|
||
assert len(results) >= 3
|
||
|
||
|
||
class TestReportAPIRoutes:
|
||
"""报告API路由测试类"""
|
||
|
||
@pytest.fixture
|
||
def client(self):
|
||
"""创建测试客户端"""
|
||
return TestClient(app)
|
||
|
||
def test_get_algorithms(self, client):
|
||
"""测试获取算法列表API"""
|
||
response = client.get("/api/v1/risk-detection/algorithms")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert len(data) >= 5 # 至少有5个算法
|
||
|
||
# 验证算法结构
|
||
for algo in data:
|
||
assert "code" in algo
|
||
assert "name" in algo
|
||
assert "description" in algo
|
||
assert "parameters" in algo
|
||
|
||
# 验证特定算法
|
||
revenue_algo = next((a for a in data if a["code"] == "REVENUE_INTEGRITY_CHECK"), None)
|
||
assert revenue_algo is not None
|
||
assert revenue_algo["name"] == "收入完整性检测"
|
||
|
||
def test_get_rules(self, client):
|
||
"""测试获取规则列表API"""
|
||
response = client.get("/api/v1/risk-detection/rules")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert isinstance(data, list)
|
||
|
||
# 验证规则结构
|
||
for rule in data:
|
||
assert "rule_id" in rule
|
||
assert "rule_name" in rule
|
||
assert "algorithm_code" in rule
|
||
assert "is_enabled" in rule
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|