345 lines
8.0 KiB
Go
345 lines
8.0 KiB
Go
package repository
|
||
|
||
import (
|
||
"testing"
|
||
|
||
"github.com/topfans/backend/pkg/database"
|
||
"github.com/topfans/backend/pkg/models"
|
||
"golang.org/x/crypto/bcrypt"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// setupTestDB 设置测试数据库
|
||
func setupTestDB(t *testing.T) *gorm.DB {
|
||
config := database.Config{
|
||
Host: "localhost",
|
||
Port: 5432,
|
||
User: "haihuizhu",
|
||
Password: "admin",
|
||
DBName: "top-fans", // 使用主数据库进行测试
|
||
SSLMode: "disable",
|
||
TimeZone: "Asia/Shanghai",
|
||
}
|
||
|
||
// 注意:测试时需要创建测试数据库或使用现有的测试数据库
|
||
if err := database.Init(config); err != nil {
|
||
t.Skipf("Skipping test: failed to connect to test database: %v", err)
|
||
}
|
||
|
||
return database.GetDB()
|
||
}
|
||
|
||
// cleanupTestDB 清理测试数据
|
||
func cleanupTestDB(t *testing.T, db *gorm.DB) {
|
||
// 清理测试数据(注意外键约束,先删除fan_profiles)
|
||
db.Exec("DELETE FROM fan_profiles WHERE user_id IN (SELECT id FROM users WHERE mobile LIKE '138%')")
|
||
db.Exec("DELETE FROM users WHERE mobile LIKE '138%'")
|
||
db.Exec("DELETE FROM stars WHERE identity_id LIKE 'test_%'")
|
||
}
|
||
|
||
func TestUserRepository_Create(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer cleanupTestDB(t, db)
|
||
|
||
repo := NewUserRepository()
|
||
|
||
hashedPassword, err := HashPassword("password123")
|
||
if err != nil {
|
||
t.Fatalf("HashPassword failed: %v", err)
|
||
}
|
||
|
||
user := &models.User{
|
||
Mobile: "13800000001",
|
||
PasswordHash: hashedPassword,
|
||
IsActive: true,
|
||
}
|
||
|
||
err = repo.Create(user)
|
||
if err != nil {
|
||
t.Fatalf("Create failed: %v", err)
|
||
}
|
||
|
||
if user.ID == 0 {
|
||
t.Fatal("User ID should not be zero after creation")
|
||
}
|
||
|
||
if user.CreatedAt == 0 {
|
||
t.Fatal("CreatedAt should be set")
|
||
}
|
||
|
||
if user.UpdatedAt == 0 {
|
||
t.Fatal("UpdatedAt should be set")
|
||
}
|
||
|
||
// 验证用户已创建
|
||
retrieved, err := repo.GetByID(user.ID)
|
||
if err != nil {
|
||
t.Fatalf("GetByID failed: %v", err)
|
||
}
|
||
|
||
if retrieved.Mobile != user.Mobile {
|
||
t.Errorf("Mobile mismatch: expected %s, got %s", user.Mobile, retrieved.Mobile)
|
||
}
|
||
}
|
||
|
||
func TestUserRepository_GetByID(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer cleanupTestDB(t, db)
|
||
|
||
repo := NewUserRepository()
|
||
|
||
hashedPassword, _ := HashPassword("password123")
|
||
user := &models.User{
|
||
Mobile: "13800000002",
|
||
PasswordHash: hashedPassword,
|
||
IsActive: true,
|
||
}
|
||
repo.Create(user)
|
||
|
||
// 测试正常查询
|
||
retrieved, err := repo.GetByID(user.ID)
|
||
if err != nil {
|
||
t.Fatalf("GetByID failed: %v", err)
|
||
}
|
||
|
||
if retrieved.ID != user.ID {
|
||
t.Errorf("ID mismatch: expected %d, got %d", user.ID, retrieved.ID)
|
||
}
|
||
|
||
// 测试不存在的用户
|
||
_, err = repo.GetByID(99999999)
|
||
if err == nil {
|
||
t.Fatal("Expected error for non-existent user")
|
||
}
|
||
|
||
// 测试无效ID
|
||
_, err = repo.GetByID(0)
|
||
if err == nil {
|
||
t.Fatal("Expected error for invalid user id")
|
||
}
|
||
}
|
||
|
||
func TestUserRepository_GetByMobile(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer cleanupTestDB(t, db)
|
||
|
||
repo := NewUserRepository()
|
||
|
||
hashedPassword, _ := HashPassword("password123")
|
||
user := &models.User{
|
||
Mobile: "13800000003",
|
||
PasswordHash: hashedPassword,
|
||
IsActive: true,
|
||
}
|
||
repo.Create(user)
|
||
|
||
// 测试正常查询
|
||
retrieved, err := repo.GetByMobile(user.Mobile)
|
||
if err != nil {
|
||
t.Fatalf("GetByMobile failed: %v", err)
|
||
}
|
||
|
||
if retrieved.Mobile != user.Mobile {
|
||
t.Errorf("Mobile mismatch: expected %s, got %s", user.Mobile, retrieved.Mobile)
|
||
}
|
||
|
||
// 测试不存在的手机号
|
||
_, err = repo.GetByMobile("13999999999")
|
||
if err == nil {
|
||
t.Fatal("Expected error for non-existent mobile")
|
||
}
|
||
|
||
// 测试空手机号
|
||
_, err = repo.GetByMobile("")
|
||
if err == nil {
|
||
t.Fatal("Expected error for empty mobile")
|
||
}
|
||
}
|
||
|
||
func TestUserRepository_Update(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer cleanupTestDB(t, db)
|
||
|
||
repo := NewUserRepository()
|
||
|
||
hashedPassword, _ := HashPassword("password123")
|
||
user := &models.User{
|
||
Mobile: "13800000004",
|
||
PasswordHash: hashedPassword,
|
||
IsActive: true,
|
||
}
|
||
repo.Create(user)
|
||
|
||
// 更新用户信息
|
||
avatarURL := "https://example.com/avatar.jpg"
|
||
user.AvatarURL = &avatarURL
|
||
|
||
err := repo.Update(user)
|
||
if err != nil {
|
||
t.Fatalf("Update failed: %v", err)
|
||
}
|
||
|
||
// 验证更新
|
||
retrieved, err := repo.GetByID(user.ID)
|
||
if err != nil {
|
||
t.Fatalf("GetByID failed: %v", err)
|
||
}
|
||
|
||
if retrieved.AvatarURL == nil || *retrieved.AvatarURL != avatarURL {
|
||
t.Errorf("AvatarURL mismatch: expected %s, got %v", avatarURL, retrieved.AvatarURL)
|
||
}
|
||
}
|
||
|
||
func TestUserRepository_UpdateToken(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer cleanupTestDB(t, db)
|
||
|
||
repo := NewUserRepository()
|
||
|
||
hashedPassword, _ := HashPassword("password123")
|
||
user := &models.User{
|
||
Mobile: "13800000005",
|
||
PasswordHash: hashedPassword,
|
||
IsActive: true,
|
||
}
|
||
repo.Create(user)
|
||
|
||
token := "test.jwt.token"
|
||
expiresAt := int64(1704672000000) // 2024-01-08 00:00:00
|
||
|
||
err := repo.UpdateToken(user.ID, token, expiresAt)
|
||
if err != nil {
|
||
t.Fatalf("UpdateToken failed: %v", err)
|
||
}
|
||
|
||
// 验证更新
|
||
retrieved, err := repo.GetByID(user.ID)
|
||
if err != nil {
|
||
t.Fatalf("GetByID failed: %v", err)
|
||
}
|
||
|
||
if retrieved.AccessToken == nil || *retrieved.AccessToken != token {
|
||
t.Errorf("AccessToken mismatch: expected %s, got %v", token, retrieved.AccessToken)
|
||
}
|
||
|
||
if retrieved.TokenExpiresAt == nil || *retrieved.TokenExpiresAt != expiresAt {
|
||
t.Errorf("TokenExpiresAt mismatch: expected %d, got %v", expiresAt, retrieved.TokenExpiresAt)
|
||
}
|
||
}
|
||
|
||
func TestUserRepository_ClearToken(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
defer cleanupTestDB(t, db)
|
||
|
||
repo := NewUserRepository()
|
||
|
||
hashedPassword, _ := HashPassword("password123")
|
||
user := &models.User{
|
||
Mobile: "13800000006",
|
||
PasswordHash: hashedPassword,
|
||
IsActive: true,
|
||
}
|
||
repo.Create(user)
|
||
|
||
// 先设置Token
|
||
token := "test.jwt.token"
|
||
expiresAt := int64(1704672000000)
|
||
repo.UpdateToken(user.ID, token, expiresAt)
|
||
|
||
// 清除Token
|
||
err := repo.ClearToken(user.ID)
|
||
if err != nil {
|
||
t.Fatalf("ClearToken failed: %v", err)
|
||
}
|
||
|
||
// 验证Token已清除
|
||
retrieved, err := repo.GetByID(user.ID)
|
||
if err != nil {
|
||
t.Fatalf("GetByID failed: %v", err)
|
||
}
|
||
|
||
if retrieved.AccessToken != nil {
|
||
t.Errorf("AccessToken should be nil after ClearToken, got %v", retrieved.AccessToken)
|
||
}
|
||
|
||
if retrieved.TokenExpiresAt != nil && *retrieved.TokenExpiresAt != 0 {
|
||
t.Errorf("TokenExpiresAt should be 0 after ClearToken, got %v", retrieved.TokenExpiresAt)
|
||
}
|
||
}
|
||
|
||
func TestUserRepository_VerifyPassword(t *testing.T) {
|
||
repo := NewUserRepository()
|
||
|
||
password := "password123"
|
||
hashedPassword, err := HashPassword(password)
|
||
if err != nil {
|
||
t.Fatalf("HashPassword failed: %v", err)
|
||
}
|
||
|
||
user := &models.User{
|
||
PasswordHash: hashedPassword,
|
||
}
|
||
|
||
// 测试正确密码
|
||
if !repo.VerifyPassword(user, password) {
|
||
t.Fatal("VerifyPassword should return true for correct password")
|
||
}
|
||
|
||
// 测试错误密码
|
||
if repo.VerifyPassword(user, "wrongpassword") {
|
||
t.Fatal("VerifyPassword should return false for wrong password")
|
||
}
|
||
|
||
// 测试空密码
|
||
if repo.VerifyPassword(user, "") {
|
||
t.Fatal("VerifyPassword should return false for empty password")
|
||
}
|
||
|
||
// 测试nil用户
|
||
if repo.VerifyPassword(nil, password) {
|
||
t.Fatal("VerifyPassword should return false for nil user")
|
||
}
|
||
}
|
||
|
||
func TestHashPassword(t *testing.T) {
|
||
password := "password123"
|
||
|
||
hashed1, err := HashPassword(password)
|
||
if err != nil {
|
||
t.Fatalf("HashPassword failed: %v", err)
|
||
}
|
||
|
||
if hashed1 == "" {
|
||
t.Fatal("Hashed password should not be empty")
|
||
}
|
||
|
||
if hashed1 == password {
|
||
t.Fatal("Hashed password should not equal plain password")
|
||
}
|
||
|
||
// 测试每次加密结果不同(因为salt)
|
||
hashed2, err := HashPassword(password)
|
||
if err != nil {
|
||
t.Fatalf("HashPassword failed: %v", err)
|
||
}
|
||
|
||
if hashed1 == hashed2 {
|
||
t.Fatal("Different hashes should be generated for the same password")
|
||
}
|
||
|
||
// 但都能验证通过
|
||
if err := bcrypt.CompareHashAndPassword([]byte(hashed1), []byte(password)); err != nil {
|
||
t.Fatal("First hash should verify correctly")
|
||
}
|
||
|
||
if err := bcrypt.CompareHashAndPassword([]byte(hashed2), []byte(password)); err != nil {
|
||
t.Fatal("Second hash should verify correctly")
|
||
}
|
||
|
||
// 测试空密码
|
||
_, err = HashPassword("")
|
||
if err == nil {
|
||
t.Fatal("HashPassword should return error for empty password")
|
||
}
|
||
}
|