267 lines
11 KiB
Python
267 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
收入完整性检测算法测试演示脚本(真实数据库版)
|
||
|
||
从真实数据库中读取数据,演示不同风险场景的检测逻辑
|
||
"""
|
||
|
||
import sys
|
||
import asyncio
|
||
from datetime import datetime, date, timedelta
|
||
from decimal import Decimal
|
||
|
||
# 添加项目路径
|
||
sys.path.insert(0, "/Users/liulujian/Documents/code/deeprisk-claude-1/backend")
|
||
|
||
# 导入相关模块
|
||
from app.services.risk_detection.algorithms.revenue_integrity import RevenueIntegrityAlgorithm
|
||
from app.services.risk_detection.algorithms.base import DetectionContext
|
||
from app.models.risk_detection import RiskLevel
|
||
from app.database import AsyncSessionLocal, engine
|
||
from sqlalchemy import select, func, and_, or_
|
||
from app.models.streamer import StreamerInfo, PlatformRecharge
|
||
from app.models.tax_declaration import TaxDeclaration
|
||
from app.models.contract import RevenueSharingContract
|
||
|
||
|
||
class RevenueIntegrityDemo:
|
||
"""收入完整性检测演示类"""
|
||
|
||
def __init__(self):
|
||
self.algorithm = RevenueIntegrityAlgorithm()
|
||
|
||
def print_header(self, title):
|
||
"""打印标题"""
|
||
print("\n" + "=" * 80)
|
||
print(f" {title}")
|
||
print("=" * 80)
|
||
|
||
def print_result(self, result):
|
||
"""打印检测结果"""
|
||
print(f"\n【检测结果】")
|
||
print(f" 风险等级: {result.risk_level}")
|
||
print(f" 风险评分: {result.risk_score:.2f}")
|
||
print(f" 风险描述: {result.description}")
|
||
if result.risk_data:
|
||
print(f" 详细数据: {result.risk_data}")
|
||
|
||
async def demo_scenario_by_streamer(self, streamer_id: str, period: str, scenario_name: str):
|
||
"""
|
||
从数据库读取数据并执行检测
|
||
:param streamer_id: 主播ID
|
||
:param period: 检测期间(格式:YYYY-MM)
|
||
:param scenario_name: 场景名称
|
||
"""
|
||
self.print_header(scenario_name)
|
||
|
||
async with AsyncSessionLocal() as db_session:
|
||
# 获取主播信息
|
||
print(f"\n【查询数据】")
|
||
print(f" 主播ID: {streamer_id}")
|
||
print(f" 检测期间: {period}")
|
||
|
||
# 查询主播信息
|
||
stmt = select(StreamerInfo).where(StreamerInfo.streamer_id == streamer_id)
|
||
result = await db_session.execute(stmt)
|
||
streamer = result.scalar_one_or_none()
|
||
|
||
if not streamer:
|
||
print(f"\n✗ 未找到主播信息: {streamer_id}")
|
||
return
|
||
|
||
print(f" 主播名称: {streamer.streamer_name}")
|
||
print(f" 实体类型: {streamer.entity_type}")
|
||
|
||
# 查询充值数据
|
||
start_date, end_date = self._parse_period(period)
|
||
recharge_stmt = select(
|
||
func.count(PlatformRecharge.id).label("count"),
|
||
func.coalesce(func.sum(PlatformRecharge.actual_amount_cny), 0).label("total")
|
||
).where(
|
||
and_(
|
||
PlatformRecharge.user_id == streamer_id,
|
||
PlatformRecharge.recharge_time >= start_date,
|
||
PlatformRecharge.recharge_time <= end_date,
|
||
PlatformRecharge.status == "success",
|
||
)
|
||
)
|
||
recharge_result = await db_session.execute(recharge_stmt)
|
||
recharge_summary = recharge_result.one()
|
||
print(f" 充值总额: {recharge_summary.total:,.2f}元")
|
||
print(f" 充值记录数: {recharge_summary.count}")
|
||
|
||
# 查询申报数据
|
||
taxpayer_ids = []
|
||
if streamer.tax_registration_no:
|
||
taxpayer_ids.append(streamer.tax_registration_no)
|
||
if streamer.unified_social_credit_code:
|
||
taxpayer_ids.append(streamer.unified_social_credit_code)
|
||
if streamer.id_card_no:
|
||
taxpayer_ids.append(streamer.id_card_no)
|
||
|
||
if taxpayer_ids:
|
||
from sqlalchemy import or_
|
||
declaration_stmt = select(TaxDeclaration).where(
|
||
and_(
|
||
TaxDeclaration.taxpayer_id.in_(taxpayer_ids),
|
||
TaxDeclaration.tax_period == period,
|
||
)
|
||
)
|
||
declaration_result = await db_session.execute(declaration_stmt)
|
||
declarations = declaration_result.scalars().all()
|
||
|
||
total_declaration = sum(d.sales_revenue or 0 for d in declarations)
|
||
print(f" 申报收入: {total_declaration:,.2f}元")
|
||
print(f" 申报记录数: {len(declarations)}")
|
||
else:
|
||
print(f" 申报收入: 0元(无税号信息)")
|
||
print(f" 申报记录数: 0")
|
||
|
||
# 查询分成协议
|
||
contract_stmt = select(RevenueSharingContract).where(
|
||
and_(
|
||
RevenueSharingContract.streamer_id == streamer_id,
|
||
RevenueSharingContract.contract_start_date <= end_date,
|
||
or_(
|
||
RevenueSharingContract.contract_end_date.is_(None),
|
||
RevenueSharingContract.contract_end_date >= start_date,
|
||
),
|
||
RevenueSharingContract.contract_status == "active",
|
||
)
|
||
).order_by(RevenueSharingContract.contract_start_date.desc()).limit(1)
|
||
|
||
contract_result = await db_session.execute(contract_stmt)
|
||
contract = contract_result.scalar_one_or_none()
|
||
|
||
if contract:
|
||
print(f" 分成比例: 主播{contract.streamer_ratio}%, 平台{contract.platform_ratio}%")
|
||
else:
|
||
print(f" 分成比例: 无协议")
|
||
|
||
print(f"\n【执行检测】")
|
||
|
||
# 执行检测
|
||
context = DetectionContext(
|
||
task_id=f"task_demo_{streamer_id}",
|
||
rule_id="rule_demo_revenue_integrity",
|
||
parameters={"streamer_id": streamer_id, "period": period},
|
||
db_session=db_session
|
||
)
|
||
|
||
result = await self.algorithm.detect(context)
|
||
self.print_result(result)
|
||
|
||
def _parse_period(self, period: str):
|
||
"""解析期间格式"""
|
||
if "-" in period and "Q" not in period:
|
||
year, month = map(int, period.split("-"))
|
||
start_date = datetime(year, month, 1)
|
||
if month == 12:
|
||
end_date = datetime(year + 1, 1, 1) - timedelta(days=1)
|
||
else:
|
||
end_date = datetime(year, month + 1, 1) - timedelta(days=1)
|
||
return start_date, end_date
|
||
else:
|
||
raise ValueError(f"不支持的期间格式: {period}")
|
||
|
||
async def demo_scenario_1_normal(self):
|
||
"""演示场景1:从数据库读取数据执行检测"""
|
||
# 这里需要指定一个真实存在的主播ID和期间
|
||
# 请根据实际数据库中的数据修改这里的值
|
||
streamer_id = "TEST_001" # 请修改为真实的主播ID
|
||
period = "2024-01" # 请修改为要检测的期间
|
||
|
||
await self.demo_scenario_by_streamer(streamer_id, period, "场景1: 收入完整性检测(数据库数据)")
|
||
|
||
async def demo_scenario_2_under_reporting(self):
|
||
"""演示场景2:从数据库读取数据执行检测"""
|
||
# 这里需要指定一个真实存在的主播ID和期间
|
||
streamer_id = "TEST_002" # 请修改为真实的主播ID
|
||
period = "2024-02" # 请修改为要检测的期间
|
||
|
||
await self.demo_scenario_by_streamer(streamer_id, period, "场景2: 收入完整性检测(数据库数据)")
|
||
|
||
async def demo_scenario_3_critical(self):
|
||
"""演示场景3:从数据库读取数据执行检测"""
|
||
# 这里需要指定一个真实存在的主播ID和期间
|
||
streamer_id = "TEST_003" # 请修改为真实的主播ID
|
||
period = "2024-03" # 请修改为要检测的期间
|
||
|
||
await self.demo_scenario_by_streamer(streamer_id, period, "场景3: 收入完整性检测(数据库数据)")
|
||
|
||
async def demo_scenario_4_no_data(self):
|
||
"""演示场景4:从数据库读取数据执行检测"""
|
||
# 这里需要指定一个真实存在的主播ID和期间
|
||
streamer_id = "TEST_004" # 请修改为真实的主播ID
|
||
period = "2024-04" # 请修改为要检测的期间
|
||
|
||
await self.demo_scenario_by_streamer(streamer_id, period, "场景4: 收入完整性检测(数据库数据)")
|
||
|
||
async def list_available_streamers(self):
|
||
"""列出数据库中可用的主播ID"""
|
||
self.print_header("可用主播列表")
|
||
|
||
async with AsyncSessionLocal() as db_session:
|
||
stmt = select(StreamerInfo).limit(10)
|
||
result = await db_session.execute(stmt)
|
||
streamers = result.scalars().all()
|
||
|
||
if not streamers:
|
||
print("\n数据库中没有找到主播信息")
|
||
print("请先向数据库中插入测试数据")
|
||
return []
|
||
|
||
print("\n【数据库中的主播信息】")
|
||
print(f"{'主播ID':<30} {'主播名称':<20} {'实体类型':<15} {'税号信息':<30}")
|
||
print("-" * 95)
|
||
|
||
available_streamers = []
|
||
for streamer in streamers:
|
||
tax_info = streamer.tax_registration_no or streamer.unified_social_credit_code or streamer.id_card_no or "无"
|
||
print(f"{streamer.streamer_id:<30} {streamer.streamer_name:<20} {streamer.entity_type:<15} {tax_info:<30}")
|
||
available_streamers.append(streamer.streamer_id)
|
||
|
||
print("\n【使用方法】")
|
||
print("将上述任意一个主播ID复制到测试场景中,例如:")
|
||
print(" streamer_id = 'ZB_XXXX' # 替换为上述任一主播ID")
|
||
print("\n" + "=" * 80)
|
||
|
||
return available_streamers
|
||
|
||
async def run_all_demos(self):
|
||
"""运行所有演示场景"""
|
||
print("\n" + "=" * 80)
|
||
print(" 收入完整性检测算法演示(真实数据库版)")
|
||
print(" 从数据库读取真实数据,演示不同的风险场景和检测逻辑")
|
||
print("=" * 80)
|
||
|
||
# 先列出可用的主播
|
||
available_streamers = await self.list_available_streamers()
|
||
|
||
if not available_streamers:
|
||
print("\n✗ 没有可用的主播数据,无法执行演示")
|
||
return
|
||
|
||
print("\n【使用说明】")
|
||
print("请修改脚本中的主播ID为上面列出的可用主播ID,然后重新运行")
|
||
print("例如:修改 demo_scenario_1_normal() 方法中的 streamer_id 变量")
|
||
print("\n" + "=" * 80)
|
||
print("\n【风险等级说明】")
|
||
print(" NONE: 正常,无风险")
|
||
print(" LOW: 低度风险,持续监控")
|
||
print(" MEDIUM: 中度风险,需要核查")
|
||
print(" HIGH: 高度风险,重点关注")
|
||
print(" CRITICAL: 严重风险,立即处理")
|
||
print("=" * 80 + "\n")
|
||
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
demo = RevenueIntegrityDemo()
|
||
# await demo.run_all_demos()
|
||
await demo.demo_scenario_2_under_reporting()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|