topfans/backend/services/userService/repository/user_repository_test.go
2026-04-07 22:29:48 +08:00

345 lines
8.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"testing"
"github.com/topfans/backend/pkg/database"
"github.com/topfans/backend/pkg/models"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// setupTestDB 设置测试数据库
func setupTestDB(t *testing.T) *gorm.DB {
config := database.Config{
Host: "localhost",
Port: 5432,
User: "haihuizhu",
Password: "admin",
DBName: "top-fans", // 使用主数据库进行测试
SSLMode: "disable",
TimeZone: "Asia/Shanghai",
}
// 注意:测试时需要创建测试数据库或使用现有的测试数据库
if err := database.Init(config); err != nil {
t.Skipf("Skipping test: failed to connect to test database: %v", err)
}
return database.GetDB()
}
// cleanupTestDB 清理测试数据
func cleanupTestDB(t *testing.T, db *gorm.DB) {
// 清理测试数据注意外键约束先删除fan_profiles
db.Exec("DELETE FROM fan_profiles WHERE user_id IN (SELECT id FROM users WHERE mobile LIKE '138%')")
db.Exec("DELETE FROM users WHERE mobile LIKE '138%'")
db.Exec("DELETE FROM stars WHERE identity_id LIKE 'test_%'")
}
func TestUserRepository_Create(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
repo := NewUserRepository()
hashedPassword, err := HashPassword("password123")
if err != nil {
t.Fatalf("HashPassword failed: %v", err)
}
user := &models.User{
Mobile: "13800000001",
PasswordHash: hashedPassword,
IsActive: true,
}
err = repo.Create(user)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
if user.ID == 0 {
t.Fatal("User ID should not be zero after creation")
}
if user.CreatedAt == 0 {
t.Fatal("CreatedAt should be set")
}
if user.UpdatedAt == 0 {
t.Fatal("UpdatedAt should be set")
}
// 验证用户已创建
retrieved, err := repo.GetByID(user.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if retrieved.Mobile != user.Mobile {
t.Errorf("Mobile mismatch: expected %s, got %s", user.Mobile, retrieved.Mobile)
}
}
func TestUserRepository_GetByID(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
repo := NewUserRepository()
hashedPassword, _ := HashPassword("password123")
user := &models.User{
Mobile: "13800000002",
PasswordHash: hashedPassword,
IsActive: true,
}
repo.Create(user)
// 测试正常查询
retrieved, err := repo.GetByID(user.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if retrieved.ID != user.ID {
t.Errorf("ID mismatch: expected %d, got %d", user.ID, retrieved.ID)
}
// 测试不存在的用户
_, err = repo.GetByID(99999999)
if err == nil {
t.Fatal("Expected error for non-existent user")
}
// 测试无效ID
_, err = repo.GetByID(0)
if err == nil {
t.Fatal("Expected error for invalid user id")
}
}
func TestUserRepository_GetByMobile(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
repo := NewUserRepository()
hashedPassword, _ := HashPassword("password123")
user := &models.User{
Mobile: "13800000003",
PasswordHash: hashedPassword,
IsActive: true,
}
repo.Create(user)
// 测试正常查询
retrieved, err := repo.GetByMobile(user.Mobile)
if err != nil {
t.Fatalf("GetByMobile failed: %v", err)
}
if retrieved.Mobile != user.Mobile {
t.Errorf("Mobile mismatch: expected %s, got %s", user.Mobile, retrieved.Mobile)
}
// 测试不存在的手机号
_, err = repo.GetByMobile("13999999999")
if err == nil {
t.Fatal("Expected error for non-existent mobile")
}
// 测试空手机号
_, err = repo.GetByMobile("")
if err == nil {
t.Fatal("Expected error for empty mobile")
}
}
func TestUserRepository_Update(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
repo := NewUserRepository()
hashedPassword, _ := HashPassword("password123")
user := &models.User{
Mobile: "13800000004",
PasswordHash: hashedPassword,
IsActive: true,
}
repo.Create(user)
// 更新用户信息
avatarURL := "https://example.com/avatar.jpg"
user.AvatarURL = &avatarURL
err := repo.Update(user)
if err != nil {
t.Fatalf("Update failed: %v", err)
}
// 验证更新
retrieved, err := repo.GetByID(user.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if retrieved.AvatarURL == nil || *retrieved.AvatarURL != avatarURL {
t.Errorf("AvatarURL mismatch: expected %s, got %v", avatarURL, retrieved.AvatarURL)
}
}
func TestUserRepository_UpdateToken(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
repo := NewUserRepository()
hashedPassword, _ := HashPassword("password123")
user := &models.User{
Mobile: "13800000005",
PasswordHash: hashedPassword,
IsActive: true,
}
repo.Create(user)
token := "test.jwt.token"
expiresAt := int64(1704672000000) // 2024-01-08 00:00:00
err := repo.UpdateToken(user.ID, token, expiresAt)
if err != nil {
t.Fatalf("UpdateToken failed: %v", err)
}
// 验证更新
retrieved, err := repo.GetByID(user.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if retrieved.AccessToken == nil || *retrieved.AccessToken != token {
t.Errorf("AccessToken mismatch: expected %s, got %v", token, retrieved.AccessToken)
}
if retrieved.TokenExpiresAt == nil || *retrieved.TokenExpiresAt != expiresAt {
t.Errorf("TokenExpiresAt mismatch: expected %d, got %v", expiresAt, retrieved.TokenExpiresAt)
}
}
func TestUserRepository_ClearToken(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
repo := NewUserRepository()
hashedPassword, _ := HashPassword("password123")
user := &models.User{
Mobile: "13800000006",
PasswordHash: hashedPassword,
IsActive: true,
}
repo.Create(user)
// 先设置Token
token := "test.jwt.token"
expiresAt := int64(1704672000000)
repo.UpdateToken(user.ID, token, expiresAt)
// 清除Token
err := repo.ClearToken(user.ID)
if err != nil {
t.Fatalf("ClearToken failed: %v", err)
}
// 验证Token已清除
retrieved, err := repo.GetByID(user.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if retrieved.AccessToken != nil {
t.Errorf("AccessToken should be nil after ClearToken, got %v", retrieved.AccessToken)
}
if retrieved.TokenExpiresAt != nil && *retrieved.TokenExpiresAt != 0 {
t.Errorf("TokenExpiresAt should be 0 after ClearToken, got %v", retrieved.TokenExpiresAt)
}
}
func TestUserRepository_VerifyPassword(t *testing.T) {
repo := NewUserRepository()
password := "password123"
hashedPassword, err := HashPassword(password)
if err != nil {
t.Fatalf("HashPassword failed: %v", err)
}
user := &models.User{
PasswordHash: hashedPassword,
}
// 测试正确密码
if !repo.VerifyPassword(user, password) {
t.Fatal("VerifyPassword should return true for correct password")
}
// 测试错误密码
if repo.VerifyPassword(user, "wrongpassword") {
t.Fatal("VerifyPassword should return false for wrong password")
}
// 测试空密码
if repo.VerifyPassword(user, "") {
t.Fatal("VerifyPassword should return false for empty password")
}
// 测试nil用户
if repo.VerifyPassword(nil, password) {
t.Fatal("VerifyPassword should return false for nil user")
}
}
func TestHashPassword(t *testing.T) {
password := "password123"
hashed1, err := HashPassword(password)
if err != nil {
t.Fatalf("HashPassword failed: %v", err)
}
if hashed1 == "" {
t.Fatal("Hashed password should not be empty")
}
if hashed1 == password {
t.Fatal("Hashed password should not equal plain password")
}
// 测试每次加密结果不同因为salt
hashed2, err := HashPassword(password)
if err != nil {
t.Fatalf("HashPassword failed: %v", err)
}
if hashed1 == hashed2 {
t.Fatal("Different hashes should be generated for the same password")
}
// 但都能验证通过
if err := bcrypt.CompareHashAndPassword([]byte(hashed1), []byte(password)); err != nil {
t.Fatal("First hash should verify correctly")
}
if err := bcrypt.CompareHashAndPassword([]byte(hashed2), []byte(password)); err != nil {
t.Fatal("Second hash should verify correctly")
}
// 测试空密码
_, err = HashPassword("")
if err == nil {
t.Fatal("HashPassword should return error for empty password")
}
}