topfans/backend/services/userService/repository/user_repository.go
2026-04-07 22:29:48 +08:00

231 lines
5.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"errors"
"github.com/topfans/backend/pkg/database"
appErrors "github.com/topfans/backend/pkg/errors"
"github.com/topfans/backend/pkg/models"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// UserRepository 用户Repository接口
type UserRepository interface {
// Create 创建用户
Create(user *models.User) error
// GetByID 根据ID查询
GetByID(id int64) (*models.User, error)
// GetByMobile 根据手机号查询
GetByMobile(mobile string) (*models.User, error)
// ExistsByMobile 检查手机号是否存在
ExistsByMobile(mobile string) (bool, error)
// Update 更新用户
Update(user *models.User) error
// UpdateToken 更新Token
UpdateToken(userID int64, token string, expiresAt int64) error
// ClearToken 清除Token登出
ClearToken(userID int64) error
// VerifyPassword 验证密码bcrypt比对
VerifyPassword(user *models.User, password string) bool
// UpdateAvatar 更新用户头像
UpdateAvatar(userID int64, avatarURL string) error
}
// userRepository 用户Repository实现
type userRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户Repository实例
func NewUserRepository() UserRepository {
return &userRepository{
db: database.GetDB(),
}
}
// Create 创建用户
func (r *userRepository) Create(user *models.User) error {
if user == nil {
return errors.New("user cannot be nil")
}
if err := r.db.Create(user).Error; err != nil {
return err
}
return nil
}
// GetByID 根据ID查询
func (r *userRepository) GetByID(id int64) (*models.User, error) {
if id <= 0 {
return nil, errors.New("invalid user id")
}
var user models.User
if err := r.db.Where("id = ? AND is_active = ?", id, true).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, appErrors.ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// GetByMobile 根据手机号查询
func (r *userRepository) GetByMobile(mobile string) (*models.User, error) {
if mobile == "" {
return nil, errors.New("mobile cannot be empty")
}
var user models.User
if err := r.db.Where("mobile = ? AND is_active = ?", mobile, true).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, appErrors.ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// ExistsByMobile 检查手机号是否存在
func (r *userRepository) ExistsByMobile(mobile string) (bool, error) {
if mobile == "" {
return false, errors.New("mobile cannot be empty")
}
var count int64
if err := r.db.Model(&models.User{}).Where("mobile = ? AND is_active = ?", mobile, true).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
// Update 更新用户
func (r *userRepository) Update(user *models.User) error {
if user == nil {
return errors.New("user cannot be nil")
}
if user.ID == 0 {
return errors.New("user id cannot be zero")
}
if err := r.db.Model(user).Updates(user).Error; err != nil {
return err
}
return nil
}
// UpdateToken 更新Token
// 注意:使用 UpdateColumns 而不是 Updates避免触发 BeforeUpdate 钩子
// 因为更新 Token 不应该改变 updated_atupdated_at 用于验证 Token 有效性)
func (r *userRepository) UpdateToken(userID int64, token string, expiresAt int64) error {
if userID <= 0 {
return errors.New("invalid user id")
}
result := r.db.Model(&models.User{}).
Where("id = ?", userID).
UpdateColumns(map[string]interface{}{
"access_token": token,
"token_expires_at": expiresAt,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return appErrors.ErrUserNotFound
}
return nil
}
// ClearToken 清除Token登出
func (r *userRepository) ClearToken(userID int64) error {
if userID <= 0 {
return errors.New("invalid user id")
}
result := r.db.Model(&models.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"access_token": nil,
"token_expires_at": 0,
})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return appErrors.ErrUserNotFound
}
return nil
}
// VerifyPassword 验证密码bcrypt比对
func (r *userRepository) VerifyPassword(user *models.User, password string) bool {
if user == nil || password == "" {
return false
}
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
return err == nil
}
// HashPassword 加密密码(辅助函数,用于创建用户时)
func HashPassword(password string) (string, error) {
if password == "" {
return "", errors.New("password cannot be empty")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedPassword), nil
}
// UpdateAvatar 更新用户头像
func (r *userRepository) UpdateAvatar(userID int64, avatarURL string) error {
if userID <= 0 {
return errors.New("invalid user id")
}
if avatarURL == "" {
return errors.New("avatar_url cannot be empty")
}
result := r.db.Model(&models.User{}).
Where("id = ?", userID).
Update("avatar_url", avatarURL)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return appErrors.ErrUserNotFound
}
return nil
}