283 lines
8.7 KiB
Python
283 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
风控检测模块测试脚本
|
|
演示完整的检测流程:创建任务 -> 执行任务 -> 查看结果
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
# 添加项目根目录到Python路径
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, text
|
|
from app.database import AsyncSessionLocal, engine
|
|
from app.models.user import User
|
|
from app.models.streamer import StreamerInfo
|
|
from app.models.risk_detection import DetectionTask, TaskType, TaskStatus
|
|
from app.services.risk_detection.task_manager.task_manager import TaskManager
|
|
from app.utils.helpers import get_password_hash
|
|
from loguru import logger
|
|
|
|
|
|
async def create_test_user(db: AsyncSession) -> User:
|
|
"""创建测试用户"""
|
|
logger.info("检查是否存在测试用户...")
|
|
|
|
# 查询用户
|
|
result = await db.execute(select(User).where(User.username == "test_user"))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user:
|
|
logger.info(f"✅ 测试用户已存在: {user.username}")
|
|
return user
|
|
|
|
# 创建测试用户
|
|
logger.info("创建测试用户...")
|
|
user = User(
|
|
username="test_user",
|
|
password=get_password_hash("test123"),
|
|
nickname="测试用户",
|
|
email="test@example.com",
|
|
phone="13800138000",
|
|
status=True,
|
|
is_superuser=False,
|
|
entity_id="TEST_001",
|
|
entity_type="streamer"
|
|
)
|
|
|
|
db.add(user)
|
|
await db.flush()
|
|
await db.refresh(user)
|
|
|
|
logger.info(f"✅ 测试用户创建成功: {user.username} (ID: {user.id})")
|
|
return user
|
|
|
|
|
|
async def create_test_streamer(db: AsyncSession) -> StreamerInfo:
|
|
"""创建测试主播数据"""
|
|
logger.info("检查是否存在测试主播...")
|
|
|
|
# 查询主播
|
|
result = await db.execute(select(StreamerInfo).where(StreamerInfo.streamer_id == "TEST_001"))
|
|
streamer = result.scalar_one_or_none()
|
|
|
|
if streamer:
|
|
logger.info(f"✅ 测试主播已存在: {streamer.streamer_id}")
|
|
return streamer
|
|
|
|
# 创建测试主播
|
|
logger.info("创建测试主播...")
|
|
streamer = StreamerInfo(
|
|
streamer_id="TEST_001",
|
|
streamer_name="测试主播001",
|
|
entity_type="individual",
|
|
tax_registration_no="TAX_TEST_001",
|
|
id_card_no="110101199001011234",
|
|
phone_number="13800138001",
|
|
bank_account_no="6222021234567890123",
|
|
bank_name="中国工商银行",
|
|
status="active"
|
|
)
|
|
|
|
db.add(streamer)
|
|
await db.flush()
|
|
await db.refresh(streamer)
|
|
|
|
logger.info(f"✅ 测试主播创建成功: {streamer.streamer_id}")
|
|
return streamer
|
|
|
|
|
|
async def test_create_task(db: AsyncSession, user: User) -> DetectionTask:
|
|
"""测试创建任务"""
|
|
logger.info("\n" + "="*60)
|
|
logger.info("测试1: 创建检测任务")
|
|
logger.info("="*60)
|
|
|
|
task_manager = TaskManager(db)
|
|
|
|
try:
|
|
task = await task_manager.create_task(
|
|
task_name="测试任务 - TEST_001收入完整性检测",
|
|
task_type=TaskType.ON_DEMAND,
|
|
entity_ids=["TEST_001"],
|
|
entity_type="streamer",
|
|
period="2024-01",
|
|
rule_ids=["REVENUE_INTEGRITY_CHECK"],
|
|
parameters={
|
|
"streamer_id": "TEST_001",
|
|
"comparison_type": "monthly"
|
|
}
|
|
)
|
|
|
|
logger.info(f"✅ 任务创建成功!")
|
|
logger.info(f" 任务ID: {task.task_id}")
|
|
logger.info(f" 任务名称: {task.task_name}")
|
|
logger.info(f" 任务类型: {task.task_type.value}")
|
|
logger.info(f" 实体类型: {task.entity_type}")
|
|
logger.info(f" 检测期间: {task.period}")
|
|
logger.info(f" 状态: {task.status.value}")
|
|
|
|
return task
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 任务创建失败: {e}")
|
|
raise
|
|
|
|
|
|
async def test_execute_task(db: AsyncSession, task_id: str):
|
|
"""测试执行任务"""
|
|
logger.info("\n" + "="*60)
|
|
logger.info("测试2: 执行检测任务")
|
|
logger.info("="*60)
|
|
|
|
task_manager = TaskManager(db)
|
|
|
|
try:
|
|
logger.info(f"开始执行任务: {task_id}")
|
|
result = await task_manager.execute_task(task_id)
|
|
|
|
logger.info(f"✅ 任务执行完成!")
|
|
logger.info(f" 任务ID: {result['task_id']}")
|
|
logger.info(f" 状态: {result['status']}")
|
|
logger.info(f" 执行时间: {result['executed_at']}")
|
|
|
|
if 'summary' in result:
|
|
summary = result['summary']
|
|
logger.info(f"\n📊 检测结果汇总:")
|
|
logger.info(f" 总实体数: {summary.get('total_entities', 0)}")
|
|
logger.info(f" 总检测数: {summary.get('total_detections', 0)}")
|
|
logger.info(f" 平均风险评分: {summary.get('avg_score', 0)}")
|
|
|
|
risk_dist = summary.get('risk_distribution', {})
|
|
if risk_dist:
|
|
logger.info(f" 风险分布:")
|
|
for level, count in risk_dist.items():
|
|
logger.info(f" - {level}: {count}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 任务执行失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
async def test_list_tasks(db: AsyncSession):
|
|
"""测试查询任务列表"""
|
|
logger.info("\n" + "="*60)
|
|
logger.info("测试3: 查询任务列表")
|
|
logger.info("="*60)
|
|
|
|
task_manager = TaskManager(db)
|
|
|
|
try:
|
|
tasks = await task_manager.list_tasks(limit=10)
|
|
|
|
logger.info(f"✅ 查询到 {len(tasks)} 个任务:")
|
|
for task in tasks:
|
|
logger.info(f"\n 任务ID: {task.task_id}")
|
|
logger.info(f" 任务名称: {task.task_name}")
|
|
logger.info(f" 状态: {task.status.value}")
|
|
logger.info(f" 创建时间: {task.created_at}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 查询任务列表失败: {e}")
|
|
raise
|
|
|
|
|
|
async def test_get_task_detail(db: AsyncSession, task_id: str):
|
|
"""测试获取任务详情"""
|
|
logger.info("\n" + "="*60)
|
|
logger.info("测试4: 获取任务详情")
|
|
logger.info("="*60)
|
|
|
|
task_manager = TaskManager(db)
|
|
|
|
try:
|
|
task = await task_manager.get_task(task_id)
|
|
|
|
if not task:
|
|
logger.error(f"❌ 任务不存在: {task_id}")
|
|
return None
|
|
|
|
logger.info(f"✅ 任务详情:")
|
|
logger.info(f" 任务ID: {task.task_id}")
|
|
logger.info(f" 任务名称: {task.task_name}")
|
|
logger.info(f" 任务类型: {task.task_type.value}")
|
|
logger.info(f" 状态: {task.status.value}")
|
|
logger.info(f" 实体类型: {task.entity_type}")
|
|
logger.info(f" 检测期间: {task.period}")
|
|
logger.info(f" 总实体数: {task.total_entities}")
|
|
logger.info(f" 已处理实体数: {task.processed_entities}")
|
|
logger.info(f" 结果数量: {task.result_count}")
|
|
logger.info(f" 创建时间: {task.created_at}")
|
|
logger.info(f" 开始时间: {task.started_at}")
|
|
logger.info(f" 完成时间: {task.completed_at}")
|
|
|
|
if task.summary:
|
|
logger.info(f"\n📊 任务汇总:")
|
|
logger.info(json.dumps(task.summary, indent=2, ensure_ascii=False))
|
|
|
|
if task.error_message:
|
|
logger.info(f"\n⚠️ 错误信息:")
|
|
logger.info(task.error_message)
|
|
|
|
return task
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 获取任务详情失败: {e}")
|
|
raise
|
|
|
|
|
|
async def main():
|
|
"""主函数"""
|
|
logger.info("="*60)
|
|
logger.info("风控检测模块测试")
|
|
logger.info("主播: TEST_001")
|
|
logger.info("算法: 收入完整性检测")
|
|
logger.info("="*60)
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
try:
|
|
# 1. 创建测试用户
|
|
user = await create_test_user(db)
|
|
await db.commit()
|
|
|
|
# 2. 创建测试主播
|
|
streamer = await create_test_streamer(db)
|
|
await db.commit()
|
|
|
|
# 3. 创建任务
|
|
task = await test_create_task(db, user)
|
|
await db.commit()
|
|
|
|
# 4. 执行任务
|
|
result = await test_execute_task(db, task.task_id)
|
|
await db.commit()
|
|
|
|
# 5. 查询任务列表
|
|
await test_list_tasks(db)
|
|
|
|
# 6. 获取任务详情
|
|
await test_get_task_detail(db, task.task_id)
|
|
|
|
logger.info("\n" + "="*60)
|
|
logger.info("✅ 所有测试完成!")
|
|
logger.info("="*60)
|
|
|
|
except Exception as e:
|
|
logger.error(f"\n❌ 测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
await db.rollback()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|