#!/usr/bin/env python3 """ 收入完整性检测算法测试数据导入脚本 将JSON格式的测试数据导入到数据库中 """ import os import sys import json import asyncio from datetime import datetime from typing import List, Dict, Any sys.path.append('/Users/liulujian/Documents/code/deeprisk-claude-1/backend') from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy import text, DateTime from loguru import logger # 导入模型 from app.models.streamer import StreamerInfo, PlatformRecharge from app.models.contract import RevenueSharingContract from app.models.tax_declaration import TaxDeclaration logger.add("import_data.log", rotation="100 MB", level="INFO") class RevenueTestDataImporter: """收入完整性检测专项测试数据导入器""" def __init__(self, database_url: str = "sqlite+aiosqlite:///./test_revenue.db"): self.database_url = database_url self.engine = create_async_engine( database_url, echo=False, future=True ) self.SessionLocal = async_sessionmaker( bind=self.engine, class_=AsyncSession, expire_on_commit=False ) async def import_all(self): """导入所有测试数据""" logger.info("=" * 80) logger.info("开始导入收入完整性检测专项测试数据") logger.info("=" * 80) # 加载JSON数据 data_dir = "/Users/liulujian/Documents/code/deeprisk-claude-1/backend/test_data/revenue_test" with open(f"{data_dir}/streamers.json", 'r', encoding='utf-8') as f: streamers = json.load(f) with open(f"{data_dir}/contracts.json", 'r', encoding='utf-8') as f: contracts = json.load(f) with open(f"{data_dir}/recharges.json", 'r', encoding='utf-8') as f: recharges = json.load(f) with open(f"{data_dir}/tax_declarations.json", 'r', encoding='utf-8') as f: tax_declarations = json.load(f) async with self.SessionLocal() as session: try: # 创建所有表 await self.create_tables(session) # 导入数据 logger.info("导入主播信息...") await self.import_streamers(session, streamers) logger.info("导入分成协议...") await self.import_contracts(session, contracts) logger.info("导入充值记录...") await self.import_recharges(session, recharges) logger.info("导入税务申报...") await self.import_tax_declarations(session, tax_declarations) # 提交事务 await session.commit() logger.info("✅ 所有数据导入成功!") # 验证数据 await self.verify_data(session) except Exception as e: logger.error(f"导入失败: {str(e)}", exc_info=True) await session.rollback() raise finally: await session.close() async def create_tables(self, session: AsyncSession): """创建数据库表""" logger.info("创建数据库表...") # 获取模型映射的表 from app.models import Base from app.models.streamer import StreamerInfo, PlatformRecharge, McnAgency from app.models.contract import RevenueSharingContract from app.models.tax_declaration import TaxDeclaration # 创建所有表 async with self.engine.begin() as conn: # 删除所有表(如果存在) await conn.run_sync(Base.metadata.drop_all) # 创建所有表 await conn.run_sync(Base.metadata.create_all) logger.info("✅ 数据库表创建完成") async def import_streamers(self, session: AsyncSession, streamers: List[Dict]): """导入主播信息""" from app.models.streamer import StreamerInfo logger.info(f"准备导入 {len(streamers)} 条主播信息...") for streamer_data in streamers: try: streamer = StreamerInfo( streamer_id=streamer_data['streamer_id'], streamer_name=streamer_data['streamer_name'], entity_type='individual', phone_number='13800138000', bank_account_no='6222' + '0' * 14, bank_name='中国银行', status='active', registration_date=datetime.now().date() ) # 根据tax_no设置税务信息 if 'tax_no' in streamer_data: streamer.unified_social_credit_code = streamer_data['tax_no'] streamer.tax_registration_no = streamer_data['tax_no'] session.add(streamer) except Exception as e: logger.error(f"导入主播失败 {streamer_data['streamer_id']}: {str(e)}") raise logger.info("✅ 主播信息导入完成") async def import_contracts(self, session: AsyncSession, contracts: List[Dict]): """导入分成协议""" from app.models.contract import RevenueSharingContract logger.info(f"准备导入 {len(contracts)} 条分成协议...") for contract_data in contracts: try: contract = RevenueSharingContract( contract_id=contract_data['contract_id'], contract_no=contract_data['contract_id'], contract_type='tip_sharing', streamer_id=contract_data['streamer_id'], streamer_name=contract_data['streamer_id'], streamer_entity_type='individual', platform_party=contract_data['platform'], platform_credit_code='91110000000000000X', revenue_type='tip', streamer_ratio=float(contract_data['share_ratio']), platform_ratio=float(contract_data['platform_share_ratio']), settlement_cycle='monthly', contract_start_date=datetime.strptime(contract_data['start_date'], '%Y-%m-%d').date(), contract_end_date=datetime.strptime(contract_data['end_date'], '%Y-%m-%d').date(), contract_status=contract_data['status'] ) session.add(contract) except Exception as e: logger.error(f"导入分成协议失败 {contract_data['contract_id']}: {str(e)}") raise logger.info("✅ 分成协议导入完成") async def import_recharges(self, session: AsyncSession, recharges: List[Dict]): """导入充值记录""" from app.models.streamer import PlatformRecharge logger.info(f"准备导入 {len(recharges)} 条充值记录...") for recharge_data in recharges: try: recharge = PlatformRecharge( recharge_id=recharge_data['recharge_id'], user_id=recharge_data['streamer_id'], user_name=recharge_data['streamer_id'], recharge_amount=float(recharge_data['recharge_amount']), recharge_time=datetime.strptime(recharge_data['recharge_date'], '%Y-%m-%d'), payment_method=recharge_data['payment_method'], transaction_no='TXN_' + recharge_data['recharge_id'], platform_order_no='PO_' + recharge_data['recharge_id'], actual_amount_cny=float(recharge_data['recharge_amount']), total_coins=float(recharge_data['recharge_amount']) * 10, status=recharge_data['payment_status'], withdrawal_status='not_withdrawn' ) session.add(recharge) except Exception as e: logger.error(f"导入充值记录失败 {recharge_data['recharge_id']}: {str(e)}") raise logger.info("✅ 充值记录导入完成") async def import_tax_declarations(self, session: AsyncSession, declarations: List[Dict]): """导入税务申报""" from app.models.tax_declaration import TaxDeclaration logger.info(f"准备导入 {len(declarations)} 条税务申报...") for decl_data in declarations: try: declaration = TaxDeclaration( vat_declaration_id=decl_data['declaration_id'], taxpayer_name=decl_data['streamer_id'], taxpayer_id=decl_data['tax_no'], tax_period=decl_data['declaration_period'], declaration_date=datetime.strptime(decl_data['declaration_date'], '%Y-%m-%d').date(), tax_authority_code='TAX001', tax_authority_name='税务局', taxpayer_type='small_scale', tax_rate=float(decl_data['tax_rate']), sales_revenue=float(decl_data['declared_amount']), sales_revenue_taxable=float(decl_data['declared_amount']), output_tax=float(decl_data['tax_amount']), input_tax=0, input_tax_deductible=0, tax_payable=float(decl_data['tax_amount']), tax_to_pay=float(decl_data['tax_amount']), refund_amount=0, declaration_status=decl_data['status'] ) session.add(declaration) except Exception as e: logger.error(f"导入税务申报失败 {decl_data['declaration_id']}: {str(e)}") raise logger.info("✅ 税务申报导入完成") async def verify_data(self, session: AsyncSession): """验证导入的数据""" logger.info("=" * 80) logger.info("验证导入的数据") logger.info("=" * 80) # 统计各表数据量 tables = { 'streamer_info': 'streamer_info', 'platform_recharge': 'platform_recharge', 'contract': 'contract', 'tax_declaration': 'tax_declaration' } for table_name, table in tables.items(): result = await session.execute(text(f"SELECT COUNT(*) FROM {table}")) count = result.scalar() logger.info(f"{table_name}: {count} 条记录") # 验证主播数据 result = await session.execute(text("SELECT streamer_id FROM streamer_info WHERE streamer_id LIKE 'TEST_%'")) streamer_ids = [row[0] for row in result.fetchall()] logger.info(f"测试主播ID: {streamer_ids}") # 验证充值总额 for streamer_id in streamer_ids[:3]: # 只验证前3个 result = await session.execute( text(f"SELECT SUM(recharge_amount) FROM platform_recharge WHERE user_id = '{streamer_id}'") ) total = result.scalar() or 0 logger.info(f"{streamer_id} 充值总额: ¥{total:,.2f}") logger.info("=" * 80) logger.info("✅ 数据验证完成") logger.info("=" * 80) async def close(self): """关闭数据库连接""" await self.engine.dispose() async def main(): """主函数""" print("=" * 80) print("收入完整性检测算法测试数据导入工具") print("=" * 80) print() print("本工具将把JSON格式的测试数据导入到数据库中") print("导入完成后即可运行算法测试") print("=" * 80) print() importer = RevenueTestDataImporter() try: await importer.import_all() print("\n✅ 测试数据导入成功!") print("\n📊 导入摘要:") print(" - 数据库: test_revenue.db") print(" - 主播信息: 8条") print(" - 分成协议: 24条") print(" - 充值记录: 240条") print(" - 税务申报: 8条") print("\n🔬 下一步:") print(" 1. 运行算法测试: python scripts/test_revenue_algorithm.py") print(" 2. 使用API测试: curl -X POST http://localhost:8000/api/v1/risk-detection/execute ...") print(" 3. 使用前端测试: http://localhost:3000/risk-detection/execute") except Exception as e: print(f"\n❌ 导入失败: {str(e)}") import traceback traceback.print_exc() sys.exit(1) finally: await importer.close() if __name__ == "__main__": asyncio.run(main())