128 lines
3.0 KiB
Python
128 lines
3.0 KiB
Python
"""
|
|
基础模型类
|
|
包含所有表的通用字段和方法
|
|
"""
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Optional
|
|
|
|
from sqlalchemy import (
|
|
Column,
|
|
DateTime,
|
|
Integer,
|
|
String,
|
|
Text,
|
|
func,
|
|
select,
|
|
)
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
|
|
from app.database import Base
|
|
|
|
|
|
class TimestampedMixin:
|
|
"""
|
|
时间戳混入类
|
|
提供 created_at 和 updated_at 字段
|
|
"""
|
|
|
|
created_at: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
server_default=func.now(),
|
|
nullable=False,
|
|
comment="创建时间",
|
|
)
|
|
updated_at: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
server_default=func.now(),
|
|
onupdate=func.now(),
|
|
nullable=False,
|
|
comment="更新时间",
|
|
)
|
|
|
|
|
|
class BaseModel(Base, TimestampedMixin):
|
|
"""
|
|
基础模型类
|
|
所有模型都继承自此类
|
|
"""
|
|
|
|
__abstract__ = True
|
|
|
|
id: Mapped[int] = mapped_column(
|
|
Integer, primary_key=True, index=True, comment="ID"
|
|
)
|
|
remark: Mapped[Optional[str]] = mapped_column(
|
|
Text, nullable=True, comment="备注"
|
|
)
|
|
|
|
async def to_dict(self, exclude: Optional[list] = None) -> Dict[str, Any]:
|
|
"""
|
|
将模型转换为字典
|
|
"""
|
|
exclude = exclude or []
|
|
result = {}
|
|
|
|
for column in self.__table__.columns:
|
|
if column.name in exclude:
|
|
continue
|
|
|
|
value = getattr(self, column.name)
|
|
|
|
# 处理日期时间类型
|
|
if isinstance(value, datetime):
|
|
value = value.isoformat()
|
|
|
|
result[column.name] = value
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
async def get_by_id(
|
|
cls, session: AsyncSession, id: int
|
|
) -> Optional["BaseModel"]:
|
|
"""
|
|
根据ID获取记录
|
|
"""
|
|
result = await session.execute(select(cls).where(cls.id == id))
|
|
return result.scalar_one_or_none()
|
|
|
|
@classmethod
|
|
async def get_all(
|
|
cls, session: AsyncSession, limit: int = 100, offset: int = 0
|
|
) -> list["BaseModel"]:
|
|
"""
|
|
获取所有记录
|
|
"""
|
|
result = await session.execute(
|
|
select(cls).limit(limit).offset(offset).order_by(cls.id.desc())
|
|
)
|
|
return result.scalars().all()
|
|
|
|
async def save(self, session: AsyncSession) -> "BaseModel":
|
|
"""
|
|
保存当前实例
|
|
"""
|
|
session.add(self)
|
|
await session.flush()
|
|
await session.refresh(self)
|
|
return self
|
|
|
|
async def delete(self, session: AsyncSession) -> None:
|
|
"""
|
|
删除当前实例
|
|
"""
|
|
await session.delete(self)
|
|
|
|
async def update(self, session: AsyncSession, **kwargs) -> "BaseModel":
|
|
"""
|
|
更新当前实例
|
|
"""
|
|
for key, value in kwargs.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
|
|
await session.flush()
|
|
await session.refresh(self)
|
|
return self
|