deep-risk/backend/app/api/v1/endpoints/contract.py
2025-12-14 20:08:27 +08:00

291 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
分成协议API路由
"""
from typing import Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime
from loguru import logger
from app.database import get_async_session
from app.models.contract import Contract
from app.schemas.contract import (
ContractCreate,
ContractUpdate,
ContractResponse,
ContractListResponse,
)
router = APIRouter()
@router.get("", response_model=ContractListResponse)
async def list_contracts(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(10, ge=1, le=100, description="每页数量"),
contract_id: str = Query(None, description="协议ID"),
streamer_id: str = Query(None, description="主播ID"),
contract_type: str = Query(None, description="协议类型"),
contract_status: str = Query(None, description="协议状态"),
db: AsyncSession = Depends(get_async_session),
):
"""
获取分成协议列表(分页查询)
"""
logger.info(f"获取分成协议列表: page={page}, size={size}")
# 构建查询
query = select(Contract)
# 添加过滤条件
conditions = []
if contract_id:
conditions.append(Contract.contract_id.ilike(f"%{contract_id}%"))
if streamer_id:
conditions.append(Contract.streamer_id == streamer_id)
if contract_type:
conditions.append(Contract.contract_type == contract_type)
if contract_status:
conditions.append(Contract.contract_status == contract_status)
if conditions:
query = query.where(and_(*conditions))
# 获取总数
count_query = select(func.count()).select_from(Contract)
if conditions:
count_query = count_query.where(and_(*conditions))
total_result = await db.execute(count_query)
total = total_result.scalar()
# 分页
query = query.offset((page - 1) * size).limit(size)
# 执行查询
result = await db.execute(query)
records = result.scalars().all()
# 转换为响应格式
response_records = []
for contract in records:
response_records.append({
"id": contract.id,
"contract_id": contract.contract_id,
"contract_no": contract.contract_no,
"contract_type": contract.contract_type,
"streamer_id": contract.streamer_id,
"streamer_name": contract.streamer_name,
"streamer_entity_type": contract.streamer_entity_type,
"platform_party": contract.platform_party,
"platform_credit_code": contract.platform_credit_code,
"revenue_type": contract.revenue_type,
"platform_ratio": contract.platform_ratio,
"streamer_ratio": contract.streamer_ratio,
"settlement_cycle": contract.settlement_cycle,
"contract_start_date": contract.contract_start_date.isoformat() if contract.contract_start_date else None,
"contract_end_date": contract.contract_end_date.isoformat() if contract.contract_end_date else None,
"contract_status": contract.contract_status,
"created_at": contract.created_at.isoformat() if contract.created_at else None,
})
return ContractListResponse(
records=response_records,
total=total,
page=page,
size=size,
)
@router.get("/{contract_id}", response_model=ContractResponse)
async def get_contract(contract_id: str, db: AsyncSession = Depends(get_async_session)):
"""
根据ID获取分成协议详细信息
"""
logger.info(f"获取分成协议详情: {contract_id}")
query = select(Contract).where(Contract.contract_id == contract_id)
result = await db.execute(query)
contract = result.scalar_one_or_none()
if not contract:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"分成协议不存在: {contract_id}",
)
# 转换为响应格式
response = {
"id": contract.id,
"contract_id": contract.contract_id,
"contract_no": contract.contract_no,
"contract_type": contract.contract_type,
"streamer_id": contract.streamer_id,
"streamer_name": contract.streamer_name,
"streamer_entity_type": contract.streamer_entity_type,
"platform_party": contract.platform_party,
"platform_credit_code": contract.platform_credit_code,
"revenue_type": contract.revenue_type,
"platform_ratio": contract.platform_ratio,
"streamer_ratio": contract.streamer_ratio,
"settlement_cycle": contract.settlement_cycle,
"contract_start_date": contract.contract_start_date.isoformat() if contract.contract_start_date else None,
"contract_end_date": contract.contract_end_date.isoformat() if contract.contract_end_date else None,
"contract_status": contract.contract_status,
"created_at": contract.created_at.isoformat() if contract.created_at else None,
}
return response
@router.post("", response_model=ContractResponse, status_code=status.HTTP_201_CREATED)
async def create_contract(contract: ContractCreate, db: AsyncSession = Depends(get_async_session)):
"""
创建新的分成协议
"""
logger.info(f"创建分成协议: {contract.streamer_name}")
# 生成协议ID
query = select(func.count()).select_from(Contract)
result = await db.execute(query)
count = result.scalar()
contract_id = f"CTR{2024}{(count + 1):06d}"
# 为缺失的字段设置默认值
# 生成协议编号
contract_no = contract.contract_no or f"CTR{datetime.now().strftime('%Y%m%d')}{(count + 1):04d}"
# 主播ID可以使用主播名称的hash或直接使用名称
streamer_id = contract.streamer_id or f"STR{abs(hash(contract.streamer_name)) % 100000:05d}"
# 主播主体类型(默认个人)
streamer_entity_type = contract.streamer_entity_type or "individual"
# 平台方信用代码默认空或Placeholder
platform_credit_code = contract.platform_credit_code or "PLACEHOLDER"
# 协议状态(默认生效)
contract_status = contract.contract_status or "active"
# 创建新协议
new_contract = Contract(
contract_id=contract_id,
contract_no=contract_no,
contract_type=contract.contract_type,
streamer_id=streamer_id,
streamer_name=contract.streamer_name,
streamer_entity_type=streamer_entity_type,
platform_party=contract.platform_party,
platform_credit_code=platform_credit_code,
revenue_type=contract.revenue_type,
platform_ratio=contract.platform_ratio,
streamer_ratio=contract.streamer_ratio,
settlement_cycle=contract.settlement_cycle,
contract_start_date=contract.contract_start_date,
contract_end_date=contract.contract_end_date,
contract_status=contract_status,
remark=contract.remark,
)
db.add(new_contract)
await db.commit()
await db.refresh(new_contract)
# 转换为响应格式
response = {
"id": new_contract.id,
"contract_id": new_contract.contract_id,
"contract_no": new_contract.contract_no,
"contract_type": new_contract.contract_type,
"streamer_id": new_contract.streamer_id,
"streamer_name": new_contract.streamer_name,
"streamer_entity_type": new_contract.streamer_entity_type,
"platform_party": new_contract.platform_party,
"platform_credit_code": new_contract.platform_credit_code,
"revenue_type": new_contract.revenue_type,
"platform_ratio": new_contract.platform_ratio,
"streamer_ratio": new_contract.streamer_ratio,
"settlement_cycle": new_contract.settlement_cycle,
"contract_start_date": new_contract.contract_start_date.isoformat() if new_contract.contract_start_date else None,
"contract_end_date": new_contract.contract_end_date.isoformat() if new_contract.contract_end_date else None,
"contract_status": new_contract.contract_status,
"created_at": new_contract.created_at.isoformat() if new_contract.created_at else None,
}
return response
@router.put("/{contract_id}", response_model=ContractResponse)
async def update_contract(
contract_id: str,
contract_update: ContractUpdate,
db: AsyncSession = Depends(get_async_session),
):
"""
更新分成协议
"""
logger.info(f"更新分成协议: {contract_id}")
query = select(Contract).where(Contract.contract_id == contract_id)
result = await db.execute(query)
contract = result.scalar_one_or_none()
if not contract:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"分成协议不存在: {contract_id}",
)
# 更新字段
update_data = contract_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
if hasattr(contract, key):
setattr(contract, key, value)
await db.commit()
await db.refresh(contract)
# 转换为响应格式
response = {
"id": contract.id,
"contract_id": contract.contract_id,
"contract_no": contract.contract_no,
"contract_type": contract.contract_type,
"streamer_id": contract.streamer_id,
"streamer_name": contract.streamer_name,
"streamer_entity_type": contract.streamer_entity_type,
"platform_party": contract.platform_party,
"platform_credit_code": contract.platform_credit_code,
"revenue_type": contract.revenue_type,
"platform_ratio": contract.platform_ratio,
"streamer_ratio": contract.streamer_ratio,
"settlement_cycle": contract.settlement_cycle,
"contract_start_date": contract.contract_start_date.isoformat() if contract.contract_start_date else None,
"contract_end_date": contract.contract_end_date.isoformat() if contract.contract_end_date else None,
"contract_status": contract.contract_status,
"created_at": contract.created_at.isoformat() if contract.created_at else None,
}
return response
@router.delete("/{contract_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_contract(contract_id: str, db: AsyncSession = Depends(get_async_session)):
"""
删除分成协议(软删除)
"""
logger.info(f"删除分成协议: {contract_id}")
query = select(Contract).where(Contract.contract_id == contract_id)
result = await db.execute(query)
contract = result.scalar_one_or_none()
if not contract:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"分成协议不存在: {contract_id}",
)
# 软删除
contract.contract_status = "terminated"
await db.commit()
return None