#!/usr/bin/env python3 """ 数据库数据验证脚本 检查数据的完整性和一致性 """ import asyncio import sys from pathlib import Path from sqlalchemy import text from app.database import AsyncSessionLocal from loguru import logger logger.remove() logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level: <8} | {message}") async def check_foreign_keys(session): """检查外键约束""" logger.info("检查外键约束...") checks = [ ("streamer_info", "mcn_agency_id", "mcn_agency"), ("platform_recharge", "streamer_id", "streamer_info"), ("contract", "streamer_id", "streamer_info"), ("order", "streamer_id", "streamer_info"), ] issues = [] for table, fk_col, ref_table in checks: result = await session.execute(text(f""" SELECT COUNT(*) FROM {table} WHERE {fk_col} IS NOT NULL AND {fk_col} NOT IN (SELECT id FROM {ref_table}) """)) orphaned = result.scalar() if orphaned > 0: issues.append(f"❌ {table}.{fk_col} 有 {orphaned} 条孤儿记录") else: logger.info(f"✅ {table}.{fk_col} 外键正常") return issues async def check_data_consistency(session): """检查数据一致性""" logger.info("\n检查数据一致性...") issues = [] # 检查主播的手机号格式 result = await session.execute(text(""" SELECT COUNT(*) FROM streamer_info WHERE phone_number !~ '^[0-9]{11}$' """)) invalid_phones = result.scalar() if invalid_phones > 0: issues.append(f"❌ 有 {invalid_phones} 个主播的手机号格式不正确") else: logger.info("✅ 主播手机号格式正确") # 检查身份证号格式 result = await session.execute(text(""" SELECT COUNT(*) FROM streamer_info WHERE id_card_no IS NOT NULL AND id_card_no !~ '^[0-9]{17}[0-9X]$' """)) invalid_ids = result.scalar() if invalid_ids > 0: issues.append(f"❌ 有 {invalid_ids} 个主播的身份证号格式不正确") else: logger.info("✅ 主播身份证号格式正确") # 检查交易金额不能为负 result = await session.execute(text(""" SELECT COUNT(*) FROM platform_recharge WHERE recharge_amount < 0 """)) negative_amounts = result.scalar() if negative_amounts > 0: issues.append(f"❌ 有 {negative_amounts} 条充值记录的金额为负数") else: logger.info("✅ 充值金额数据正常") return issues async def check_business_logic(session): """检查业务逻辑""" logger.info("\n检查业务逻辑...") issues = [] # 检查分成比例总和 result = await session.execute(text(""" SELECT COUNT(*) FROM contract WHERE (platform_ratio + streamer_ratio) != 100 """)) wrong_ratios = result.scalar() if wrong_ratios > 0: issues.append(f"❌ 有 {wrong_ratios} 份合同的分成比例总和不为100%") else: logger.info("✅ 分成比例正确") # 检查合同日期 result = await session.execute(text(""" SELECT COUNT(*) FROM contract WHERE contract_start_date > contract_end_date """)) wrong_dates = result.scalar() if wrong_dates > 0: issues.append(f"❌ 有 {wrong_dates} 份合同的开始日期晚于结束日期") else: logger.info("✅ 合同日期正确") return issues async def show_statistics(session): """显示统计信息""" logger.info("\n数据统计:") logger.info("=" * 60) stats = [ ("MCN机构数量", "SELECT COUNT(*) FROM mcn_agency"), ("主播数量", "SELECT COUNT(*) FROM streamer_info"), ("平台充值记录", "SELECT COUNT(*) FROM platform_recharge"), ("分成协议", "SELECT COUNT(*) FROM contract"), ("电商订单", "SELECT COUNT(*) FROM `order`"), ("结算单", "SELECT COUNT(*) FROM settlement"), ("风险检测规则", "SELECT COUNT(*) FROM detection_rule"), ] for name, query in stats: try: result = await session.execute(text(query.replace('`order`', '"order"'))) count = result.scalar() logger.info(f" {name:20}: {count:6}") except Exception as e: logger.warning(f" {name:20}: 查询失败 - {str(e)[:50]}") async def show_samples(session, table: str, limit: int = 3): """显示表中的示例数据""" logger.info(f"\n{table} 示例数据:") try: result = await session.execute(text(f"SELECT * FROM {table} LIMIT {limit}")) rows = result.fetchall() if rows: # 获取列名 cols = list(result.keys()) logger.info(" " + " | ".join(f"{c[:15]:15}" for c in cols[:5])) for row in rows[:3]: logger.info(" " + " | ".join(f"{str(row[i])[:15]:15}" for i in range(min(5, len(row))))) else: logger.info(" (空表)") except Exception as e: logger.warning(f" 查询失败: {e}") async def main(): """主函数""" logger.info("\n" + "=" * 60) logger.info("🔍 数据库数据验证") logger.info("=" * 60 + "\n") async with AsyncSessionLocal() as session: try: # 1. 显示统计信息 await show_statistics(session) # 2. 检查外键 fk_issues = await check_foreign_keys(session) # 3. 检查数据一致性 consistency_issues = await check_data_consistency(session) # 4. 检查业务逻辑 business_issues = await check_business_logic(session) # 5. 显示示例数据 logger.info("\n" + "=" * 60) logger.info("示例数据") logger.info("=" * 60) await show_samples(session, "mcn_agency") await show_samples(session, "streamer_info") await show_samples(session, "platform_recharge") # 6. 显示所有问题 logger.info("\n" + "=" * 60) logger.info("验证结果") logger.info("=" * 60) all_issues = fk_issues + consistency_issues + business_issues if all_issues: logger.warning(f"\n⚠️ 发现 {len(all_issues)} 个问题:") for issue in all_issues: logger.warning(f" {issue}") logger.warning("\n建议修复这些问题以确保数据质量") else: logger.info("\n✅ 所有检查通过!数据质量良好") except Exception as e: logger.error(f"验证过程出错: {e}") import traceback traceback.print_exc() return False finally: await session.close() return True if __name__ == "__main__": success = asyncio.run(main()) sys.exit(0 if success else 1)