215 lines
5.2 KiB
Python
215 lines
5.2 KiB
Python
"""
|
||
辅助函数和工具
|
||
包括密码处理、JWT token处理等
|
||
"""
|
||
from datetime import datetime, timedelta
|
||
from typing import Any, Optional, Union
|
||
|
||
from fastapi import Depends, HTTPException, status
|
||
from fastapi.security import OAuth2PasswordBearer
|
||
from jose import JWTError, jwt
|
||
from passlib.context import CryptContext
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.config import settings
|
||
from app.database import get_async_session
|
||
from app.models.user import User
|
||
|
||
# 密码加密上下文
|
||
# 注意:使用pbkdf2_sha256代替bcrypt以避免Python 3.14兼容性问题
|
||
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||
|
||
# OAuth2 方案
|
||
oauth2_scheme = OAuth2PasswordBearer(
|
||
tokenUrl=f"{settings.API_V1_STR}/auth/login"
|
||
)
|
||
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
"""
|
||
验证密码
|
||
"""
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
|
||
def get_password_hash(password: str) -> str:
|
||
"""
|
||
获取密码哈希值
|
||
"""
|
||
return pwd_context.hash(password)
|
||
|
||
|
||
def create_access_token(
|
||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||
) -> str:
|
||
"""
|
||
创建访问令牌
|
||
"""
|
||
if expires_delta:
|
||
expire = datetime.utcnow() + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(
|
||
minutes=settings.JWT_EXPIRE_MINUTES
|
||
)
|
||
|
||
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
|
||
encoded_jwt = jwt.encode(
|
||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||
)
|
||
return encoded_jwt
|
||
|
||
|
||
def create_refresh_token(subject: Union[str, Any]) -> str:
|
||
"""
|
||
创建刷新令牌
|
||
"""
|
||
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_EXPIRE_DAYS)
|
||
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
|
||
encoded_jwt = jwt.encode(
|
||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||
)
|
||
return encoded_jwt
|
||
|
||
|
||
def verify_token(token: str, token_type: str = "access") -> Optional[str]:
|
||
"""
|
||
验证令牌并返回主题
|
||
"""
|
||
try:
|
||
payload = jwt.decode(
|
||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||
)
|
||
subject: str = payload.get("sub")
|
||
token_type_from_token: str = payload.get("type")
|
||
|
||
if subject is None or token_type_from_token != token_type:
|
||
return None
|
||
|
||
return subject
|
||
except JWTError:
|
||
return None
|
||
|
||
|
||
async def get_current_user(
|
||
db: AsyncSession = Depends(get_async_session), token: str = Depends(oauth2_scheme)
|
||
) -> User:
|
||
"""
|
||
获取当前登录用户
|
||
"""
|
||
credentials_exception = HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Could not validate credentials",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
subject = verify_token(token, "access")
|
||
if subject is None:
|
||
raise credentials_exception
|
||
|
||
user = await User.get_by_id(db, int(subject))
|
||
if user is None:
|
||
raise credentials_exception
|
||
|
||
if not user.status:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
|
||
)
|
||
|
||
return user
|
||
|
||
|
||
async def get_current_active_superuser(
|
||
current_user: User = Depends(get_current_user),
|
||
) -> User:
|
||
"""
|
||
获取当前超级用户
|
||
"""
|
||
if not current_user.is_superuser:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="The user doesn't have enough privileges",
|
||
)
|
||
return current_user
|
||
|
||
|
||
def is_valid_email(email: str) -> bool:
|
||
"""
|
||
验证邮箱格式
|
||
"""
|
||
import re
|
||
|
||
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||
return re.match(pattern, email) is not None
|
||
|
||
|
||
def is_valid_phone(phone: str) -> bool:
|
||
"""
|
||
验证手机号格式(中国)
|
||
"""
|
||
import re
|
||
|
||
pattern = r"^1[3-9]\d{9}$"
|
||
return re.match(pattern, phone) is not None
|
||
|
||
|
||
def is_valid_id_card(id_card: str) -> bool:
|
||
"""
|
||
验证身份证号格式(中国)
|
||
"""
|
||
import re
|
||
|
||
pattern = r"^[1-9]\d{5}(18|19|([23]\d))\d{2}((0[1-9])|(10|11|12))(([0-2][1-9])|10|20|30|31)\d{3}[0-9Xx]$"
|
||
return re.match(pattern, id_card) is not None
|
||
|
||
|
||
def is_valid_social_credit_code(code: str) -> bool:
|
||
"""
|
||
验证统一社会信用代码(中国)
|
||
"""
|
||
import re
|
||
|
||
pattern = r"^[0-9A-HJ-NPQRTUWXY]{2}\d{6}[0-9A-HJ-NPQRTUWXY]{10}$"
|
||
return re.match(pattern, code) is not None
|
||
|
||
|
||
def paginate_params(page: int = 1, size: int = 10) -> tuple[int, int]:
|
||
"""
|
||
分页参数验证和转换
|
||
返回 (skip, limit)
|
||
"""
|
||
if page < 1:
|
||
page = 1
|
||
if size < 1:
|
||
size = 10
|
||
if size > 100:
|
||
size = 100
|
||
|
||
skip = (page - 1) * size
|
||
limit = size
|
||
|
||
return skip, limit
|
||
|
||
|
||
def format_file_size(size_bytes: int) -> str:
|
||
"""
|
||
格式化文件大小
|
||
"""
|
||
if size_bytes == 0:
|
||
return "0B"
|
||
|
||
size_names = ["B", "KB", "MB", "GB", "TB"]
|
||
import math
|
||
|
||
i = int(math.floor(math.log(size_bytes, 1024)))
|
||
p = math.pow(1024, i)
|
||
s = round(size_bytes / p, 2)
|
||
|
||
return f"{s} {size_names[i]}"
|
||
|
||
|
||
def get_client_ip(request) -> str:
|
||
"""
|
||
获取客户端IP地址
|
||
"""
|
||
return request.client.host
|