326 lines
12 KiB
Python
326 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
收入完整性检测算法测试数据导入脚本
|
|
将JSON格式的测试数据导入到数据库中
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime
|
|
from typing import List, Dict, Any
|
|
|
|
sys.path.append('/Users/liulujian/Documents/code/deeprisk-claude-1/backend')
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from sqlalchemy import text, DateTime
|
|
from loguru import logger
|
|
|
|
# 导入模型
|
|
from app.models.streamer import StreamerInfo, PlatformRecharge
|
|
from app.models.contract import RevenueSharingContract
|
|
from app.models.tax_declaration import TaxDeclaration
|
|
|
|
logger.add("import_data.log", rotation="100 MB", level="INFO")
|
|
|
|
class RevenueTestDataImporter:
|
|
"""收入完整性检测专项测试数据导入器"""
|
|
|
|
def __init__(self, database_url: str = "sqlite+aiosqlite:///./test_revenue.db"):
|
|
self.database_url = database_url
|
|
self.engine = create_async_engine(
|
|
database_url,
|
|
echo=False,
|
|
future=True
|
|
)
|
|
self.SessionLocal = async_sessionmaker(
|
|
bind=self.engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False
|
|
)
|
|
|
|
async def import_all(self):
|
|
"""导入所有测试数据"""
|
|
logger.info("=" * 80)
|
|
logger.info("开始导入收入完整性检测专项测试数据")
|
|
logger.info("=" * 80)
|
|
|
|
# 加载JSON数据
|
|
data_dir = "/Users/liulujian/Documents/code/deeprisk-claude-1/backend/test_data/revenue_test"
|
|
|
|
with open(f"{data_dir}/streamers.json", 'r', encoding='utf-8') as f:
|
|
streamers = json.load(f)
|
|
|
|
with open(f"{data_dir}/contracts.json", 'r', encoding='utf-8') as f:
|
|
contracts = json.load(f)
|
|
|
|
with open(f"{data_dir}/recharges.json", 'r', encoding='utf-8') as f:
|
|
recharges = json.load(f)
|
|
|
|
with open(f"{data_dir}/tax_declarations.json", 'r', encoding='utf-8') as f:
|
|
tax_declarations = json.load(f)
|
|
|
|
async with self.SessionLocal() as session:
|
|
try:
|
|
# 创建所有表
|
|
await self.create_tables(session)
|
|
|
|
# 导入数据
|
|
logger.info("导入主播信息...")
|
|
await self.import_streamers(session, streamers)
|
|
|
|
logger.info("导入分成协议...")
|
|
await self.import_contracts(session, contracts)
|
|
|
|
logger.info("导入充值记录...")
|
|
await self.import_recharges(session, recharges)
|
|
|
|
logger.info("导入税务申报...")
|
|
await self.import_tax_declarations(session, tax_declarations)
|
|
|
|
# 提交事务
|
|
await session.commit()
|
|
logger.info("✅ 所有数据导入成功!")
|
|
|
|
# 验证数据
|
|
await self.verify_data(session)
|
|
|
|
except Exception as e:
|
|
logger.error(f"导入失败: {str(e)}", exc_info=True)
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
async def create_tables(self, session: AsyncSession):
|
|
"""创建数据库表"""
|
|
logger.info("创建数据库表...")
|
|
|
|
# 获取模型映射的表
|
|
from app.models import Base
|
|
from app.models.streamer import StreamerInfo, PlatformRecharge, McnAgency
|
|
from app.models.contract import RevenueSharingContract
|
|
from app.models.tax_declaration import TaxDeclaration
|
|
|
|
# 创建所有表
|
|
async with self.engine.begin() as conn:
|
|
# 删除所有表(如果存在)
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
# 创建所有表
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
logger.info("✅ 数据库表创建完成")
|
|
|
|
async def import_streamers(self, session: AsyncSession, streamers: List[Dict]):
|
|
"""导入主播信息"""
|
|
from app.models.streamer import StreamerInfo
|
|
|
|
logger.info(f"准备导入 {len(streamers)} 条主播信息...")
|
|
|
|
for streamer_data in streamers:
|
|
try:
|
|
streamer = StreamerInfo(
|
|
streamer_id=streamer_data['streamer_id'],
|
|
streamer_name=streamer_data['streamer_name'],
|
|
entity_type='individual',
|
|
phone_number='13800138000',
|
|
bank_account_no='6222' + '0' * 14,
|
|
bank_name='中国银行',
|
|
status='active',
|
|
registration_date=datetime.now().date()
|
|
)
|
|
|
|
# 根据tax_no设置税务信息
|
|
if 'tax_no' in streamer_data:
|
|
streamer.unified_social_credit_code = streamer_data['tax_no']
|
|
streamer.tax_registration_no = streamer_data['tax_no']
|
|
|
|
session.add(streamer)
|
|
|
|
except Exception as e:
|
|
logger.error(f"导入主播失败 {streamer_data['streamer_id']}: {str(e)}")
|
|
raise
|
|
|
|
logger.info("✅ 主播信息导入完成")
|
|
|
|
async def import_contracts(self, session: AsyncSession, contracts: List[Dict]):
|
|
"""导入分成协议"""
|
|
from app.models.contract import RevenueSharingContract
|
|
|
|
logger.info(f"准备导入 {len(contracts)} 条分成协议...")
|
|
|
|
for contract_data in contracts:
|
|
try:
|
|
contract = RevenueSharingContract(
|
|
contract_id=contract_data['contract_id'],
|
|
contract_no=contract_data['contract_id'],
|
|
contract_type='tip_sharing',
|
|
streamer_id=contract_data['streamer_id'],
|
|
streamer_name=contract_data['streamer_id'],
|
|
streamer_entity_type='individual',
|
|
platform_party=contract_data['platform'],
|
|
platform_credit_code='91110000000000000X',
|
|
revenue_type='tip',
|
|
streamer_ratio=float(contract_data['share_ratio']),
|
|
platform_ratio=float(contract_data['platform_share_ratio']),
|
|
settlement_cycle='monthly',
|
|
contract_start_date=datetime.strptime(contract_data['start_date'], '%Y-%m-%d').date(),
|
|
contract_end_date=datetime.strptime(contract_data['end_date'], '%Y-%m-%d').date(),
|
|
contract_status=contract_data['status']
|
|
)
|
|
|
|
session.add(contract)
|
|
|
|
except Exception as e:
|
|
logger.error(f"导入分成协议失败 {contract_data['contract_id']}: {str(e)}")
|
|
raise
|
|
|
|
logger.info("✅ 分成协议导入完成")
|
|
|
|
async def import_recharges(self, session: AsyncSession, recharges: List[Dict]):
|
|
"""导入充值记录"""
|
|
from app.models.streamer import PlatformRecharge
|
|
|
|
logger.info(f"准备导入 {len(recharges)} 条充值记录...")
|
|
|
|
for recharge_data in recharges:
|
|
try:
|
|
recharge = PlatformRecharge(
|
|
recharge_id=recharge_data['recharge_id'],
|
|
user_id=recharge_data['streamer_id'],
|
|
user_name=recharge_data['streamer_id'],
|
|
recharge_amount=float(recharge_data['recharge_amount']),
|
|
recharge_time=datetime.strptime(recharge_data['recharge_date'], '%Y-%m-%d'),
|
|
payment_method=recharge_data['payment_method'],
|
|
transaction_no='TXN_' + recharge_data['recharge_id'],
|
|
platform_order_no='PO_' + recharge_data['recharge_id'],
|
|
actual_amount_cny=float(recharge_data['recharge_amount']),
|
|
total_coins=float(recharge_data['recharge_amount']) * 10,
|
|
status=recharge_data['payment_status'],
|
|
withdrawal_status='not_withdrawn'
|
|
)
|
|
|
|
session.add(recharge)
|
|
|
|
except Exception as e:
|
|
logger.error(f"导入充值记录失败 {recharge_data['recharge_id']}: {str(e)}")
|
|
raise
|
|
|
|
logger.info("✅ 充值记录导入完成")
|
|
|
|
async def import_tax_declarations(self, session: AsyncSession, declarations: List[Dict]):
|
|
"""导入税务申报"""
|
|
from app.models.tax_declaration import TaxDeclaration
|
|
|
|
logger.info(f"准备导入 {len(declarations)} 条税务申报...")
|
|
|
|
for decl_data in declarations:
|
|
try:
|
|
declaration = TaxDeclaration(
|
|
vat_declaration_id=decl_data['declaration_id'],
|
|
taxpayer_name=decl_data['streamer_id'],
|
|
taxpayer_id=decl_data['tax_no'],
|
|
tax_period=decl_data['declaration_period'],
|
|
declaration_date=datetime.strptime(decl_data['declaration_date'], '%Y-%m-%d').date(),
|
|
tax_authority_code='TAX001',
|
|
tax_authority_name='税务局',
|
|
taxpayer_type='small_scale',
|
|
tax_rate=float(decl_data['tax_rate']),
|
|
sales_revenue=float(decl_data['declared_amount']),
|
|
sales_revenue_taxable=float(decl_data['declared_amount']),
|
|
output_tax=float(decl_data['tax_amount']),
|
|
input_tax=0,
|
|
input_tax_deductible=0,
|
|
tax_payable=float(decl_data['tax_amount']),
|
|
tax_to_pay=float(decl_data['tax_amount']),
|
|
refund_amount=0,
|
|
declaration_status=decl_data['status']
|
|
)
|
|
|
|
session.add(declaration)
|
|
|
|
except Exception as e:
|
|
logger.error(f"导入税务申报失败 {decl_data['declaration_id']}: {str(e)}")
|
|
raise
|
|
|
|
logger.info("✅ 税务申报导入完成")
|
|
|
|
async def verify_data(self, session: AsyncSession):
|
|
"""验证导入的数据"""
|
|
logger.info("=" * 80)
|
|
logger.info("验证导入的数据")
|
|
logger.info("=" * 80)
|
|
|
|
# 统计各表数据量
|
|
tables = {
|
|
'streamer_info': 'streamer_info',
|
|
'platform_recharge': 'platform_recharge',
|
|
'contract': 'contract',
|
|
'tax_declaration': 'tax_declaration'
|
|
}
|
|
|
|
for table_name, table in tables.items():
|
|
result = await session.execute(text(f"SELECT COUNT(*) FROM {table}"))
|
|
count = result.scalar()
|
|
logger.info(f"{table_name}: {count} 条记录")
|
|
|
|
# 验证主播数据
|
|
result = await session.execute(text("SELECT streamer_id FROM streamer_info WHERE streamer_id LIKE 'TEST_%'"))
|
|
streamer_ids = [row[0] for row in result.fetchall()]
|
|
logger.info(f"测试主播ID: {streamer_ids}")
|
|
|
|
# 验证充值总额
|
|
for streamer_id in streamer_ids[:3]: # 只验证前3个
|
|
result = await session.execute(
|
|
text(f"SELECT SUM(recharge_amount) FROM platform_recharge WHERE user_id = '{streamer_id}'")
|
|
)
|
|
total = result.scalar() or 0
|
|
logger.info(f"{streamer_id} 充值总额: ¥{total:,.2f}")
|
|
|
|
logger.info("=" * 80)
|
|
logger.info("✅ 数据验证完成")
|
|
logger.info("=" * 80)
|
|
|
|
async def close(self):
|
|
"""关闭数据库连接"""
|
|
await self.engine.dispose()
|
|
|
|
async def main():
|
|
"""主函数"""
|
|
print("=" * 80)
|
|
print("收入完整性检测算法测试数据导入工具")
|
|
print("=" * 80)
|
|
print()
|
|
print("本工具将把JSON格式的测试数据导入到数据库中")
|
|
print("导入完成后即可运行算法测试")
|
|
print("=" * 80)
|
|
print()
|
|
|
|
importer = RevenueTestDataImporter()
|
|
|
|
try:
|
|
await importer.import_all()
|
|
print("\n✅ 测试数据导入成功!")
|
|
print("\n📊 导入摘要:")
|
|
print(" - 数据库: test_revenue.db")
|
|
print(" - 主播信息: 8条")
|
|
print(" - 分成协议: 24条")
|
|
print(" - 充值记录: 240条")
|
|
print(" - 税务申报: 8条")
|
|
print("\n🔬 下一步:")
|
|
print(" 1. 运行算法测试: python scripts/test_revenue_algorithm.py")
|
|
print(" 2. 使用API测试: curl -X POST http://localhost:8000/api/v1/risk-detection/execute ...")
|
|
print(" 3. 使用前端测试: http://localhost:3000/risk-detection/execute")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ 导入失败: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
finally:
|
|
await importer.close()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|