deep-risk/backend/app/utils/helpers.py
2025-12-14 20:08:27 +08:00

215 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
辅助函数和工具
包括密码处理、JWT token处理等
"""
from datetime import datetime, timedelta
from typing import Any, Optional, Union
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import get_async_session
from app.models.user import User
# 密码加密上下文
# 注意使用pbkdf2_sha256代替bcrypt以避免Python 3.14兼容性问题
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
# OAuth2 方案
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/auth/login"
)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
验证密码
"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""
获取密码哈希值
"""
return pwd_context.hash(password)
def create_access_token(
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
"""
创建访问令牌
"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.JWT_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
return encoded_jwt
def create_refresh_token(subject: Union[str, Any]) -> str:
"""
创建刷新令牌
"""
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
encoded_jwt = jwt.encode(
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
return encoded_jwt
def verify_token(token: str, token_type: str = "access") -> Optional[str]:
"""
验证令牌并返回主题
"""
try:
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
subject: str = payload.get("sub")
token_type_from_token: str = payload.get("type")
if subject is None or token_type_from_token != token_type:
return None
return subject
except JWTError:
return None
async def get_current_user(
db: AsyncSession = Depends(get_async_session), token: str = Depends(oauth2_scheme)
) -> User:
"""
获取当前登录用户
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
subject = verify_token(token, "access")
if subject is None:
raise credentials_exception
user = await User.get_by_id(db, int(subject))
if user is None:
raise credentials_exception
if not user.status:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
)
return user
async def get_current_active_superuser(
current_user: User = Depends(get_current_user),
) -> User:
"""
获取当前超级用户
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The user doesn't have enough privileges",
)
return current_user
def is_valid_email(email: str) -> bool:
"""
验证邮箱格式
"""
import re
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
return re.match(pattern, email) is not None
def is_valid_phone(phone: str) -> bool:
"""
验证手机号格式(中国)
"""
import re
pattern = r"^1[3-9]\d{9}$"
return re.match(pattern, phone) is not None
def is_valid_id_card(id_card: str) -> bool:
"""
验证身份证号格式(中国)
"""
import re
pattern = r"^[1-9]\d{5}(18|19|([23]\d))\d{2}((0[1-9])|(10|11|12))(([0-2][1-9])|10|20|30|31)\d{3}[0-9Xx]$"
return re.match(pattern, id_card) is not None
def is_valid_social_credit_code(code: str) -> bool:
"""
验证统一社会信用代码(中国)
"""
import re
pattern = r"^[0-9A-HJ-NPQRTUWXY]{2}\d{6}[0-9A-HJ-NPQRTUWXY]{10}$"
return re.match(pattern, code) is not None
def paginate_params(page: int = 1, size: int = 10) -> tuple[int, int]:
"""
分页参数验证和转换
返回 (skip, limit)
"""
if page < 1:
page = 1
if size < 1:
size = 10
if size > 100:
size = 100
skip = (page - 1) * size
limit = size
return skip, limit
def format_file_size(size_bytes: int) -> str:
"""
格式化文件大小
"""
if size_bytes == 0:
return "0B"
size_names = ["B", "KB", "MB", "GB", "TB"]
import math
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f"{s} {size_names[i]}"
def get_client_ip(request) -> str:
"""
获取客户端IP地址
"""
return request.client.host