topfans/backend/services/userService/service/sms_redis.go
2026-06-15 14:15:24 +08:00

331 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"time"
appErrors "github.com/topfans/backend/pkg/errors"
"github.com/redis/go-redis/v9"
"github.com/topfans/backend/pkg/database"
)
const (
SMSCodeKeyPrefix = "sms:" // 验证码sms:{scene}:{mobile}
SMSVerifyTokenPrefix = "verify:" // 验证Tokenverify:{scene}:{mobile}
SMSLimitMobilePrefix = "sms:limit:mobile:" // 手机号频率sms:limit:mobile:{scene}:{mobile}
SMSLimitIPPrefix = "sms:limit:ip:send:" // IP频率sms:limit:ip:send:{ip}
SMSBlacklistIPPrefix = "sms:blacklist:ip:" // IP黑名单sms:blacklist:ip:{ip}
MaxMobileAttemptsPerHour = 10
MaxIPAttemptsPerHour = 30
)
type SMSCodeData struct {
Code string `json:"code"`
CreatedAt int64 `json:"created_at"`
Attempts int `json:"attempts"`
Used bool `json:"used"`
}
// GenerateCode generates a 6-digit random SMS code
func GenerateCode() (string, error) {
code := make([]byte, 6)
for i := range code {
n, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
return "", fmt.Errorf("failed to generate random number: %w", err)
}
code[i] = byte('0' + n.Int64())
}
return string(code), nil
}
// GetRedisClient returns the Redis client instance
func GetRedisClient() *redis.Client {
return database.GetRedis()
}
// smsCodeKey generates the Redis key for SMS code
func smsCodeKey(scene, mobile string) string {
return SMSCodeKeyPrefix + scene + ":" + mobile
}
// verifyTokenKey generates the Redis key for verify token
func verifyTokenKey(scene, mobile string) string {
return SMSVerifyTokenPrefix + scene + ":" + mobile
}
// mobileLimitKey generates the Redis key for mobile rate limit
func mobileLimitKey(scene, mobile string) string {
return SMSLimitMobilePrefix + scene + ":" + mobile
}
// ipLimitKey generates the Redis key for IP rate limit
func ipLimitKey(ip string) string {
return SMSLimitIPPrefix + ip
}
// blacklistIPKey generates the Redis key for IP blacklist
func blacklistIPKey(ip string) string {
return SMSBlacklistIPPrefix + ip
}
// SaveSMSCode saves the SMS code as a Hash with fields: code, created_at, attempts, used
func SaveSMSCode(ctx context.Context, scene, mobile, code string, ttl time.Duration) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(scene, mobile)
now := time.Now().Unix()
pipe := client.Pipeline()
pipe.HSet(ctx, key, map[string]interface{}{
"code": code,
"created_at": now,
"attempts": 0,
"used": false,
})
pipe.Expire(ctx, key, ttl)
_, err := pipe.Exec(ctx)
return err
}
// GetSMSCode retrieves the SMS code data for a mobile number
func GetSMSCode(ctx context.Context, scene, mobile string) (*SMSCodeData, error) {
client := GetRedisClient()
if client == nil {
return nil, fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(scene, mobile)
data, err := client.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, nil
}
smsData := &SMSCodeData{}
if code, ok := data["code"]; ok {
smsData.Code = code
}
if createdAt, ok := data["created_at"]; ok {
fmt.Sscanf(createdAt, "%d", &smsData.CreatedAt)
}
if attempts, ok := data["attempts"]; ok {
fmt.Sscanf(attempts, "%d", &smsData.Attempts)
}
if used, ok := data["used"]; ok {
smsData.Used = used == "true" || used == "1"
}
return smsData, nil
}
// DeleteSMSCode deletes the SMS code key
func DeleteSMSCode(ctx context.Context, scene, mobile string) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(scene, mobile)
return client.Del(ctx, key).Err()
}
// IncrementAttempts increments the attempts counter and returns the new count
func IncrementAttempts(ctx context.Context, scene, mobile string) (int, error) {
client := GetRedisClient()
if client == nil {
return 0, fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(scene, mobile)
count, err := client.HIncrBy(ctx, key, "attempts", 1).Result()
if err != nil {
return 0, err
}
return int(count), nil
}
// MarkCodeUsed marks the SMS code as used
func MarkCodeUsed(ctx context.Context, scene, mobile string) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := smsCodeKey(scene, mobile)
return client.HSet(ctx, key, "used", "true").Err()
}
// SaveVerifyToken saves the verify token as a String
func SaveVerifyToken(ctx context.Context, scene, mobile, token string, ttl time.Duration) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := verifyTokenKey(scene, mobile)
return client.Set(ctx, key, token, ttl).Err()
}
// GetVerifyToken retrieves the verify token for a mobile number
func GetVerifyToken(ctx context.Context, scene, mobile string) (string, error) {
client := GetRedisClient()
if client == nil {
return "", fmt.Errorf("redis client is not initialized")
}
key := verifyTokenKey(scene, mobile)
token, err := client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
}
return token, err
}
// DeleteVerifyToken deletes the verify token
func DeleteVerifyToken(ctx context.Context, scene, mobile string) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := verifyTokenKey(scene, mobile)
return client.Del(ctx, key).Err()
}
// verifyTokenConsumeScript Lua 脚本:原子 GET+COMPARE+DEL
// 用于 VerifyToken 在业务成功后被 ConsumeVerifyToken 一次性消费
const verifyTokenConsumeScript = `
local current = redis.call('GET', KEYS[1])
if current == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
`
// ConsumeVerifyToken 校验 + 原子删除(Plan A: Compare-And-Delete)
// 仅在业务逻辑成功后调用,实现"成功才消费"的语义:
// - token 校验通过 + 业务成功 → token 被删除(下次请求会失败)
// - token 校验通过 + 业务失败 → token 保留(用户可重试无需重新发短信)
// - token 已被并发请求消费 → 返回 ErrInvalidVerifyToken(防重放)
func ConsumeVerifyToken(ctx context.Context, scene, mobile, token string) error {
client := GetRedisClient()
if client == nil {
return fmt.Errorf("redis client is not initialized")
}
key := verifyTokenKey(scene, mobile)
result, err := client.Eval(ctx, verifyTokenConsumeScript, []string{key}, token).Result()
if err != nil {
return fmt.Errorf("failed to consume verify token: %w", err)
}
deleted, _ := result.(int64)
if deleted == 0 {
return appErrors.ErrInvalidVerifyToken
}
return nil
}
// CheckRateLimitMobile checks if the mobile has sent less than 10 times in the current hour
func CheckRateLimitMobile(ctx context.Context, scene, mobile string) (bool, error) {
client := GetRedisClient()
if client == nil {
return false, fmt.Errorf("redis client is not initialized")
}
key := mobileLimitKey(scene, mobile)
count, err := client.Get(ctx, key).Int()
if err == redis.Nil {
return true, nil
}
if err != nil {
return false, err
}
return count < MaxMobileAttemptsPerHour, nil
}
// IncrMobileCount increments the mobile send count with 1 hour TTL
func IncrMobileCount(ctx context.Context, scene, mobile string) (int, error) {
client := GetRedisClient()
if client == nil {
return 0, fmt.Errorf("redis client is not initialized")
}
key := mobileLimitKey(scene, mobile)
pipe := client.Pipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Hour)
_, err := pipe.Exec(ctx)
if err != nil {
return 0, err
}
return int(incr.Val()), nil
}
// CheckRateLimitIP checks if the IP has sent less than 30 times in the current hour
func CheckRateLimitIP(ctx context.Context, ip string) (bool, error) {
client := GetRedisClient()
if client == nil {
return false, fmt.Errorf("redis client is not initialized")
}
key := ipLimitKey(ip)
count, err := client.Get(ctx, key).Int()
if err == redis.Nil {
return true, nil
}
if err != nil {
return false, err
}
return count < MaxIPAttemptsPerHour, nil
}
// IncrIPCount increments the IP send count with 1 hour TTL
func IncrIPCount(ctx context.Context, ip string) (int, error) {
client := GetRedisClient()
if client == nil {
return 0, fmt.Errorf("redis client is not initialized")
}
key := ipLimitKey(ip)
pipe := client.Pipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Hour)
_, err := pipe.Exec(ctx)
if err != nil {
return 0, err
}
return int(incr.Val()), nil
}
// CheckIPBlacklist checks if the IP is blacklisted
func CheckIPBlacklist(ctx context.Context, ip string) (bool, error) {
client := GetRedisClient()
if client == nil {
return false, fmt.Errorf("redis client is not initialized")
}
key := blacklistIPKey(ip)
exists, err := client.Exists(ctx, key).Result()
if err != nil {
return false, err
}
return exists > 0, nil
}