404 lines
14 KiB
Python
404 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
HTTP接口端到端测试脚本
|
||
测试风险检测模块的完整API流程
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Optional
|
||
|
||
import httpx
|
||
from loguru import logger
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from app.config import settings
|
||
|
||
|
||
class RiskDetectionAPITester:
|
||
"""风险检测API测试器"""
|
||
|
||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||
self.base_url = base_url.rstrip("/")
|
||
self.api_v1 = f"{self.base_url}/api/v1"
|
||
self.client = httpx.AsyncClient(timeout=30.0)
|
||
self.access_token: Optional[str] = None
|
||
self.headers: Dict[str, str] = {
|
||
"Content-Type": "application/json",
|
||
"Accept": "application/json"
|
||
}
|
||
|
||
async def __aenter__(self):
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
await self.client.aclose()
|
||
|
||
async def login(self, username: str = "test_user", password: str = "test123") -> bool:
|
||
"""
|
||
用户登录获取token
|
||
|
||
Args:
|
||
username: 用户名
|
||
password: 密码
|
||
|
||
Returns:
|
||
登录是否成功
|
||
"""
|
||
logger.info("="*70)
|
||
logger.info("步骤1: 用户登录")
|
||
logger.info("="*70)
|
||
|
||
try:
|
||
# 使用 OAuth2 密码表单格式
|
||
data = {
|
||
"username": username,
|
||
"password": password,
|
||
"grant_type": "password", # OAuth2 需要此字段,且必须为"password"
|
||
"scope": "",
|
||
"client_id": "",
|
||
"client_secret": ""
|
||
}
|
||
|
||
response = await self.client.post(
|
||
f"{self.api_v1}/auth/login",
|
||
data=data,
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
self.access_token = result.get("access_token")
|
||
self.headers["Authorization"] = f"Bearer {self.access_token}"
|
||
|
||
logger.info("✅ 登录成功!")
|
||
logger.info(f" 用户ID: {result['user_info']['id']}")
|
||
logger.info(f" 用户名: {result['user_info']['username']}")
|
||
logger.info(f" Token: {self.access_token[:20]}...")
|
||
logger.info(f" 过期时间: {result['expires_in']}秒")
|
||
return True
|
||
else:
|
||
logger.error(f"❌ 登录失败: {response.status_code}")
|
||
logger.error(f" 响应: {response.text}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 登录异常: {e}")
|
||
return False
|
||
|
||
async def get_algorithms(self) -> bool:
|
||
"""
|
||
获取算法列表
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info("\n" + "="*70)
|
||
logger.info("步骤2: 获取算法列表")
|
||
logger.info("="*70)
|
||
|
||
try:
|
||
response = await self.client.get(
|
||
f"{self.api_v1}/risk-detection/algorithms",
|
||
headers=self.headers
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
algorithms = response.json()
|
||
logger.info(f"✅ 获取算法列表成功,共 {len(algorithms)} 个算法:")
|
||
for algo in algorithms:
|
||
logger.info(f" - {algo['code']}: {algo['name']}")
|
||
return True
|
||
else:
|
||
logger.error(f"❌ 获取算法列表失败: {response.status_code}")
|
||
logger.error(f" 响应: {response.text}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取算法列表异常: {e}")
|
||
return False
|
||
|
||
async def execute_detection(self, entity_id: str = "TEST_001", period: str = "2024-01") -> Optional[str]:
|
||
"""
|
||
执行即时检测
|
||
|
||
Args:
|
||
entity_id: 实体ID
|
||
period: 检测期间
|
||
|
||
Returns:
|
||
任务ID,失败返回None
|
||
"""
|
||
logger.info("\n" + "="*70)
|
||
logger.info("步骤3: 执行即时检测")
|
||
logger.info("="*70)
|
||
|
||
try:
|
||
request_data = {
|
||
"task_name": f"HTTP接口测试-{entity_id}-{period}",
|
||
"entity_id": entity_id,
|
||
"entity_type": "streamer",
|
||
"period": period,
|
||
"rule_ids": ["REVENUE_INTEGRITY_CHECK"],
|
||
"parameters": {
|
||
"streamer_id": entity_id,
|
||
"comparison_type": "monthly"
|
||
}
|
||
}
|
||
|
||
logger.info(f"发送检测请求...")
|
||
logger.info(f" 实体ID: {entity_id}")
|
||
logger.info(f" 检测期间: {period}")
|
||
logger.info(f" 算法: REVENUE_INTEGRITY_CHECK")
|
||
|
||
response = await self.client.post(
|
||
f"{self.api_v1}/risk-detection/execute",
|
||
json=request_data,
|
||
headers=self.headers
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
task_id = result.get("task_id")
|
||
summary = result.get("summary", {})
|
||
|
||
logger.info("✅ 检测执行成功!")
|
||
logger.info(f" 任务ID: {task_id}")
|
||
logger.info(f" 状态: {result['status']}")
|
||
logger.info(f" 检测结果数: {result['result_count']}")
|
||
|
||
if summary:
|
||
logger.info(f" 总实体数: {summary.get('total_entities', 0)}")
|
||
logger.info(f" 总检测数: {summary.get('total_detections', 0)}")
|
||
logger.info(f" 平均风险评分: {summary.get('avg_score', 0)}")
|
||
|
||
risk_dist = summary.get('risk_distribution', {})
|
||
if risk_dist:
|
||
logger.info(f" 风险分布: {risk_dist}")
|
||
|
||
return task_id
|
||
else:
|
||
logger.error(f"❌ 检测执行失败: {response.status_code}")
|
||
logger.error(f" 响应: {response.text}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 检测执行异常: {e}")
|
||
return None
|
||
|
||
async def get_task_details(self, task_id: str) -> bool:
|
||
"""
|
||
获取任务详情
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info("\n" + "="*70)
|
||
logger.info("步骤4: 获取任务详情")
|
||
logger.info("="*70)
|
||
|
||
try:
|
||
response = await self.client.get(
|
||
f"{self.api_v1}/risk-detection/tasks/{task_id}",
|
||
headers=self.headers
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
task = response.json()
|
||
logger.info(f"✅ 获取任务详情成功:")
|
||
logger.info(f" 任务ID: {task['task_id']}")
|
||
logger.info(f" 任务名称: {task['task_name']}")
|
||
logger.info(f" 状态: {task['status']}")
|
||
logger.info(f" 实体类型: {task['entity_type']}")
|
||
logger.info(f" 检测期间: {task['period']}")
|
||
logger.info(f" 总实体数: {task['total_entities']}")
|
||
logger.info(f" 已处理实体数: {task['processed_entities']}")
|
||
logger.info(f" 结果数量: {task.get('result_count', 0)}")
|
||
|
||
if task.get('summary'):
|
||
logger.info(f"\n📊 任务汇总:")
|
||
logger.info(json.dumps(task['summary'], indent=2, ensure_ascii=False))
|
||
|
||
if task.get('error_message'):
|
||
logger.info(f"\n⚠️ 错误信息: {task['error_message']}")
|
||
|
||
return True
|
||
else:
|
||
logger.error(f"❌ 获取任务详情失败: {response.status_code}")
|
||
logger.error(f" 响应: {response.text}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取任务详情异常: {e}")
|
||
return False
|
||
|
||
async def list_tasks(self, limit: int = 10) -> bool:
|
||
"""
|
||
列出任务列表
|
||
|
||
Args:
|
||
limit: 返回数量限制
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info("\n" + "="*70)
|
||
logger.info("步骤5: 列出任务列表")
|
||
logger.info("="*70)
|
||
|
||
try:
|
||
response = await self.client.get(
|
||
f"{self.api_v1}/risk-detection/tasks?limit={limit}",
|
||
headers=self.headers
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
tasks = response.json()
|
||
logger.info(f"✅ 获取任务列表成功,共 {len(tasks)} 个任务:")
|
||
|
||
for i, task in enumerate(tasks, 1):
|
||
logger.info(f"\n 任务 {i}:")
|
||
logger.info(f" ID: {task['task_id']}")
|
||
logger.info(f" 名称: {task['task_name']}")
|
||
logger.info(f" 状态: {task['status']}")
|
||
logger.info(f" 期间: {task['period']}")
|
||
|
||
return True
|
||
else:
|
||
logger.error(f"❌ 获取任务列表失败: {response.status_code}")
|
||
logger.error(f" 响应: {response.text}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取任务列表异常: {e}")
|
||
return False
|
||
|
||
async def get_results(self, task_id: str, limit: int = 10) -> bool:
|
||
"""
|
||
获取检测结果
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
limit: 返回数量限制
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info("\n" + "="*70)
|
||
logger.info("步骤6: 获取检测结果")
|
||
logger.info("="*70)
|
||
|
||
try:
|
||
response = await self.client.get(
|
||
f"{self.api_v1}/risk-detection/results?task_id={task_id}&limit={limit}",
|
||
headers=self.headers
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
results = response.json()
|
||
logger.info(f"✅ 获取检测结果成功,共 {len(results)} 条结果:")
|
||
|
||
for i, result in enumerate(results, 1):
|
||
logger.info(f"\n 结果 {i}:")
|
||
logger.info(f" 任务ID: {result['task_id']}")
|
||
logger.info(f" 实体ID: {result['result']['entity_id']}")
|
||
logger.info(f" 风险等级: {result['result']['risk_level']}")
|
||
logger.info(f" 风险评分: {result['result']['risk_score']}")
|
||
logger.info(f" 描述: {result['result']['description'][:50]}...")
|
||
|
||
return True
|
||
else:
|
||
logger.error(f"❌ 获取检测结果失败: {response.status_code}")
|
||
logger.error(f" 响应: {response.text}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取检测结果异常: {e}")
|
||
return False
|
||
|
||
async def test_complete_workflow(self) -> bool:
|
||
"""
|
||
测试完整的API流程
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
logger.info("\n" + "="*70)
|
||
logger.info("🚀 开始HTTP接口端到端测试")
|
||
logger.info("="*70)
|
||
logger.info(f"API地址: {self.base_url}")
|
||
logger.info(f"测试实体: TEST_001")
|
||
logger.info(f"测试算法: 收入完整性检测")
|
||
|
||
# 1. 登录
|
||
if not await self.login():
|
||
logger.error("❌ 测试失败:登录失败")
|
||
return False
|
||
|
||
# 2. 获取算法列表
|
||
if not await self.get_algorithms():
|
||
logger.warning("⚠️ 警告:获取算法列表失败")
|
||
|
||
# 3. 执行检测
|
||
task_id = await self.execute_detection()
|
||
if not task_id:
|
||
logger.error("❌ 测试失败:执行检测失败")
|
||
return False
|
||
|
||
# 等待任务完成
|
||
logger.info("\n⏳ 等待任务完成...")
|
||
await asyncio.sleep(2)
|
||
|
||
# 4. 获取任务详情
|
||
if not await self.get_task_details(task_id):
|
||
logger.warning("⚠️ 警告:获取任务详情失败")
|
||
|
||
# 5. 列出任务列表
|
||
if not await self.list_tasks():
|
||
logger.warning("⚠️ 警告:获取任务列表失败")
|
||
|
||
# 6. 获取检测结果
|
||
if not await self.get_results(task_id):
|
||
logger.warning("⚠️ 警告:获取检测结果失败")
|
||
|
||
# 测试完成
|
||
logger.info("\n" + "="*70)
|
||
logger.info("✨ HTTP接口端到端测试完成!")
|
||
logger.info("="*70)
|
||
return True
|
||
|
||
|
||
async def main():
|
||
"""主函数"""
|
||
logger.info("风险检测模块 - HTTP接口端到端测试")
|
||
logger.info("="*70)
|
||
|
||
# 检查服务是否启动
|
||
logger.info("检查后端服务状态...")
|
||
try:
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
response = await client.get("http://localhost:8000/health")
|
||
if response.status_code == 200:
|
||
logger.info("✅ 后端服务运行正常")
|
||
else:
|
||
logger.warning(f"⚠️ 后端服务状态异常: {response.status_code}")
|
||
except Exception as e:
|
||
logger.error(f"❌ 无法连接到后端服务: {e}")
|
||
logger.error("请先启动后端服务: python -m uvicorn app.main:app --host 0.0.0.0 --port 8000")
|
||
sys.exit(1)
|
||
|
||
# 运行测试
|
||
async with RiskDetectionAPITester() as tester:
|
||
success = await tester.test_complete_workflow()
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|