""" 辅助函数和工具 包括密码处理、JWT token处理等 """ from datetime import datetime, timedelta from typing import Any, Optional, Union from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from passlib.context import CryptContext from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.database import get_async_session from app.models.user import User # 密码加密上下文 # 注意:使用pbkdf2_sha256代替bcrypt以避免Python 3.14兼容性问题 pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") # OAuth2 方案 oauth2_scheme = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/auth/login" ) def verify_password(plain_password: str, hashed_password: str) -> bool: """ 验证密码 """ return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: """ 获取密码哈希值 """ return pwd_context.hash(password) def create_access_token( subject: Union[str, Any], expires_delta: Optional[timedelta] = None ) -> str: """ 创建访问令牌 """ if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta( minutes=settings.JWT_EXPIRE_MINUTES ) to_encode = {"exp": expire, "sub": str(subject), "type": "access"} encoded_jwt = jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) return encoded_jwt def create_refresh_token(subject: Union[str, Any]) -> str: """ 创建刷新令牌 """ expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS) to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"} encoded_jwt = jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) return encoded_jwt def verify_token(token: str, token_type: str = "access") -> Optional[str]: """ 验证令牌并返回主题 """ try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) subject: str = payload.get("sub") token_type_from_token: str = payload.get("type") if subject is None or token_type_from_token != token_type: return None return subject except JWTError: return None async def get_current_user( db: AsyncSession = Depends(get_async_session), token: str = Depends(oauth2_scheme) ) -> User: """ 获取当前登录用户 """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) subject = verify_token(token, "access") if subject is None: raise credentials_exception user = await User.get_by_id(db, int(subject)) if user is None: raise credentials_exception if not user.status: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user" ) return user async def get_current_active_superuser( current_user: User = Depends(get_current_user), ) -> User: """ 获取当前超级用户 """ if not current_user.is_superuser: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="The user doesn't have enough privileges", ) return current_user def is_valid_email(email: str) -> bool: """ 验证邮箱格式 """ import re pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" return re.match(pattern, email) is not None def is_valid_phone(phone: str) -> bool: """ 验证手机号格式(中国) """ import re pattern = r"^1[3-9]\d{9}$" return re.match(pattern, phone) is not None def is_valid_id_card(id_card: str) -> bool: """ 验证身份证号格式(中国) """ import re pattern = r"^[1-9]\d{5}(18|19|([23]\d))\d{2}((0[1-9])|(10|11|12))(([0-2][1-9])|10|20|30|31)\d{3}[0-9Xx]$" return re.match(pattern, id_card) is not None def is_valid_social_credit_code(code: str) -> bool: """ 验证统一社会信用代码(中国) """ import re pattern = r"^[0-9A-HJ-NPQRTUWXY]{2}\d{6}[0-9A-HJ-NPQRTUWXY]{10}$" return re.match(pattern, code) is not None def paginate_params(page: int = 1, size: int = 10) -> tuple[int, int]: """ 分页参数验证和转换 返回 (skip, limit) """ if page < 1: page = 1 if size < 1: size = 10 if size > 100: size = 100 skip = (page - 1) * size limit = size return skip, limit def format_file_size(size_bytes: int) -> str: """ 格式化文件大小 """ if size_bytes == 0: return "0B" size_names = ["B", "KB", "MB", "GB", "TB"] import math i = int(math.floor(math.log(size_bytes, 1024))) p = math.pow(1024, i) s = round(size_bytes / p, 2) return f"{s} {size_names[i]}" def get_client_ip(request) -> str: """ 获取客户端IP地址 """ return request.client.host