deep-risk/backend/scripts/import_revenue_test_data.py
2025-12-14 20:08:27 +08:00

326 lines
12 KiB
Python

#!/usr/bin/env python3
"""
收入完整性检测算法测试数据导入脚本
将JSON格式的测试数据导入到数据库中
"""
import os
import sys
import json
import asyncio
from datetime import datetime
from typing import List, Dict, Any
sys.path.append('/Users/liulujian/Documents/code/deeprisk-claude-1/backend')
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import text, DateTime
from loguru import logger
# 导入模型
from app.models.streamer import StreamerInfo, PlatformRecharge
from app.models.contract import RevenueSharingContract
from app.models.tax_declaration import TaxDeclaration
logger.add("import_data.log", rotation="100 MB", level="INFO")
class RevenueTestDataImporter:
"""收入完整性检测专项测试数据导入器"""
def __init__(self, database_url: str = "sqlite+aiosqlite:///./test_revenue.db"):
self.database_url = database_url
self.engine = create_async_engine(
database_url,
echo=False,
future=True
)
self.SessionLocal = async_sessionmaker(
bind=self.engine,
class_=AsyncSession,
expire_on_commit=False
)
async def import_all(self):
"""导入所有测试数据"""
logger.info("=" * 80)
logger.info("开始导入收入完整性检测专项测试数据")
logger.info("=" * 80)
# 加载JSON数据
data_dir = "/Users/liulujian/Documents/code/deeprisk-claude-1/backend/test_data/revenue_test"
with open(f"{data_dir}/streamers.json", 'r', encoding='utf-8') as f:
streamers = json.load(f)
with open(f"{data_dir}/contracts.json", 'r', encoding='utf-8') as f:
contracts = json.load(f)
with open(f"{data_dir}/recharges.json", 'r', encoding='utf-8') as f:
recharges = json.load(f)
with open(f"{data_dir}/tax_declarations.json", 'r', encoding='utf-8') as f:
tax_declarations = json.load(f)
async with self.SessionLocal() as session:
try:
# 创建所有表
await self.create_tables(session)
# 导入数据
logger.info("导入主播信息...")
await self.import_streamers(session, streamers)
logger.info("导入分成协议...")
await self.import_contracts(session, contracts)
logger.info("导入充值记录...")
await self.import_recharges(session, recharges)
logger.info("导入税务申报...")
await self.import_tax_declarations(session, tax_declarations)
# 提交事务
await session.commit()
logger.info("✅ 所有数据导入成功!")
# 验证数据
await self.verify_data(session)
except Exception as e:
logger.error(f"导入失败: {str(e)}", exc_info=True)
await session.rollback()
raise
finally:
await session.close()
async def create_tables(self, session: AsyncSession):
"""创建数据库表"""
logger.info("创建数据库表...")
# 获取模型映射的表
from app.models import Base
from app.models.streamer import StreamerInfo, PlatformRecharge, McnAgency
from app.models.contract import RevenueSharingContract
from app.models.tax_declaration import TaxDeclaration
# 创建所有表
async with self.engine.begin() as conn:
# 删除所有表(如果存在)
await conn.run_sync(Base.metadata.drop_all)
# 创建所有表
await conn.run_sync(Base.metadata.create_all)
logger.info("✅ 数据库表创建完成")
async def import_streamers(self, session: AsyncSession, streamers: List[Dict]):
"""导入主播信息"""
from app.models.streamer import StreamerInfo
logger.info(f"准备导入 {len(streamers)} 条主播信息...")
for streamer_data in streamers:
try:
streamer = StreamerInfo(
streamer_id=streamer_data['streamer_id'],
streamer_name=streamer_data['streamer_name'],
entity_type='individual',
phone_number='13800138000',
bank_account_no='6222' + '0' * 14,
bank_name='中国银行',
status='active',
registration_date=datetime.now().date()
)
# 根据tax_no设置税务信息
if 'tax_no' in streamer_data:
streamer.unified_social_credit_code = streamer_data['tax_no']
streamer.tax_registration_no = streamer_data['tax_no']
session.add(streamer)
except Exception as e:
logger.error(f"导入主播失败 {streamer_data['streamer_id']}: {str(e)}")
raise
logger.info("✅ 主播信息导入完成")
async def import_contracts(self, session: AsyncSession, contracts: List[Dict]):
"""导入分成协议"""
from app.models.contract import RevenueSharingContract
logger.info(f"准备导入 {len(contracts)} 条分成协议...")
for contract_data in contracts:
try:
contract = RevenueSharingContract(
contract_id=contract_data['contract_id'],
contract_no=contract_data['contract_id'],
contract_type='tip_sharing',
streamer_id=contract_data['streamer_id'],
streamer_name=contract_data['streamer_id'],
streamer_entity_type='individual',
platform_party=contract_data['platform'],
platform_credit_code='91110000000000000X',
revenue_type='tip',
streamer_ratio=float(contract_data['share_ratio']),
platform_ratio=float(contract_data['platform_share_ratio']),
settlement_cycle='monthly',
contract_start_date=datetime.strptime(contract_data['start_date'], '%Y-%m-%d').date(),
contract_end_date=datetime.strptime(contract_data['end_date'], '%Y-%m-%d').date(),
contract_status=contract_data['status']
)
session.add(contract)
except Exception as e:
logger.error(f"导入分成协议失败 {contract_data['contract_id']}: {str(e)}")
raise
logger.info("✅ 分成协议导入完成")
async def import_recharges(self, session: AsyncSession, recharges: List[Dict]):
"""导入充值记录"""
from app.models.streamer import PlatformRecharge
logger.info(f"准备导入 {len(recharges)} 条充值记录...")
for recharge_data in recharges:
try:
recharge = PlatformRecharge(
recharge_id=recharge_data['recharge_id'],
user_id=recharge_data['streamer_id'],
user_name=recharge_data['streamer_id'],
recharge_amount=float(recharge_data['recharge_amount']),
recharge_time=datetime.strptime(recharge_data['recharge_date'], '%Y-%m-%d'),
payment_method=recharge_data['payment_method'],
transaction_no='TXN_' + recharge_data['recharge_id'],
platform_order_no='PO_' + recharge_data['recharge_id'],
actual_amount_cny=float(recharge_data['recharge_amount']),
total_coins=float(recharge_data['recharge_amount']) * 10,
status=recharge_data['payment_status'],
withdrawal_status='not_withdrawn'
)
session.add(recharge)
except Exception as e:
logger.error(f"导入充值记录失败 {recharge_data['recharge_id']}: {str(e)}")
raise
logger.info("✅ 充值记录导入完成")
async def import_tax_declarations(self, session: AsyncSession, declarations: List[Dict]):
"""导入税务申报"""
from app.models.tax_declaration import TaxDeclaration
logger.info(f"准备导入 {len(declarations)} 条税务申报...")
for decl_data in declarations:
try:
declaration = TaxDeclaration(
vat_declaration_id=decl_data['declaration_id'],
taxpayer_name=decl_data['streamer_id'],
taxpayer_id=decl_data['tax_no'],
tax_period=decl_data['declaration_period'],
declaration_date=datetime.strptime(decl_data['declaration_date'], '%Y-%m-%d').date(),
tax_authority_code='TAX001',
tax_authority_name='税务局',
taxpayer_type='small_scale',
tax_rate=float(decl_data['tax_rate']),
sales_revenue=float(decl_data['declared_amount']),
sales_revenue_taxable=float(decl_data['declared_amount']),
output_tax=float(decl_data['tax_amount']),
input_tax=0,
input_tax_deductible=0,
tax_payable=float(decl_data['tax_amount']),
tax_to_pay=float(decl_data['tax_amount']),
refund_amount=0,
declaration_status=decl_data['status']
)
session.add(declaration)
except Exception as e:
logger.error(f"导入税务申报失败 {decl_data['declaration_id']}: {str(e)}")
raise
logger.info("✅ 税务申报导入完成")
async def verify_data(self, session: AsyncSession):
"""验证导入的数据"""
logger.info("=" * 80)
logger.info("验证导入的数据")
logger.info("=" * 80)
# 统计各表数据量
tables = {
'streamer_info': 'streamer_info',
'platform_recharge': 'platform_recharge',
'contract': 'contract',
'tax_declaration': 'tax_declaration'
}
for table_name, table in tables.items():
result = await session.execute(text(f"SELECT COUNT(*) FROM {table}"))
count = result.scalar()
logger.info(f"{table_name}: {count} 条记录")
# 验证主播数据
result = await session.execute(text("SELECT streamer_id FROM streamer_info WHERE streamer_id LIKE 'TEST_%'"))
streamer_ids = [row[0] for row in result.fetchall()]
logger.info(f"测试主播ID: {streamer_ids}")
# 验证充值总额
for streamer_id in streamer_ids[:3]: # 只验证前3个
result = await session.execute(
text(f"SELECT SUM(recharge_amount) FROM platform_recharge WHERE user_id = '{streamer_id}'")
)
total = result.scalar() or 0
logger.info(f"{streamer_id} 充值总额: ¥{total:,.2f}")
logger.info("=" * 80)
logger.info("✅ 数据验证完成")
logger.info("=" * 80)
async def close(self):
"""关闭数据库连接"""
await self.engine.dispose()
async def main():
"""主函数"""
print("=" * 80)
print("收入完整性检测算法测试数据导入工具")
print("=" * 80)
print()
print("本工具将把JSON格式的测试数据导入到数据库中")
print("导入完成后即可运行算法测试")
print("=" * 80)
print()
importer = RevenueTestDataImporter()
try:
await importer.import_all()
print("\n✅ 测试数据导入成功!")
print("\n📊 导入摘要:")
print(" - 数据库: test_revenue.db")
print(" - 主播信息: 8条")
print(" - 分成协议: 24条")
print(" - 充值记录: 240条")
print(" - 税务申报: 8条")
print("\n🔬 下一步:")
print(" 1. 运行算法测试: python scripts/test_revenue_algorithm.py")
print(" 2. 使用API测试: curl -X POST http://localhost:8000/api/v1/risk-detection/execute ...")
print(" 3. 使用前端测试: http://localhost:3000/risk-detection/execute")
except Exception as e:
print(f"\n❌ 导入失败: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)
finally:
await importer.close()
if __name__ == "__main__":
asyncio.run(main())