deep-risk/backend/tests/demo_test.py
2025-12-14 20:08:27 +08:00

267 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
收入完整性检测算法测试演示脚本(真实数据库版)
从真实数据库中读取数据,演示不同风险场景的检测逻辑
"""
import sys
import asyncio
from datetime import datetime, date, timedelta
from decimal import Decimal
# 添加项目路径
sys.path.insert(0, "/Users/liulujian/Documents/code/deeprisk-claude-1/backend")
# 导入相关模块
from app.services.risk_detection.algorithms.revenue_integrity import RevenueIntegrityAlgorithm
from app.services.risk_detection.algorithms.base import DetectionContext
from app.models.risk_detection import RiskLevel
from app.database import AsyncSessionLocal, engine
from sqlalchemy import select, func, and_, or_
from app.models.streamer import StreamerInfo, PlatformRecharge
from app.models.tax_declaration import TaxDeclaration
from app.models.contract import RevenueSharingContract
class RevenueIntegrityDemo:
"""收入完整性检测演示类"""
def __init__(self):
self.algorithm = RevenueIntegrityAlgorithm()
def print_header(self, title):
"""打印标题"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def print_result(self, result):
"""打印检测结果"""
print(f"\n【检测结果】")
print(f" 风险等级: {result.risk_level}")
print(f" 风险评分: {result.risk_score:.2f}")
print(f" 风险描述: {result.description}")
if result.risk_data:
print(f" 详细数据: {result.risk_data}")
async def demo_scenario_by_streamer(self, streamer_id: str, period: str, scenario_name: str):
"""
从数据库读取数据并执行检测
:param streamer_id: 主播ID
:param period: 检测期间格式YYYY-MM
:param scenario_name: 场景名称
"""
self.print_header(scenario_name)
async with AsyncSessionLocal() as db_session:
# 获取主播信息
print(f"\n【查询数据】")
print(f" 主播ID: {streamer_id}")
print(f" 检测期间: {period}")
# 查询主播信息
stmt = select(StreamerInfo).where(StreamerInfo.streamer_id == streamer_id)
result = await db_session.execute(stmt)
streamer = result.scalar_one_or_none()
if not streamer:
print(f"\n✗ 未找到主播信息: {streamer_id}")
return
print(f" 主播名称: {streamer.streamer_name}")
print(f" 实体类型: {streamer.entity_type}")
# 查询充值数据
start_date, end_date = self._parse_period(period)
recharge_stmt = select(
func.count(PlatformRecharge.id).label("count"),
func.coalesce(func.sum(PlatformRecharge.actual_amount_cny), 0).label("total")
).where(
and_(
PlatformRecharge.user_id == streamer_id,
PlatformRecharge.recharge_time >= start_date,
PlatformRecharge.recharge_time <= end_date,
PlatformRecharge.status == "success",
)
)
recharge_result = await db_session.execute(recharge_stmt)
recharge_summary = recharge_result.one()
print(f" 充值总额: {recharge_summary.total:,.2f}")
print(f" 充值记录数: {recharge_summary.count}")
# 查询申报数据
taxpayer_ids = []
if streamer.tax_registration_no:
taxpayer_ids.append(streamer.tax_registration_no)
if streamer.unified_social_credit_code:
taxpayer_ids.append(streamer.unified_social_credit_code)
if streamer.id_card_no:
taxpayer_ids.append(streamer.id_card_no)
if taxpayer_ids:
from sqlalchemy import or_
declaration_stmt = select(TaxDeclaration).where(
and_(
TaxDeclaration.taxpayer_id.in_(taxpayer_ids),
TaxDeclaration.tax_period == period,
)
)
declaration_result = await db_session.execute(declaration_stmt)
declarations = declaration_result.scalars().all()
total_declaration = sum(d.sales_revenue or 0 for d in declarations)
print(f" 申报收入: {total_declaration:,.2f}")
print(f" 申报记录数: {len(declarations)}")
else:
print(f" 申报收入: 0元无税号信息")
print(f" 申报记录数: 0")
# 查询分成协议
contract_stmt = select(RevenueSharingContract).where(
and_(
RevenueSharingContract.streamer_id == streamer_id,
RevenueSharingContract.contract_start_date <= end_date,
or_(
RevenueSharingContract.contract_end_date.is_(None),
RevenueSharingContract.contract_end_date >= start_date,
),
RevenueSharingContract.contract_status == "active",
)
).order_by(RevenueSharingContract.contract_start_date.desc()).limit(1)
contract_result = await db_session.execute(contract_stmt)
contract = contract_result.scalar_one_or_none()
if contract:
print(f" 分成比例: 主播{contract.streamer_ratio}%, 平台{contract.platform_ratio}%")
else:
print(f" 分成比例: 无协议")
print(f"\n【执行检测】")
# 执行检测
context = DetectionContext(
task_id=f"task_demo_{streamer_id}",
rule_id="rule_demo_revenue_integrity",
parameters={"streamer_id": streamer_id, "period": period},
db_session=db_session
)
result = await self.algorithm.detect(context)
self.print_result(result)
def _parse_period(self, period: str):
"""解析期间格式"""
if "-" in period and "Q" not in period:
year, month = map(int, period.split("-"))
start_date = datetime(year, month, 1)
if month == 12:
end_date = datetime(year + 1, 1, 1) - timedelta(days=1)
else:
end_date = datetime(year, month + 1, 1) - timedelta(days=1)
return start_date, end_date
else:
raise ValueError(f"不支持的期间格式: {period}")
async def demo_scenario_1_normal(self):
"""演示场景1从数据库读取数据执行检测"""
# 这里需要指定一个真实存在的主播ID和期间
# 请根据实际数据库中的数据修改这里的值
streamer_id = "TEST_001" # 请修改为真实的主播ID
period = "2024-01" # 请修改为要检测的期间
await self.demo_scenario_by_streamer(streamer_id, period, "场景1: 收入完整性检测(数据库数据)")
async def demo_scenario_2_under_reporting(self):
"""演示场景2从数据库读取数据执行检测"""
# 这里需要指定一个真实存在的主播ID和期间
streamer_id = "TEST_002" # 请修改为真实的主播ID
period = "2024-02" # 请修改为要检测的期间
await self.demo_scenario_by_streamer(streamer_id, period, "场景2: 收入完整性检测(数据库数据)")
async def demo_scenario_3_critical(self):
"""演示场景3从数据库读取数据执行检测"""
# 这里需要指定一个真实存在的主播ID和期间
streamer_id = "TEST_003" # 请修改为真实的主播ID
period = "2024-03" # 请修改为要检测的期间
await self.demo_scenario_by_streamer(streamer_id, period, "场景3: 收入完整性检测(数据库数据)")
async def demo_scenario_4_no_data(self):
"""演示场景4从数据库读取数据执行检测"""
# 这里需要指定一个真实存在的主播ID和期间
streamer_id = "TEST_004" # 请修改为真实的主播ID
period = "2024-04" # 请修改为要检测的期间
await self.demo_scenario_by_streamer(streamer_id, period, "场景4: 收入完整性检测(数据库数据)")
async def list_available_streamers(self):
"""列出数据库中可用的主播ID"""
self.print_header("可用主播列表")
async with AsyncSessionLocal() as db_session:
stmt = select(StreamerInfo).limit(10)
result = await db_session.execute(stmt)
streamers = result.scalars().all()
if not streamers:
print("\n数据库中没有找到主播信息")
print("请先向数据库中插入测试数据")
return []
print("\n【数据库中的主播信息】")
print(f"{'主播ID':<30} {'主播名称':<20} {'实体类型':<15} {'税号信息':<30}")
print("-" * 95)
available_streamers = []
for streamer in streamers:
tax_info = streamer.tax_registration_no or streamer.unified_social_credit_code or streamer.id_card_no or ""
print(f"{streamer.streamer_id:<30} {streamer.streamer_name:<20} {streamer.entity_type:<15} {tax_info:<30}")
available_streamers.append(streamer.streamer_id)
print("\n【使用方法】")
print("将上述任意一个主播ID复制到测试场景中例如")
print(" streamer_id = 'ZB_XXXX' # 替换为上述任一主播ID")
print("\n" + "=" * 80)
return available_streamers
async def run_all_demos(self):
"""运行所有演示场景"""
print("\n" + "=" * 80)
print(" 收入完整性检测算法演示(真实数据库版)")
print(" 从数据库读取真实数据,演示不同的风险场景和检测逻辑")
print("=" * 80)
# 先列出可用的主播
available_streamers = await self.list_available_streamers()
if not available_streamers:
print("\n✗ 没有可用的主播数据,无法执行演示")
return
print("\n【使用说明】")
print("请修改脚本中的主播ID为上面列出的可用主播ID然后重新运行")
print("例如:修改 demo_scenario_1_normal() 方法中的 streamer_id 变量")
print("\n" + "=" * 80)
print("\n【风险等级说明】")
print(" NONE: 正常,无风险")
print(" LOW: 低度风险,持续监控")
print(" MEDIUM: 中度风险,需要核查")
print(" HIGH: 高度风险,重点关注")
print(" CRITICAL: 严重风险,立即处理")
print("=" * 80 + "\n")
async def main():
"""主函数"""
demo = RevenueIntegrityDemo()
# await demo.run_all_demos()
await demo.demo_scenario_2_under_reporting()
if __name__ == "__main__":
asyncio.run(main())