topfans/backend/services/userService/service/sms_redis.go
2026-05-26 13:23:04 +08:00

297 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"time"
"github.com/redis/go-redis/v9"
"github.com/topfans/backend/pkg/database"
)
const (
SMSCodeKeyPrefix = "sms:register:" // 验证码sms:register:{mobile}
SMSVerifyTokenPrefix = "verify:register:" // 验证Tokenverify:register:{mobile}
SMSLimitMobilePrefix = "sms:limit:mobile:" // 手机号频率sms:limit:mobile:register:{mobile}
SMSLimitIPPrefix = "sms:limit:ip:send:" // IP频率sms:limit:ip:send:{ip}
SMSBlacklistIPPrefix = "sms:blacklist:ip:" // IP黑名单sms:blacklist:ip:{ip}
MaxMobileAttemptsPerHour = 10
MaxIPAttemptsPerHour = 30
)
type SMSCodeData struct {
Code string `json:"code"`
CreatedAt int64 `json:"created_at"`
Attempts int `json:"attempts"`
Used bool `json:"used"`
}
// GenerateCode generates a 6-digit random SMS code
func GenerateCode() (string, error) {
code := make([]byte, 6)
for i := range code {
n, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
return "", fmt.Errorf("failed to generate random number: %w", err)
}
code[i] = byte('0' + n.Int64())
}
return string(code), nil
}
// GetRedisClient returns the Redis client instance
func GetRedisClient() *redis.Client {
return database.GetRedis()
}
// smsCodeKey generates the Redis key for SMS code
func smsCodeKey(mobile string) string {
return SMSCodeKeyPrefix + mobile
}
// verifyTokenKey generates the Redis key for verify token
func verifyTokenKey(mobile string) string {
return SMSVerifyTokenPrefix + mobile
}
// mobileLimitKey generates the Redis key for mobile rate limit
func mobileLimitKey(mobile string) string {
return SMSLimitMobilePrefix + mobile
}
// ipLimitKey generates the Redis key for IP rate limit
func ipLimitKey(ip string) string {
return SMSLimitIPPrefix + ip
}
// blacklistIPKey generates the Redis key for IP blacklist
func blacklistIPKey(ip string) string {
return SMSBlacklistIPPrefix + ip
}
// SaveSMSCode saves the SMS code as a Hash with fields: code, created_at, attempts, used
func SaveSMSCode(ctx context.Context, mobile, code string, ttl time.Duration) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(mobile)
now := time.Now().Unix()
pipe := client.Pipeline()
pipe.HSet(ctx, key, map[string]interface{}{
"code": code,
"created_at": now,
"attempts": 0,
"used": false,
})
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
// GetSMSCode retrieves the SMS code data for a mobile number
func GetSMSCode(ctx context.Context, mobile string) (*SMSCodeData, error) {
client := GetRedisClient()
if client == nil {
return nil, fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(mobile)
data, err := client.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, nil
}
smsData := &SMSCodeData{}
if code, ok := data["code"]; ok {
smsData.Code = code
}
if createdAt, ok := data["created_at"]; ok {
fmt.Sscanf(createdAt, "%d", &smsData.CreatedAt)
}
if attempts, ok := data["attempts"]; ok {
fmt.Sscanf(attempts, "%d", &smsData.Attempts)
}
if used, ok := data["used"]; ok {
smsData.Used = used == "true" || used == "1"
}
return smsData, nil
}
// DeleteSMSCode deletes the SMS code key
func DeleteSMSCode(ctx context.Context, mobile string) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(mobile)
return client.Del(ctx, key).Err()
}
// IncrementAttempts increments the attempts counter and returns the new count
func IncrementAttempts(ctx context.Context, mobile string) (int, error) {
client := GetRedisClient()
if client == nil {
return 0, fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(mobile)
count, err := client.HIncrBy(ctx, key, "attempts", 1).Result()
if err != nil {
return 0, err
}
return int(count), nil
}
// MarkCodeUsed marks the SMS code as used
func MarkCodeUsed(ctx context.Context, mobile string) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(mobile)
return client.HSet(ctx, key, "used", "true").Err()
}
// SaveVerifyToken saves the verify token as a String
func SaveVerifyToken(ctx context.Context, mobile, token string, ttl time.Duration) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := verifyTokenKey(mobile)
return client.Set(ctx, key, token, ttl).Err()
}
// GetVerifyToken retrieves the verify token for a mobile number
func GetVerifyToken(ctx context.Context, mobile string) (string, error) {
client := GetRedisClient()
if client == nil {
return "", fmt.Errorf("redis client is not initialized")
}
key := verifyTokenKey(mobile)
token, err := client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
}
return token, err
}
// DeleteVerifyToken deletes the verify token
func DeleteVerifyToken(ctx context.Context, mobile string) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := verifyTokenKey(mobile)
return client.Del(ctx, key).Err()
}
// CheckRateLimitMobile checks if the mobile has sent less than 10 times in the current hour
func CheckRateLimitMobile(ctx context.Context, mobile string) (bool, error) {
client := GetRedisClient()
if client == nil {
return false, fmt.Errorf("redis client is not initialized")
}
key := mobileLimitKey(mobile)
count, err := client.Get(ctx, key).Int()
if err == redis.Nil {
return true, nil
}
if err != nil {
return false, err
}
return count < MaxMobileAttemptsPerHour, nil
}
// IncrMobileCount increments the mobile send count with 1 hour TTL
func IncrMobileCount(ctx context.Context, mobile string) (int, error) {
client := GetRedisClient()
if client == nil {
return 0, fmt.Errorf("redis client is not initialized")
}
key := mobileLimitKey(mobile)
pipe := client.Pipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Hour)
_, err := pipe.Exec(ctx)
if err != nil {
return 0, err
}
return int(incr.Val()), nil
}
// CheckRateLimitIP checks if the IP has sent less than 30 times in the current hour
func CheckRateLimitIP(ctx context.Context, ip string) (bool, error) {
client := GetRedisClient()
if client == nil {
return false, fmt.Errorf("redis client is not initialized")
}
key := ipLimitKey(ip)
count, err := client.Get(ctx, key).Int()
if err == redis.Nil {
return true, nil
}
if err != nil {
return false, err
}
return count < MaxIPAttemptsPerHour, nil
}
// IncrIPCount increments the IP send count with 1 hour TTL
func IncrIPCount(ctx context.Context, ip string) (int, error) {
client := GetRedisClient()
if client == nil {
return 0, fmt.Errorf("redis client is not initialized")
}
key := ipLimitKey(ip)
pipe := client.Pipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Hour)
_, err := pipe.Exec(ctx)
if err != nil {
return 0, err
}
return int(incr.Val()), nil
}
// CheckIPBlacklist checks if the IP is blacklisted
func CheckIPBlacklist(ctx context.Context, ip string) (bool, error) {
client := GetRedisClient()
if client == nil {
return false, fmt.Errorf("redis client is not initialized")
}
key := blacklistIPKey(ip)
exists, err := client.Exists(ctx, key).Result()
if err != nil {
return false, err
}
return exists > 0, nil
}