package repository import ( "testing" "github.com/topfans/backend/pkg/database" "github.com/topfans/backend/pkg/models" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) // setupTestDB 设置测试数据库 func setupTestDB(t *testing.T) *gorm.DB { config := database.Config{ Host: "localhost", Port: 5432, User: "haihuizhu", Password: "admin", DBName: "top-fans", // 使用主数据库进行测试 SSLMode: "disable", TimeZone: "Asia/Shanghai", } // 注意:测试时需要创建测试数据库或使用现有的测试数据库 if err := database.Init(config); err != nil { t.Skipf("Skipping test: failed to connect to test database: %v", err) } return database.GetDB() } // cleanupTestDB 清理测试数据 func cleanupTestDB(t *testing.T, db *gorm.DB) { // 清理测试数据(注意外键约束,先删除fan_profiles) db.Exec("DELETE FROM fan_profiles WHERE user_id IN (SELECT id FROM users WHERE mobile LIKE '138%')") db.Exec("DELETE FROM users WHERE mobile LIKE '138%'") db.Exec("DELETE FROM stars WHERE identity_id LIKE 'test_%'") } func TestUserRepository_Create(t *testing.T) { db := setupTestDB(t) defer cleanupTestDB(t, db) repo := NewUserRepository() hashedPassword, err := HashPassword("password123") if err != nil { t.Fatalf("HashPassword failed: %v", err) } user := &models.User{ Mobile: "13800000001", PasswordHash: hashedPassword, IsActive: true, } err = repo.Create(user) if err != nil { t.Fatalf("Create failed: %v", err) } if user.ID == 0 { t.Fatal("User ID should not be zero after creation") } if user.CreatedAt == 0 { t.Fatal("CreatedAt should be set") } if user.UpdatedAt == 0 { t.Fatal("UpdatedAt should be set") } // 验证用户已创建 retrieved, err := repo.GetByID(user.ID) if err != nil { t.Fatalf("GetByID failed: %v", err) } if retrieved.Mobile != user.Mobile { t.Errorf("Mobile mismatch: expected %s, got %s", user.Mobile, retrieved.Mobile) } } func TestUserRepository_GetByID(t *testing.T) { db := setupTestDB(t) defer cleanupTestDB(t, db) repo := NewUserRepository() hashedPassword, _ := HashPassword("password123") user := &models.User{ Mobile: "13800000002", PasswordHash: hashedPassword, IsActive: true, } repo.Create(user) // 测试正常查询 retrieved, err := repo.GetByID(user.ID) if err != nil { t.Fatalf("GetByID failed: %v", err) } if retrieved.ID != user.ID { t.Errorf("ID mismatch: expected %d, got %d", user.ID, retrieved.ID) } // 测试不存在的用户 _, err = repo.GetByID(99999999) if err == nil { t.Fatal("Expected error for non-existent user") } // 测试无效ID _, err = repo.GetByID(0) if err == nil { t.Fatal("Expected error for invalid user id") } } func TestUserRepository_GetByMobile(t *testing.T) { db := setupTestDB(t) defer cleanupTestDB(t, db) repo := NewUserRepository() hashedPassword, _ := HashPassword("password123") user := &models.User{ Mobile: "13800000003", PasswordHash: hashedPassword, IsActive: true, } repo.Create(user) // 测试正常查询 retrieved, err := repo.GetByMobile(user.Mobile) if err != nil { t.Fatalf("GetByMobile failed: %v", err) } if retrieved.Mobile != user.Mobile { t.Errorf("Mobile mismatch: expected %s, got %s", user.Mobile, retrieved.Mobile) } // 测试不存在的手机号 _, err = repo.GetByMobile("13999999999") if err == nil { t.Fatal("Expected error for non-existent mobile") } // 测试空手机号 _, err = repo.GetByMobile("") if err == nil { t.Fatal("Expected error for empty mobile") } } func TestUserRepository_Update(t *testing.T) { db := setupTestDB(t) defer cleanupTestDB(t, db) repo := NewUserRepository() hashedPassword, _ := HashPassword("password123") user := &models.User{ Mobile: "13800000004", PasswordHash: hashedPassword, IsActive: true, } repo.Create(user) // 更新用户信息 avatarURL := "https://example.com/avatar.jpg" user.AvatarURL = &avatarURL err := repo.Update(user) if err != nil { t.Fatalf("Update failed: %v", err) } // 验证更新 retrieved, err := repo.GetByID(user.ID) if err != nil { t.Fatalf("GetByID failed: %v", err) } if retrieved.AvatarURL == nil || *retrieved.AvatarURL != avatarURL { t.Errorf("AvatarURL mismatch: expected %s, got %v", avatarURL, retrieved.AvatarURL) } } func TestUserRepository_UpdateToken(t *testing.T) { db := setupTestDB(t) defer cleanupTestDB(t, db) repo := NewUserRepository() hashedPassword, _ := HashPassword("password123") user := &models.User{ Mobile: "13800000005", PasswordHash: hashedPassword, IsActive: true, } repo.Create(user) token := "test.jwt.token" expiresAt := int64(1704672000000) // 2024-01-08 00:00:00 err := repo.UpdateToken(user.ID, token, expiresAt) if err != nil { t.Fatalf("UpdateToken failed: %v", err) } // 验证更新 retrieved, err := repo.GetByID(user.ID) if err != nil { t.Fatalf("GetByID failed: %v", err) } if retrieved.AccessToken == nil || *retrieved.AccessToken != token { t.Errorf("AccessToken mismatch: expected %s, got %v", token, retrieved.AccessToken) } if retrieved.TokenExpiresAt == nil || *retrieved.TokenExpiresAt != expiresAt { t.Errorf("TokenExpiresAt mismatch: expected %d, got %v", expiresAt, retrieved.TokenExpiresAt) } } func TestUserRepository_ClearToken(t *testing.T) { db := setupTestDB(t) defer cleanupTestDB(t, db) repo := NewUserRepository() hashedPassword, _ := HashPassword("password123") user := &models.User{ Mobile: "13800000006", PasswordHash: hashedPassword, IsActive: true, } repo.Create(user) // 先设置Token token := "test.jwt.token" expiresAt := int64(1704672000000) repo.UpdateToken(user.ID, token, expiresAt) // 清除Token err := repo.ClearToken(user.ID) if err != nil { t.Fatalf("ClearToken failed: %v", err) } // 验证Token已清除 retrieved, err := repo.GetByID(user.ID) if err != nil { t.Fatalf("GetByID failed: %v", err) } if retrieved.AccessToken != nil { t.Errorf("AccessToken should be nil after ClearToken, got %v", retrieved.AccessToken) } if retrieved.TokenExpiresAt != nil && *retrieved.TokenExpiresAt != 0 { t.Errorf("TokenExpiresAt should be 0 after ClearToken, got %v", retrieved.TokenExpiresAt) } } func TestUserRepository_VerifyPassword(t *testing.T) { repo := NewUserRepository() password := "password123" hashedPassword, err := HashPassword(password) if err != nil { t.Fatalf("HashPassword failed: %v", err) } user := &models.User{ PasswordHash: hashedPassword, } // 测试正确密码 if !repo.VerifyPassword(user, password) { t.Fatal("VerifyPassword should return true for correct password") } // 测试错误密码 if repo.VerifyPassword(user, "wrongpassword") { t.Fatal("VerifyPassword should return false for wrong password") } // 测试空密码 if repo.VerifyPassword(user, "") { t.Fatal("VerifyPassword should return false for empty password") } // 测试nil用户 if repo.VerifyPassword(nil, password) { t.Fatal("VerifyPassword should return false for nil user") } } func TestHashPassword(t *testing.T) { password := "password123" hashed1, err := HashPassword(password) if err != nil { t.Fatalf("HashPassword failed: %v", err) } if hashed1 == "" { t.Fatal("Hashed password should not be empty") } if hashed1 == password { t.Fatal("Hashed password should not equal plain password") } // 测试每次加密结果不同(因为salt) hashed2, err := HashPassword(password) if err != nil { t.Fatalf("HashPassword failed: %v", err) } if hashed1 == hashed2 { t.Fatal("Different hashes should be generated for the same password") } // 但都能验证通过 if err := bcrypt.CompareHashAndPassword([]byte(hashed1), []byte(password)); err != nil { t.Fatal("First hash should verify correctly") } if err := bcrypt.CompareHashAndPassword([]byte(hashed2), []byte(password)); err != nil { t.Fatal("Second hash should verify correctly") } // 测试空密码 _, err = HashPassword("") if err == nil { t.Fatal("HashPassword should return error for empty password") } }