package service import ( "context" "crypto/rand" "fmt" "math/big" "time" "github.com/redis/go-redis/v9" "github.com/topfans/backend/pkg/database" ) const ( SMSCodeKeyPrefix = "sms:register:" // 验证码:sms:register:{mobile} SMSVerifyTokenPrefix = "verify:register:" // 验证Token:verify:register:{mobile} SMSLimitMobilePrefix = "sms:limit:mobile:" // 手机号频率:sms:limit:mobile:register:{mobile} SMSLimitIPPrefix = "sms:limit:ip:send:" // IP频率:sms:limit:ip:send:{ip} SMSBlacklistIPPrefix = "sms:blacklist:ip:" // IP黑名单:sms:blacklist:ip:{ip} MaxMobileAttemptsPerHour = 10 MaxIPAttemptsPerHour = 30 ) type SMSCodeData struct { Code string `json:"code"` CreatedAt int64 `json:"created_at"` Attempts int `json:"attempts"` Used bool `json:"used"` } // GenerateCode generates a 6-digit random SMS code func GenerateCode() (string, error) { code := make([]byte, 6) for i := range code { n, err := rand.Int(rand.Reader, big.NewInt(10)) if err != nil { return "", fmt.Errorf("failed to generate random number: %w", err) } code[i] = byte('0' + n.Int64()) } return string(code), nil } // GetRedisClient returns the Redis client instance func GetRedisClient() *redis.Client { return database.GetRedis() } // smsCodeKey generates the Redis key for SMS code func smsCodeKey(mobile string) string { return SMSCodeKeyPrefix + mobile } // verifyTokenKey generates the Redis key for verify token func verifyTokenKey(mobile string) string { return SMSVerifyTokenPrefix + mobile } // mobileLimitKey generates the Redis key for mobile rate limit func mobileLimitKey(mobile string) string { return SMSLimitMobilePrefix + mobile } // ipLimitKey generates the Redis key for IP rate limit func ipLimitKey(ip string) string { return SMSLimitIPPrefix + ip } // blacklistIPKey generates the Redis key for IP blacklist func blacklistIPKey(ip string) string { return SMSBlacklistIPPrefix + ip } // SaveSMSCode saves the SMS code as a Hash with fields: code, created_at, attempts, used func SaveSMSCode(ctx context.Context, mobile, code string, ttl time.Duration) error { client := GetRedisClient() if client == nil { return fmt.Errorf("redis client is not initialized") } key := smsCodeKey(mobile) now := time.Now().Unix() pipe := client.Pipeline() pipe.HSet(ctx, key, map[string]interface{}{ "code": code, "created_at": now, "attempts": 0, "used": false, }) pipe.Expire(ctx, key, ttl) _, err := pipe.Exec(ctx) return err } // GetSMSCode retrieves the SMS code data for a mobile number func GetSMSCode(ctx context.Context, mobile string) (*SMSCodeData, error) { client := GetRedisClient() if client == nil { return nil, fmt.Errorf("redis client is not initialized") } key := smsCodeKey(mobile) data, err := client.HGetAll(ctx, key).Result() if err != nil { return nil, err } if len(data) == 0 { return nil, nil } smsData := &SMSCodeData{} if code, ok := data["code"]; ok { smsData.Code = code } if createdAt, ok := data["created_at"]; ok { fmt.Sscanf(createdAt, "%d", &smsData.CreatedAt) } if attempts, ok := data["attempts"]; ok { fmt.Sscanf(attempts, "%d", &smsData.Attempts) } if used, ok := data["used"]; ok { smsData.Used = used == "true" || used == "1" } return smsData, nil } // DeleteSMSCode deletes the SMS code key func DeleteSMSCode(ctx context.Context, mobile string) error { client := GetRedisClient() if client == nil { return fmt.Errorf("redis client is not initialized") } key := smsCodeKey(mobile) return client.Del(ctx, key).Err() } // IncrementAttempts increments the attempts counter and returns the new count func IncrementAttempts(ctx context.Context, mobile string) (int, error) { client := GetRedisClient() if client == nil { return 0, fmt.Errorf("redis client is not initialized") } key := smsCodeKey(mobile) count, err := client.HIncrBy(ctx, key, "attempts", 1).Result() if err != nil { return 0, err } return int(count), nil } // MarkCodeUsed marks the SMS code as used func MarkCodeUsed(ctx context.Context, mobile string) error { client := GetRedisClient() if client == nil { return fmt.Errorf("redis client is not initialized") } key := smsCodeKey(mobile) return client.HSet(ctx, key, "used", "true").Err() } // SaveVerifyToken saves the verify token as a String func SaveVerifyToken(ctx context.Context, mobile, token string, ttl time.Duration) error { client := GetRedisClient() if client == nil { return fmt.Errorf("redis client is not initialized") } key := verifyTokenKey(mobile) return client.Set(ctx, key, token, ttl).Err() } // GetVerifyToken retrieves the verify token for a mobile number func GetVerifyToken(ctx context.Context, mobile string) (string, error) { client := GetRedisClient() if client == nil { return "", fmt.Errorf("redis client is not initialized") } key := verifyTokenKey(mobile) token, err := client.Get(ctx, key).Result() if err == redis.Nil { return "", nil } return token, err } // DeleteVerifyToken deletes the verify token func DeleteVerifyToken(ctx context.Context, mobile string) error { client := GetRedisClient() if client == nil { return fmt.Errorf("redis client is not initialized") } key := verifyTokenKey(mobile) return client.Del(ctx, key).Err() } // CheckRateLimitMobile checks if the mobile has sent less than 10 times in the current hour func CheckRateLimitMobile(ctx context.Context, mobile string) (bool, error) { client := GetRedisClient() if client == nil { return false, fmt.Errorf("redis client is not initialized") } key := mobileLimitKey(mobile) count, err := client.Get(ctx, key).Int() if err == redis.Nil { return true, nil } if err != nil { return false, err } return count < MaxMobileAttemptsPerHour, nil } // IncrMobileCount increments the mobile send count with 1 hour TTL func IncrMobileCount(ctx context.Context, mobile string) (int, error) { client := GetRedisClient() if client == nil { return 0, fmt.Errorf("redis client is not initialized") } key := mobileLimitKey(mobile) pipe := client.Pipeline() incr := pipe.Incr(ctx, key) pipe.Expire(ctx, key, time.Hour) _, err := pipe.Exec(ctx) if err != nil { return 0, err } return int(incr.Val()), nil } // CheckRateLimitIP checks if the IP has sent less than 30 times in the current hour func CheckRateLimitIP(ctx context.Context, ip string) (bool, error) { client := GetRedisClient() if client == nil { return false, fmt.Errorf("redis client is not initialized") } key := ipLimitKey(ip) count, err := client.Get(ctx, key).Int() if err == redis.Nil { return true, nil } if err != nil { return false, err } return count < MaxIPAttemptsPerHour, nil } // IncrIPCount increments the IP send count with 1 hour TTL func IncrIPCount(ctx context.Context, ip string) (int, error) { client := GetRedisClient() if client == nil { return 0, fmt.Errorf("redis client is not initialized") } key := ipLimitKey(ip) pipe := client.Pipeline() incr := pipe.Incr(ctx, key) pipe.Expire(ctx, key, time.Hour) _, err := pipe.Exec(ctx) if err != nil { return 0, err } return int(incr.Val()), nil } // CheckIPBlacklist checks if the IP is blacklisted func CheckIPBlacklist(ctx context.Context, ip string) (bool, error) { client := GetRedisClient() if client == nil { return false, fmt.Errorf("redis client is not initialized") } key := blacklistIPKey(ip) exists, err := client.Exists(ctx, key).Result() if err != nil { return false, err } return exists > 0, nil }