deep-risk/backend/tests/test_risk_detection_workflow.py
2025-12-14 20:08:27 +08:00

283 lines
8.7 KiB
Python

#!/usr/bin/env python3
"""
风控检测模块测试脚本
演示完整的检测流程:创建任务 -> 执行任务 -> 查看结果
"""
import asyncio
import json
import sys
from datetime import datetime
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.insert(0, str(Path(__file__).parent))
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text
from app.database import AsyncSessionLocal, engine
from app.models.user import User
from app.models.streamer import StreamerInfo
from app.models.risk_detection import DetectionTask, TaskType, TaskStatus
from app.services.risk_detection.task_manager.task_manager import TaskManager
from app.utils.helpers import get_password_hash
from loguru import logger
async def create_test_user(db: AsyncSession) -> User:
"""创建测试用户"""
logger.info("检查是否存在测试用户...")
# 查询用户
result = await db.execute(select(User).where(User.username == "test_user"))
user = result.scalar_one_or_none()
if user:
logger.info(f"✅ 测试用户已存在: {user.username}")
return user
# 创建测试用户
logger.info("创建测试用户...")
user = User(
username="test_user",
password=get_password_hash("test123"),
nickname="测试用户",
email="test@example.com",
phone="13800138000",
status=True,
is_superuser=False,
entity_id="TEST_001",
entity_type="streamer"
)
db.add(user)
await db.flush()
await db.refresh(user)
logger.info(f"✅ 测试用户创建成功: {user.username} (ID: {user.id})")
return user
async def create_test_streamer(db: AsyncSession) -> StreamerInfo:
"""创建测试主播数据"""
logger.info("检查是否存在测试主播...")
# 查询主播
result = await db.execute(select(StreamerInfo).where(StreamerInfo.streamer_id == "TEST_001"))
streamer = result.scalar_one_or_none()
if streamer:
logger.info(f"✅ 测试主播已存在: {streamer.streamer_id}")
return streamer
# 创建测试主播
logger.info("创建测试主播...")
streamer = StreamerInfo(
streamer_id="TEST_001",
streamer_name="测试主播001",
entity_type="individual",
tax_registration_no="TAX_TEST_001",
id_card_no="110101199001011234",
phone_number="13800138001",
bank_account_no="6222021234567890123",
bank_name="中国工商银行",
status="active"
)
db.add(streamer)
await db.flush()
await db.refresh(streamer)
logger.info(f"✅ 测试主播创建成功: {streamer.streamer_id}")
return streamer
async def test_create_task(db: AsyncSession, user: User) -> DetectionTask:
"""测试创建任务"""
logger.info("\n" + "="*60)
logger.info("测试1: 创建检测任务")
logger.info("="*60)
task_manager = TaskManager(db)
try:
task = await task_manager.create_task(
task_name="测试任务 - TEST_001收入完整性检测",
task_type=TaskType.ON_DEMAND,
entity_ids=["TEST_001"],
entity_type="streamer",
period="2024-01",
rule_ids=["REVENUE_INTEGRITY_CHECK"],
parameters={
"streamer_id": "TEST_001",
"comparison_type": "monthly"
}
)
logger.info(f"✅ 任务创建成功!")
logger.info(f" 任务ID: {task.task_id}")
logger.info(f" 任务名称: {task.task_name}")
logger.info(f" 任务类型: {task.task_type.value}")
logger.info(f" 实体类型: {task.entity_type}")
logger.info(f" 检测期间: {task.period}")
logger.info(f" 状态: {task.status.value}")
return task
except Exception as e:
logger.error(f"❌ 任务创建失败: {e}")
raise
async def test_execute_task(db: AsyncSession, task_id: str):
"""测试执行任务"""
logger.info("\n" + "="*60)
logger.info("测试2: 执行检测任务")
logger.info("="*60)
task_manager = TaskManager(db)
try:
logger.info(f"开始执行任务: {task_id}")
result = await task_manager.execute_task(task_id)
logger.info(f"✅ 任务执行完成!")
logger.info(f" 任务ID: {result['task_id']}")
logger.info(f" 状态: {result['status']}")
logger.info(f" 执行时间: {result['executed_at']}")
if 'summary' in result:
summary = result['summary']
logger.info(f"\n📊 检测结果汇总:")
logger.info(f" 总实体数: {summary.get('total_entities', 0)}")
logger.info(f" 总检测数: {summary.get('total_detections', 0)}")
logger.info(f" 平均风险评分: {summary.get('avg_score', 0)}")
risk_dist = summary.get('risk_distribution', {})
if risk_dist:
logger.info(f" 风险分布:")
for level, count in risk_dist.items():
logger.info(f" - {level}: {count}")
return result
except Exception as e:
logger.error(f"❌ 任务执行失败: {e}")
import traceback
traceback.print_exc()
raise
async def test_list_tasks(db: AsyncSession):
"""测试查询任务列表"""
logger.info("\n" + "="*60)
logger.info("测试3: 查询任务列表")
logger.info("="*60)
task_manager = TaskManager(db)
try:
tasks = await task_manager.list_tasks(limit=10)
logger.info(f"✅ 查询到 {len(tasks)} 个任务:")
for task in tasks:
logger.info(f"\n 任务ID: {task.task_id}")
logger.info(f" 任务名称: {task.task_name}")
logger.info(f" 状态: {task.status.value}")
logger.info(f" 创建时间: {task.created_at}")
except Exception as e:
logger.error(f"❌ 查询任务列表失败: {e}")
raise
async def test_get_task_detail(db: AsyncSession, task_id: str):
"""测试获取任务详情"""
logger.info("\n" + "="*60)
logger.info("测试4: 获取任务详情")
logger.info("="*60)
task_manager = TaskManager(db)
try:
task = await task_manager.get_task(task_id)
if not task:
logger.error(f"❌ 任务不存在: {task_id}")
return None
logger.info(f"✅ 任务详情:")
logger.info(f" 任务ID: {task.task_id}")
logger.info(f" 任务名称: {task.task_name}")
logger.info(f" 任务类型: {task.task_type.value}")
logger.info(f" 状态: {task.status.value}")
logger.info(f" 实体类型: {task.entity_type}")
logger.info(f" 检测期间: {task.period}")
logger.info(f" 总实体数: {task.total_entities}")
logger.info(f" 已处理实体数: {task.processed_entities}")
logger.info(f" 结果数量: {task.result_count}")
logger.info(f" 创建时间: {task.created_at}")
logger.info(f" 开始时间: {task.started_at}")
logger.info(f" 完成时间: {task.completed_at}")
if task.summary:
logger.info(f"\n📊 任务汇总:")
logger.info(json.dumps(task.summary, indent=2, ensure_ascii=False))
if task.error_message:
logger.info(f"\n⚠️ 错误信息:")
logger.info(task.error_message)
return task
except Exception as e:
logger.error(f"❌ 获取任务详情失败: {e}")
raise
async def main():
"""主函数"""
logger.info("="*60)
logger.info("风控检测模块测试")
logger.info("主播: TEST_001")
logger.info("算法: 收入完整性检测")
logger.info("="*60)
async with AsyncSessionLocal() as db:
try:
# 1. 创建测试用户
user = await create_test_user(db)
await db.commit()
# 2. 创建测试主播
streamer = await create_test_streamer(db)
await db.commit()
# 3. 创建任务
task = await test_create_task(db, user)
await db.commit()
# 4. 执行任务
result = await test_execute_task(db, task.task_id)
await db.commit()
# 5. 查询任务列表
await test_list_tasks(db)
# 6. 获取任务详情
await test_get_task_detail(db, task.task_id)
logger.info("\n" + "="*60)
logger.info("✅ 所有测试完成!")
logger.info("="*60)
except Exception as e:
logger.error(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()
await db.rollback()
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())