package repository import ( "errors" "github.com/topfans/backend/pkg/database" appErrors "github.com/topfans/backend/pkg/errors" "github.com/topfans/backend/pkg/models" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) // UserRepository 用户Repository接口 type UserRepository interface { // Create 创建用户 Create(user *models.User) error // GetByID 根据ID查询 GetByID(id int64) (*models.User, error) // GetByMobile 根据手机号查询 GetByMobile(mobile string) (*models.User, error) // ExistsByMobile 检查手机号是否存在 ExistsByMobile(mobile string) (bool, error) // Update 更新用户 Update(user *models.User) error // UpdateToken 更新Token UpdateToken(userID int64, token string, expiresAt int64) error // ClearToken 清除Token(登出) ClearToken(userID int64) error // VerifyPassword 验证密码(bcrypt比对) VerifyPassword(user *models.User, password string) bool // UpdateAvatar 更新用户头像 UpdateAvatar(userID int64, avatarURL string) error // GetAccountStatus 获取用户账号状态 GetAccountStatus(userID int64) (*models.UserAccountStatus, error) } // userRepository 用户Repository实现 type userRepository struct { db *gorm.DB } // NewUserRepository 创建用户Repository实例 func NewUserRepository() UserRepository { return &userRepository{ db: database.GetDB(), } } // Create 创建用户 func (r *userRepository) Create(user *models.User) error { if user == nil { return errors.New("user cannot be nil") } if err := r.db.Create(user).Error; err != nil { return err } return nil } // GetByID 根据ID查询 func (r *userRepository) GetByID(id int64) (*models.User, error) { if id <= 0 { return nil, errors.New("invalid user id") } var user models.User if err := r.db.Where("id = ? AND is_active = ?", id, true).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, appErrors.ErrUserNotFound } return nil, err } return &user, nil } // GetByMobile 根据手机号查询 func (r *userRepository) GetByMobile(mobile string) (*models.User, error) { if mobile == "" { return nil, errors.New("mobile cannot be empty") } var user models.User if err := r.db.Where("mobile = ? AND is_active = ?", mobile, true).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, appErrors.ErrUserNotFound } return nil, err } return &user, nil } // ExistsByMobile 检查手机号是否存在 func (r *userRepository) ExistsByMobile(mobile string) (bool, error) { if mobile == "" { return false, errors.New("mobile cannot be empty") } var count int64 if err := r.db.Model(&models.User{}).Where("mobile = ? AND is_active = ?", mobile, true).Count(&count).Error; err != nil { return false, err } return count > 0, nil } // Update 更新用户 func (r *userRepository) Update(user *models.User) error { if user == nil { return errors.New("user cannot be nil") } if user.ID == 0 { return errors.New("user id cannot be zero") } if err := r.db.Model(user).Updates(user).Error; err != nil { return err } return nil } // UpdateToken 更新Token // 注意:使用 UpdateColumns 而不是 Updates,避免触发 BeforeUpdate 钩子 // 因为更新 Token 不应该改变 updated_at(updated_at 用于验证 Token 有效性) func (r *userRepository) UpdateToken(userID int64, token string, expiresAt int64) error { if userID <= 0 { return errors.New("invalid user id") } result := r.db.Model(&models.User{}). Where("id = ?", userID). UpdateColumns(map[string]interface{}{ "access_token": token, "token_expires_at": expiresAt, }) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return appErrors.ErrUserNotFound } return nil } // ClearToken 清除Token(登出) func (r *userRepository) ClearToken(userID int64) error { if userID <= 0 { return errors.New("invalid user id") } result := r.db.Model(&models.User{}). Where("id = ?", userID). Updates(map[string]interface{}{ "access_token": nil, "token_expires_at": 0, }) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return appErrors.ErrUserNotFound } return nil } // VerifyPassword 验证密码(bcrypt比对) func (r *userRepository) VerifyPassword(user *models.User, password string) bool { if user == nil || password == "" { return false } err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) return err == nil } // HashPassword 加密密码(辅助函数,用于创建用户时) func HashPassword(password string) (string, error) { if password == "" { return "", errors.New("password cannot be empty") } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err } return string(hashedPassword), nil } // UpdateAvatar 更新用户头像 func (r *userRepository) UpdateAvatar(userID int64, avatarURL string) error { if userID <= 0 { return errors.New("invalid user id") } if avatarURL == "" { return errors.New("avatar_url cannot be empty") } result := r.db.Model(&models.User{}). Where("id = ?", userID). Update("avatar_url", avatarURL) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return appErrors.ErrUserNotFound } return nil } // GetAccountStatus 获取用户账号状态 func (r *userRepository) GetAccountStatus(userID int64) (*models.UserAccountStatus, error) { if userID <= 0 { return nil, errors.New("invalid user id") } var status models.UserAccountStatus if err := r.db.Where("user_id = ?", userID).First(&status).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil // 没有记录表示正常 } return nil, err } return &status, nil }