137 lines
3.4 KiB
Go
137 lines
3.4 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
const (
|
||
BlacklistKeyPrefix = "blacklist:token:"
|
||
)
|
||
|
||
// RedisClient Redis 客户端单例
|
||
var RedisClient *redis.Client
|
||
|
||
// Config Redis 配置
|
||
type RedisConfig struct {
|
||
Host string
|
||
Port int
|
||
Password string
|
||
DB int
|
||
}
|
||
|
||
// InitRedis 初始化 Redis 连接
|
||
func InitRedis(cfg RedisConfig) error {
|
||
RedisClient = redis.NewClient(&redis.Options{
|
||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||
Password: cfg.Password,
|
||
DB: cfg.DB,
|
||
})
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := RedisClient.Ping(ctx).Err(); err != nil {
|
||
return fmt.Errorf("failed to connect to redis: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// CloseRedis 关闭 Redis 连接
|
||
func CloseRedis() error {
|
||
if RedisClient != nil {
|
||
return RedisClient.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetRedis 获取 Redis 客户端实例
|
||
func GetRedis() *redis.Client {
|
||
return RedisClient
|
||
}
|
||
|
||
// RedisHealthCheck 健康检查
|
||
func RedisHealthCheck() error {
|
||
if RedisClient == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
return RedisClient.Ping(ctx).Err()
|
||
}
|
||
|
||
// BlacklistEntry 黑名单条目
|
||
type BlacklistEntry struct {
|
||
UserID int64 `json:"user_id"`
|
||
Reason string `json:"reason"`
|
||
}
|
||
|
||
// tokenToHash 将 Token 转换为 SHA256 哈希作为 Key
|
||
func tokenToHash(token string) string {
|
||
hash := sha256.Sum256([]byte(token))
|
||
return fmt.Sprintf("%x", hash)
|
||
}
|
||
|
||
// AddToBlacklist 添加 Token 到黑名单
|
||
func AddToBlacklist(ctx context.Context, token string, userID int64, banReason string, ttl time.Duration) error {
|
||
if token == "" {
|
||
return fmt.Errorf("token is empty")
|
||
}
|
||
if RedisClient == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := BlacklistKeyPrefix + tokenToHash(token)
|
||
entry := BlacklistEntry{UserID: userID, Reason: banReason}
|
||
value, err := json.Marshal(entry)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal blacklist entry: %w", err)
|
||
}
|
||
|
||
return RedisClient.Set(ctx, key, value, ttl).Err()
|
||
}
|
||
|
||
// IsBlacklisted 检查 Token 是否在黑名单
|
||
func IsBlacklisted(ctx context.Context, token string) (bool, int64, string, error) {
|
||
if token == "" {
|
||
return false, 0, "", nil
|
||
}
|
||
if RedisClient == nil {
|
||
return false, 0, "", fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := BlacklistKeyPrefix + tokenToHash(token)
|
||
value, err := RedisClient.Get(ctx, key).Result()
|
||
if err == redis.Nil {
|
||
return false, 0, "", nil
|
||
}
|
||
if err != nil {
|
||
return false, 0, "", err
|
||
}
|
||
|
||
var entry BlacklistEntry
|
||
if err := json.Unmarshal([]byte(value), &entry); err != nil {
|
||
return false, 0, "", fmt.Errorf("failed to unmarshal blacklist entry: %w", err)
|
||
}
|
||
|
||
return true, entry.UserID, entry.Reason, nil
|
||
}
|
||
|
||
// RemoveFromBlacklist 从黑名单移除 Token(用于解封)
|
||
func RemoveFromBlacklist(ctx context.Context, token string) error {
|
||
if token == "" {
|
||
return nil
|
||
}
|
||
if RedisClient == nil {
|
||
return fmt.Errorf("redis client is not initialized")
|
||
}
|
||
|
||
key := BlacklistKeyPrefix + tokenToHash(token)
|
||
return RedisClient.Del(ctx, key).Err()
|
||
} |