297 lines
7.4 KiB
Go
297 lines
7.4 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"fmt"
|
||
"math/big"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
"github.com/topfans/backend/pkg/database"
|
||
)
|
||
|
||
const (
|
||
SMSCodeKeyPrefix = "sms:register:" // 验证码:sms:register:{mobile}
|
||
SMSVerifyTokenPrefix = "verify:register:" // 验证Token:verify:register:{mobile}
|
||
SMSLimitMobilePrefix = "sms:limit:mobile:" // 手机号频率:sms:limit:mobile:register:{mobile}
|
||
SMSLimitIPPrefix = "sms:limit:ip:send:" // IP频率:sms:limit:ip:send:{ip}
|
||
SMSBlacklistIPPrefix = "sms:blacklist:ip:" // IP黑名单:sms:blacklist:ip:{ip}
|
||
|
||
MaxMobileAttemptsPerHour = 10
|
||
MaxIPAttemptsPerHour = 30
|
||
)
|
||
|
||
type SMSCodeData struct {
|
||
Code string `json:"code"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
Attempts int `json:"attempts"`
|
||
Used bool `json:"used"`
|
||
}
|
||
|
||
// GenerateCode generates a 6-digit random SMS code
|
||
func GenerateCode() (string, error) {
|
||
code := make([]byte, 6)
|
||
for i := range code {
|
||
n, err := rand.Int(rand.Reader, big.NewInt(10))
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to generate random number: %w", err)
|
||
}
|
||
code[i] = byte('0' + n.Int64())
|
||
}
|
||
return string(code), nil
|
||
}
|
||
|
||
// GetRedisClient returns the Redis client instance
|
||
func GetRedisClient() *redis.Client {
|
||
return database.GetRedis()
|
||
}
|
||
|
||
// smsCodeKey generates the Redis key for SMS code
|
||
func smsCodeKey(mobile string) string {
|
||
return SMSCodeKeyPrefix + mobile
|
||
}
|
||
|
||
// verifyTokenKey generates the Redis key for verify token
|
||
func verifyTokenKey(mobile string) string {
|
||
return SMSVerifyTokenPrefix + mobile
|
||
}
|
||
|
||
// mobileLimitKey generates the Redis key for mobile rate limit
|
||
func mobileLimitKey(mobile string) string {
|
||
return SMSLimitMobilePrefix + mobile
|
||
}
|
||
|
||
// ipLimitKey generates the Redis key for IP rate limit
|
||
func ipLimitKey(ip string) string {
|
||
return SMSLimitIPPrefix + ip
|
||
}
|
||
|
||
// blacklistIPKey generates the Redis key for IP blacklist
|
||
func blacklistIPKey(ip string) string {
|
||
return SMSBlacklistIPPrefix + ip
|
||
}
|
||
|
||
// SaveSMSCode saves the SMS code as a Hash with fields: code, created_at, attempts, used
|
||
func SaveSMSCode(ctx context.Context, mobile, code string, ttl time.Duration) error {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := smsCodeKey(mobile)
|
||
now := time.Now().Unix()
|
||
|
||
pipe := client.Pipeline()
|
||
pipe.HSet(ctx, key, map[string]interface{}{
|
||
"code": code,
|
||
"created_at": now,
|
||
"attempts": 0,
|
||
"used": false,
|
||
})
|
||
pipe.Expire(ctx, key, ttl)
|
||
_, err := pipe.Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
// GetSMSCode retrieves the SMS code data for a mobile number
|
||
func GetSMSCode(ctx context.Context, mobile string) (*SMSCodeData, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return nil, fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := smsCodeKey(mobile)
|
||
data, err := client.HGetAll(ctx, key).Result()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if len(data) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
smsData := &SMSCodeData{}
|
||
|
||
if code, ok := data["code"]; ok {
|
||
smsData.Code = code
|
||
}
|
||
if createdAt, ok := data["created_at"]; ok {
|
||
fmt.Sscanf(createdAt, "%d", &smsData.CreatedAt)
|
||
}
|
||
if attempts, ok := data["attempts"]; ok {
|
||
fmt.Sscanf(attempts, "%d", &smsData.Attempts)
|
||
}
|
||
if used, ok := data["used"]; ok {
|
||
smsData.Used = used == "true" || used == "1"
|
||
}
|
||
|
||
return smsData, nil
|
||
}
|
||
|
||
// DeleteSMSCode deletes the SMS code key
|
||
func DeleteSMSCode(ctx context.Context, mobile string) error {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := smsCodeKey(mobile)
|
||
return client.Del(ctx, key).Err()
|
||
}
|
||
|
||
// IncrementAttempts increments the attempts counter and returns the new count
|
||
func IncrementAttempts(ctx context.Context, mobile string) (int, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return 0, fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := smsCodeKey(mobile)
|
||
count, err := client.HIncrBy(ctx, key, "attempts", 1).Result()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return int(count), nil
|
||
}
|
||
|
||
// MarkCodeUsed marks the SMS code as used
|
||
func MarkCodeUsed(ctx context.Context, mobile string) error {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := smsCodeKey(mobile)
|
||
return client.HSet(ctx, key, "used", "true").Err()
|
||
}
|
||
|
||
// SaveVerifyToken saves the verify token as a String
|
||
func SaveVerifyToken(ctx context.Context, mobile, token string, ttl time.Duration) error {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := verifyTokenKey(mobile)
|
||
return client.Set(ctx, key, token, ttl).Err()
|
||
}
|
||
|
||
// GetVerifyToken retrieves the verify token for a mobile number
|
||
func GetVerifyToken(ctx context.Context, mobile string) (string, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return "", fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := verifyTokenKey(mobile)
|
||
token, err := client.Get(ctx, key).Result()
|
||
if err == redis.Nil {
|
||
return "", nil
|
||
}
|
||
return token, err
|
||
}
|
||
|
||
// DeleteVerifyToken deletes the verify token
|
||
func DeleteVerifyToken(ctx context.Context, mobile string) error {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := verifyTokenKey(mobile)
|
||
return client.Del(ctx, key).Err()
|
||
}
|
||
|
||
// CheckRateLimitMobile checks if the mobile has sent less than 10 times in the current hour
|
||
func CheckRateLimitMobile(ctx context.Context, mobile string) (bool, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return false, fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := mobileLimitKey(mobile)
|
||
count, err := client.Get(ctx, key).Int()
|
||
if err == redis.Nil {
|
||
return true, nil
|
||
}
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count < MaxMobileAttemptsPerHour, nil
|
||
}
|
||
|
||
// IncrMobileCount increments the mobile send count with 1 hour TTL
|
||
func IncrMobileCount(ctx context.Context, mobile string) (int, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return 0, fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := mobileLimitKey(mobile)
|
||
pipe := client.Pipeline()
|
||
|
||
incr := pipe.Incr(ctx, key)
|
||
pipe.Expire(ctx, key, time.Hour)
|
||
|
||
_, err := pipe.Exec(ctx)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return int(incr.Val()), nil
|
||
}
|
||
|
||
// CheckRateLimitIP checks if the IP has sent less than 30 times in the current hour
|
||
func CheckRateLimitIP(ctx context.Context, ip string) (bool, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return false, fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := ipLimitKey(ip)
|
||
count, err := client.Get(ctx, key).Int()
|
||
if err == redis.Nil {
|
||
return true, nil
|
||
}
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count < MaxIPAttemptsPerHour, nil
|
||
}
|
||
|
||
// IncrIPCount increments the IP send count with 1 hour TTL
|
||
func IncrIPCount(ctx context.Context, ip string) (int, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return 0, fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := ipLimitKey(ip)
|
||
pipe := client.Pipeline()
|
||
|
||
incr := pipe.Incr(ctx, key)
|
||
pipe.Expire(ctx, key, time.Hour)
|
||
|
||
_, err := pipe.Exec(ctx)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return int(incr.Val()), nil
|
||
}
|
||
|
||
// CheckIPBlacklist checks if the IP is blacklisted
|
||
func CheckIPBlacklist(ctx context.Context, ip string) (bool, error) {
|
||
client := GetRedisClient()
|
||
if client == nil {
|
||
return false, fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := blacklistIPKey(ip)
|
||
exists, err := client.Exists(ctx, key).Result()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return exists > 0, nil
|
||
} |