231 lines
5.2 KiB
Go
231 lines
5.2 KiB
Go
package repository
|
||
|
||
import (
|
||
"errors"
|
||
|
||
"github.com/topfans/backend/pkg/database"
|
||
appErrors "github.com/topfans/backend/pkg/errors"
|
||
"github.com/topfans/backend/pkg/models"
|
||
"golang.org/x/crypto/bcrypt"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// UserRepository 用户Repository接口
|
||
type UserRepository interface {
|
||
// Create 创建用户
|
||
Create(user *models.User) error
|
||
|
||
// GetByID 根据ID查询
|
||
GetByID(id int64) (*models.User, error)
|
||
|
||
// GetByMobile 根据手机号查询
|
||
GetByMobile(mobile string) (*models.User, error)
|
||
|
||
// ExistsByMobile 检查手机号是否存在
|
||
ExistsByMobile(mobile string) (bool, error)
|
||
|
||
// Update 更新用户
|
||
Update(user *models.User) error
|
||
|
||
// UpdateToken 更新Token
|
||
UpdateToken(userID int64, token string, expiresAt int64) error
|
||
|
||
// ClearToken 清除Token(登出)
|
||
ClearToken(userID int64) error
|
||
|
||
// VerifyPassword 验证密码(bcrypt比对)
|
||
VerifyPassword(user *models.User, password string) bool
|
||
|
||
// UpdateAvatar 更新用户头像
|
||
UpdateAvatar(userID int64, avatarURL string) error
|
||
}
|
||
|
||
// userRepository 用户Repository实现
|
||
type userRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewUserRepository 创建用户Repository实例
|
||
func NewUserRepository() UserRepository {
|
||
return &userRepository{
|
||
db: database.GetDB(),
|
||
}
|
||
}
|
||
|
||
// Create 创建用户
|
||
func (r *userRepository) Create(user *models.User) error {
|
||
if user == nil {
|
||
return errors.New("user cannot be nil")
|
||
}
|
||
|
||
if err := r.db.Create(user).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetByID 根据ID查询
|
||
func (r *userRepository) GetByID(id int64) (*models.User, error) {
|
||
if id <= 0 {
|
||
return nil, errors.New("invalid user id")
|
||
}
|
||
|
||
var user models.User
|
||
if err := r.db.Where("id = ? AND is_active = ?", id, true).First(&user).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, appErrors.ErrUserNotFound
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByMobile 根据手机号查询
|
||
func (r *userRepository) GetByMobile(mobile string) (*models.User, error) {
|
||
if mobile == "" {
|
||
return nil, errors.New("mobile cannot be empty")
|
||
}
|
||
|
||
var user models.User
|
||
if err := r.db.Where("mobile = ? AND is_active = ?", mobile, true).First(&user).Error; err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, appErrors.ErrUserNotFound
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
return &user, nil
|
||
}
|
||
|
||
// ExistsByMobile 检查手机号是否存在
|
||
func (r *userRepository) ExistsByMobile(mobile string) (bool, error) {
|
||
if mobile == "" {
|
||
return false, errors.New("mobile cannot be empty")
|
||
}
|
||
|
||
var count int64
|
||
if err := r.db.Model(&models.User{}).Where("mobile = ? AND is_active = ?", mobile, true).Count(&count).Error; err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return count > 0, nil
|
||
}
|
||
|
||
// Update 更新用户
|
||
func (r *userRepository) Update(user *models.User) error {
|
||
if user == nil {
|
||
return errors.New("user cannot be nil")
|
||
}
|
||
|
||
if user.ID == 0 {
|
||
return errors.New("user id cannot be zero")
|
||
}
|
||
|
||
if err := r.db.Model(user).Updates(user).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateToken 更新Token
|
||
// 注意:使用 UpdateColumns 而不是 Updates,避免触发 BeforeUpdate 钩子
|
||
// 因为更新 Token 不应该改变 updated_at(updated_at 用于验证 Token 有效性)
|
||
func (r *userRepository) UpdateToken(userID int64, token string, expiresAt int64) error {
|
||
if userID <= 0 {
|
||
return errors.New("invalid user id")
|
||
}
|
||
|
||
result := r.db.Model(&models.User{}).
|
||
Where("id = ?", userID).
|
||
UpdateColumns(map[string]interface{}{
|
||
"access_token": token,
|
||
"token_expires_at": expiresAt,
|
||
})
|
||
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return appErrors.ErrUserNotFound
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ClearToken 清除Token(登出)
|
||
func (r *userRepository) ClearToken(userID int64) error {
|
||
if userID <= 0 {
|
||
return errors.New("invalid user id")
|
||
}
|
||
|
||
result := r.db.Model(&models.User{}).
|
||
Where("id = ?", userID).
|
||
Updates(map[string]interface{}{
|
||
"access_token": nil,
|
||
"token_expires_at": 0,
|
||
})
|
||
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return appErrors.ErrUserNotFound
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// VerifyPassword 验证密码(bcrypt比对)
|
||
func (r *userRepository) VerifyPassword(user *models.User, password string) bool {
|
||
if user == nil || password == "" {
|
||
return false
|
||
}
|
||
|
||
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
||
return err == nil
|
||
}
|
||
|
||
// HashPassword 加密密码(辅助函数,用于创建用户时)
|
||
func HashPassword(password string) (string, error) {
|
||
if password == "" {
|
||
return "", errors.New("password cannot be empty")
|
||
}
|
||
|
||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return string(hashedPassword), nil
|
||
}
|
||
|
||
// UpdateAvatar 更新用户头像
|
||
func (r *userRepository) UpdateAvatar(userID int64, avatarURL string) error {
|
||
if userID <= 0 {
|
||
return errors.New("invalid user id")
|
||
}
|
||
|
||
if avatarURL == "" {
|
||
return errors.New("avatar_url cannot be empty")
|
||
}
|
||
|
||
result := r.db.Model(&models.User{}).
|
||
Where("id = ?", userID).
|
||
Update("avatar_url", avatarURL)
|
||
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return appErrors.ErrUserNotFound
|
||
}
|
||
|
||
return nil
|
||
}
|