#!/usr/bin/env python3 """ 风控检测模块测试脚本 演示完整的检测流程:创建任务 -> 执行任务 -> 查看结果 """ import asyncio import json import sys from datetime import datetime from pathlib import Path # 添加项目根目录到Python路径 sys.path.insert(0, str(Path(__file__).parent)) from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, text from app.database import AsyncSessionLocal, engine from app.models.user import User from app.models.streamer import StreamerInfo from app.models.risk_detection import DetectionTask, TaskType, TaskStatus from app.services.risk_detection.task_manager.task_manager import TaskManager from app.utils.helpers import get_password_hash from loguru import logger async def create_test_user(db: AsyncSession) -> User: """创建测试用户""" logger.info("检查是否存在测试用户...") # 查询用户 result = await db.execute(select(User).where(User.username == "test_user")) user = result.scalar_one_or_none() if user: logger.info(f"✅ 测试用户已存在: {user.username}") return user # 创建测试用户 logger.info("创建测试用户...") user = User( username="test_user", password=get_password_hash("test123"), nickname="测试用户", email="test@example.com", phone="13800138000", status=True, is_superuser=False, entity_id="TEST_001", entity_type="streamer" ) db.add(user) await db.flush() await db.refresh(user) logger.info(f"✅ 测试用户创建成功: {user.username} (ID: {user.id})") return user async def create_test_streamer(db: AsyncSession) -> StreamerInfo: """创建测试主播数据""" logger.info("检查是否存在测试主播...") # 查询主播 result = await db.execute(select(StreamerInfo).where(StreamerInfo.streamer_id == "TEST_001")) streamer = result.scalar_one_or_none() if streamer: logger.info(f"✅ 测试主播已存在: {streamer.streamer_id}") return streamer # 创建测试主播 logger.info("创建测试主播...") streamer = StreamerInfo( streamer_id="TEST_001", streamer_name="测试主播001", entity_type="individual", tax_registration_no="TAX_TEST_001", id_card_no="110101199001011234", phone_number="13800138001", bank_account_no="6222021234567890123", bank_name="中国工商银行", status="active" ) db.add(streamer) await db.flush() await db.refresh(streamer) logger.info(f"✅ 测试主播创建成功: {streamer.streamer_id}") return streamer async def test_create_task(db: AsyncSession, user: User) -> DetectionTask: """测试创建任务""" logger.info("\n" + "="*60) logger.info("测试1: 创建检测任务") logger.info("="*60) task_manager = TaskManager(db) try: task = await task_manager.create_task( task_name="测试任务 - TEST_001收入完整性检测", task_type=TaskType.ON_DEMAND, entity_ids=["TEST_001"], entity_type="streamer", period="2024-01", rule_ids=["REVENUE_INTEGRITY_CHECK"], parameters={ "streamer_id": "TEST_001", "comparison_type": "monthly" } ) logger.info(f"✅ 任务创建成功!") logger.info(f" 任务ID: {task.task_id}") logger.info(f" 任务名称: {task.task_name}") logger.info(f" 任务类型: {task.task_type.value}") logger.info(f" 实体类型: {task.entity_type}") logger.info(f" 检测期间: {task.period}") logger.info(f" 状态: {task.status.value}") return task except Exception as e: logger.error(f"❌ 任务创建失败: {e}") raise async def test_execute_task(db: AsyncSession, task_id: str): """测试执行任务""" logger.info("\n" + "="*60) logger.info("测试2: 执行检测任务") logger.info("="*60) task_manager = TaskManager(db) try: logger.info(f"开始执行任务: {task_id}") result = await task_manager.execute_task(task_id) logger.info(f"✅ 任务执行完成!") logger.info(f" 任务ID: {result['task_id']}") logger.info(f" 状态: {result['status']}") logger.info(f" 执行时间: {result['executed_at']}") if 'summary' in result: summary = result['summary'] logger.info(f"\n📊 检测结果汇总:") logger.info(f" 总实体数: {summary.get('total_entities', 0)}") logger.info(f" 总检测数: {summary.get('total_detections', 0)}") logger.info(f" 平均风险评分: {summary.get('avg_score', 0)}") risk_dist = summary.get('risk_distribution', {}) if risk_dist: logger.info(f" 风险分布:") for level, count in risk_dist.items(): logger.info(f" - {level}: {count}") return result except Exception as e: logger.error(f"❌ 任务执行失败: {e}") import traceback traceback.print_exc() raise async def test_list_tasks(db: AsyncSession): """测试查询任务列表""" logger.info("\n" + "="*60) logger.info("测试3: 查询任务列表") logger.info("="*60) task_manager = TaskManager(db) try: tasks = await task_manager.list_tasks(limit=10) logger.info(f"✅ 查询到 {len(tasks)} 个任务:") for task in tasks: logger.info(f"\n 任务ID: {task.task_id}") logger.info(f" 任务名称: {task.task_name}") logger.info(f" 状态: {task.status.value}") logger.info(f" 创建时间: {task.created_at}") except Exception as e: logger.error(f"❌ 查询任务列表失败: {e}") raise async def test_get_task_detail(db: AsyncSession, task_id: str): """测试获取任务详情""" logger.info("\n" + "="*60) logger.info("测试4: 获取任务详情") logger.info("="*60) task_manager = TaskManager(db) try: task = await task_manager.get_task(task_id) if not task: logger.error(f"❌ 任务不存在: {task_id}") return None logger.info(f"✅ 任务详情:") logger.info(f" 任务ID: {task.task_id}") logger.info(f" 任务名称: {task.task_name}") logger.info(f" 任务类型: {task.task_type.value}") logger.info(f" 状态: {task.status.value}") logger.info(f" 实体类型: {task.entity_type}") logger.info(f" 检测期间: {task.period}") logger.info(f" 总实体数: {task.total_entities}") logger.info(f" 已处理实体数: {task.processed_entities}") logger.info(f" 结果数量: {task.result_count}") logger.info(f" 创建时间: {task.created_at}") logger.info(f" 开始时间: {task.started_at}") logger.info(f" 完成时间: {task.completed_at}") if task.summary: logger.info(f"\n📊 任务汇总:") logger.info(json.dumps(task.summary, indent=2, ensure_ascii=False)) if task.error_message: logger.info(f"\n⚠️ 错误信息:") logger.info(task.error_message) return task except Exception as e: logger.error(f"❌ 获取任务详情失败: {e}") raise async def main(): """主函数""" logger.info("="*60) logger.info("风控检测模块测试") logger.info("主播: TEST_001") logger.info("算法: 收入完整性检测") logger.info("="*60) async with AsyncSessionLocal() as db: try: # 1. 创建测试用户 user = await create_test_user(db) await db.commit() # 2. 创建测试主播 streamer = await create_test_streamer(db) await db.commit() # 3. 创建任务 task = await test_create_task(db, user) await db.commit() # 4. 执行任务 result = await test_execute_task(db, task.task_id) await db.commit() # 5. 查询任务列表 await test_list_tasks(db) # 6. 获取任务详情 await test_get_task_detail(db, task.task_id) logger.info("\n" + "="*60) logger.info("✅ 所有测试完成!") logger.info("="*60) except Exception as e: logger.error(f"\n❌ 测试失败: {e}") import traceback traceback.print_exc() await db.rollback() sys.exit(1) if __name__ == "__main__": asyncio.run(main())