feat: 添加 Redis 客户端和 Token 黑名单模块

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
zerosaturation 2026-05-14 17:05:47 +08:00
parent 308c080c1e
commit acacac19c7

View File

@ -1,18 +1,18 @@
package database package database
import ( import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
const ( const (
BlacklistKeyPrefix = "blacklist:token:" BlacklistKeyPrefix = "blacklist:token:"
InspirationFlowKeyPrefix = "inspiration_flow:" InspirationFlowKeyPrefix = "inspiration_flow:"
) )
// RedisClient Redis 客户端单例 // RedisClient Redis 客户端单例
@ -20,232 +20,232 @@ var RedisClient *redis.Client
// Config Redis 配置 // Config Redis 配置
type RedisConfig struct { type RedisConfig struct {
Host string Host string
Port int Port int
Password string Password string
DB int DB int
} }
// InitRedis 初始化 Redis 连接 // InitRedis 初始化 Redis 连接
func InitRedis(cfg RedisConfig) error { func InitRedis(cfg RedisConfig) error {
RedisClient = redis.NewClient(&redis.Options{ RedisClient = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password, Password: cfg.Password,
DB: cfg.DB, DB: cfg.DB,
}) })
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := RedisClient.Ping(ctx).Err(); err != nil { if err := RedisClient.Ping(ctx).Err(); err != nil {
return fmt.Errorf("failed to connect to redis: %w", err) return fmt.Errorf("failed to connect to redis: %w", err)
} }
return nil return nil
} }
// CloseRedis 关闭 Redis 连接 // CloseRedis 关闭 Redis 连接
func CloseRedis() error { func CloseRedis() error {
if RedisClient != nil { if RedisClient != nil {
return RedisClient.Close() return RedisClient.Close()
} }
return nil return nil
} }
// GetRedis 获取 Redis 客户端实例 // GetRedis 获取 Redis 客户端实例
func GetRedis() *redis.Client { func GetRedis() *redis.Client {
return RedisClient return RedisClient
} }
// RedisHealthCheck 健康检查 // RedisHealthCheck 健康检查
func RedisHealthCheck() error { func RedisHealthCheck() error {
if RedisClient == nil { if RedisClient == nil {
return fmt.Errorf("redis client is not initialized") return fmt.Errorf("redis client is not initialized")
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
return RedisClient.Ping(ctx).Err() return RedisClient.Ping(ctx).Err()
} }
// BlacklistEntry 黑名单条目 // BlacklistEntry 黑名单条目
type BlacklistEntry struct { type BlacklistEntry struct {
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
Reason string `json:"reason"` Reason string `json:"reason"`
} }
// tokenToHash 将 Token 转换为 SHA256 哈希作为 Key // tokenToHash 将 Token 转换为 SHA256 哈希作为 Key
func tokenToHash(token string) string { func tokenToHash(token string) string {
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", hash) return fmt.Sprintf("%x", hash)
} }
// AddToBlacklist 添加 Token 到黑名单 // AddToBlacklist 添加 Token 到黑名单
func AddToBlacklist(ctx context.Context, token string, userID int64, banReason string, ttl time.Duration) error { func AddToBlacklist(ctx context.Context, token string, userID int64, banReason string, ttl time.Duration) error {
if token == "" { if token == "" {
return fmt.Errorf("token is empty") return fmt.Errorf("token is empty")
} }
if RedisClient == nil { if RedisClient == nil {
return fmt.Errorf("redis client is not initialized") return fmt.Errorf("redis client is not initialized")
} }
key := BlacklistKeyPrefix + tokenToHash(token) key := BlacklistKeyPrefix + tokenToHash(token)
entry := BlacklistEntry{UserID: userID, Reason: banReason} entry := BlacklistEntry{UserID: userID, Reason: banReason}
value, err := json.Marshal(entry) value, err := json.Marshal(entry)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal blacklist entry: %w", err) return fmt.Errorf("failed to marshal blacklist entry: %w", err)
} }
return RedisClient.Set(ctx, key, value, ttl).Err() return RedisClient.Set(ctx, key, value, ttl).Err()
} }
// IsBlacklisted 检查 Token 是否在黑名单 // IsBlacklisted 检查 Token 是否在黑名单
func IsBlacklisted(ctx context.Context, token string) (bool, int64, string, error) { func IsBlacklisted(ctx context.Context, token string) (bool, int64, string, error) {
if token == "" { if token == "" {
return false, 0, "", nil return false, 0, "", nil
} }
if RedisClient == nil { if RedisClient == nil {
return false, 0, "", fmt.Errorf("redis client is not initialized") return false, 0, "", fmt.Errorf("redis client is not initialized")
} }
key := BlacklistKeyPrefix + tokenToHash(token) key := BlacklistKeyPrefix + tokenToHash(token)
value, err := RedisClient.Get(ctx, key).Result() value, err := RedisClient.Get(ctx, key).Result()
if err == redis.Nil { if err == redis.Nil {
return false, 0, "", nil return false, 0, "", nil
} }
if err != nil { if err != nil {
return false, 0, "", err return false, 0, "", err
} }
var entry BlacklistEntry var entry BlacklistEntry
if err := json.Unmarshal([]byte(value), &entry); err != nil { if err := json.Unmarshal([]byte(value), &entry); err != nil {
return false, 0, "", fmt.Errorf("failed to unmarshal blacklist entry: %w", err) return false, 0, "", fmt.Errorf("failed to unmarshal blacklist entry: %w", err)
} }
return true, entry.UserID, entry.Reason, nil return true, entry.UserID, entry.Reason, nil
} }
// RemoveFromBlacklist 从黑名单移除 Token用于解封 // RemoveFromBlacklist 从黑名单移除 Token用于解封
func RemoveFromBlacklist(ctx context.Context, token string) error { func RemoveFromBlacklist(ctx context.Context, token string) error {
if token == "" { if token == "" {
return nil return nil
} }
if RedisClient == nil { if RedisClient == nil {
return fmt.Errorf("redis client is not initialized") return fmt.Errorf("redis client is not initialized")
} }
key := BlacklistKeyPrefix + tokenToHash(token) key := BlacklistKeyPrefix + tokenToHash(token)
return RedisClient.Del(ctx, key).Err() return RedisClient.Del(ctx, key).Err()
} }
// InspirationFlowCacheEntry 单个展品缓存数据 // InspirationFlowCacheEntry 单个展品缓存数据
type InspirationFlowCacheEntry struct { type InspirationFlowCacheEntry struct {
AssetID int64 `json:"asset_id"` AssetID int64 `json:"asset_id"`
Name string `json:"name"` Name string `json:"name"`
CoverURL string `json:"cover_url"` CoverURL string `json:"cover_url"`
LikeCount int32 `json:"like_count"` LikeCount int32 `json:"like_count"`
OwnerNickname string `json:"owner_nickname"` OwnerNickname string `json:"owner_nickname"`
Span int32 `json:"span"` Span int32 `json:"span"`
MaterialType string `json:"material_type"` MaterialType string `json:"material_type"`
} }
// InspirationFlowCache 会话缓存结构 // InspirationFlowCache 会话缓存结构
type InspirationFlowCache struct { type InspirationFlowCache struct {
DisplayedIDs []int64 `json:"displayed_ids"` // 已展示ID列表 DisplayedIDs []int64 `json:"displayed_ids"` // 已展示ID列表
History map[int64]InspirationFlowCacheEntry `json:"history"` // 历史数据详情 History map[int64]InspirationFlowCacheEntry `json:"history"` // 历史数据详情
} }
// InspirationFlowKey 生成灵感瀑布流缓存 Key // InspirationFlowKey 生成灵感瀑布流缓存 Key
func InspirationFlowKey(starID int64, sessionID string) string { func InspirationFlowKey(starID int64, sessionID string) string {
return fmt.Sprintf("%s%d:%s", InspirationFlowKeyPrefix, starID, sessionID) return fmt.Sprintf("%s%d:%s", InspirationFlowKeyPrefix, starID, sessionID)
} }
// GetInspirationFlowCache 获取灵感瀑布流会话缓存 // GetInspirationFlowCache 获取灵感瀑布流会话缓存
func GetInspirationFlowCache(ctx context.Context, starID int64, sessionID string) (*InspirationFlowCache, error) { func GetInspirationFlowCache(ctx context.Context, starID int64, sessionID string) (*InspirationFlowCache, error) {
if RedisClient == nil { if RedisClient == nil {
return nil, fmt.Errorf("redis client is not initialized") return nil, fmt.Errorf("redis client is not initialized")
} }
key := InspirationFlowKey(starID, sessionID) key := InspirationFlowKey(starID, sessionID)
data, err := RedisClient.Get(ctx, key).Result() data, err := RedisClient.Get(ctx, key).Result()
if err == redis.Nil { if err == redis.Nil {
return &InspirationFlowCache{ return &InspirationFlowCache{
DisplayedIDs: []int64{}, DisplayedIDs: []int64{},
History: make(map[int64]InspirationFlowCacheEntry), History: make(map[int64]InspirationFlowCacheEntry),
}, nil }, nil
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
var cache InspirationFlowCache var cache InspirationFlowCache
if err := json.Unmarshal([]byte(data), &cache); err != nil { if err := json.Unmarshal([]byte(data), &cache); err != nil {
return nil, err return nil, err
} }
return &cache, nil return &cache, nil
} }
// SaveInspirationFlowCache 保存灵感瀑布流会话缓存 // SaveInspirationFlowCache 保存灵感瀑布流会话缓存
func SaveInspirationFlowCache(ctx context.Context, starID int64, sessionID string, cache *InspirationFlowCache, ttl time.Duration) error { func SaveInspirationFlowCache(ctx context.Context, starID int64, sessionID string, cache *InspirationFlowCache, ttl time.Duration) error {
if RedisClient == nil { if RedisClient == nil {
return fmt.Errorf("redis client is not initialized") return fmt.Errorf("redis client is not initialized")
} }
key := InspirationFlowKey(starID, sessionID) key := InspirationFlowKey(starID, sessionID)
data, err := json.Marshal(cache) data, err := json.Marshal(cache)
if err != nil { if err != nil {
return err return err
} }
return RedisClient.Set(ctx, key, data, ttl).Err() return RedisClient.Set(ctx, key, data, ttl).Err()
} }
// AddToInspirationFlowCache 添加展品到会话缓存 // AddToInspirationFlowCache 添加展品到会话缓存
func AddToInspirationFlowCache(ctx context.Context, starID int64, sessionID string, entry InspirationFlowCacheEntry, ttl time.Duration) error { func AddToInspirationFlowCache(ctx context.Context, starID int64, sessionID string, entry InspirationFlowCacheEntry, ttl time.Duration) error {
cache, err := GetInspirationFlowCache(ctx, starID, sessionID) cache, err := GetInspirationFlowCache(ctx, starID, sessionID)
if err != nil { if err != nil {
return err return err
} }
// 检查是否已存在 // 检查是否已存在
for _, id := range cache.DisplayedIDs { for _, id := range cache.DisplayedIDs {
if id == entry.AssetID { if id == entry.AssetID {
return nil // 已存在,跳过 return nil // 已存在,跳过
} }
} }
// 添加到已展示列表 // 添加到已展示列表
cache.DisplayedIDs = append(cache.DisplayedIDs, entry.AssetID) cache.DisplayedIDs = append(cache.DisplayedIDs, entry.AssetID)
// 添加到历史详情 // 添加到历史详情
if cache.History == nil { if cache.History == nil {
cache.History = make(map[int64]InspirationFlowCacheEntry) cache.History = make(map[int64]InspirationFlowCacheEntry)
} }
cache.History[entry.AssetID] = entry cache.History[entry.AssetID] = entry
return SaveInspirationFlowCache(ctx, starID, sessionID, cache, ttl) return SaveInspirationFlowCache(ctx, starID, sessionID, cache, ttl)
} }
// GetHistoryPage 获取历史数据的某一页 // GetHistoryPage 获取历史数据的某一页
func GetHistoryPage(cache *InspirationFlowCache, offset, limit int) []InspirationFlowCacheEntry { func GetHistoryPage(cache *InspirationFlowCache, offset, limit int) []InspirationFlowCacheEntry {
if cache == nil || cache.History == nil { if cache == nil || cache.History == nil {
return []InspirationFlowCacheEntry{} return []InspirationFlowCacheEntry{}
} }
// 按展示顺序反向遍历(最新展示的在前面) // 按展示顺序反向遍历(最新展示的在前面)
items := make([]InspirationFlowCacheEntry, 0) items := make([]InspirationFlowCacheEntry, 0)
for i := len(cache.DisplayedIDs) - 1; i >= 0; i-- { for i := len(cache.DisplayedIDs) - 1; i >= 0; i-- {
if entry, ok := cache.History[cache.DisplayedIDs[i]]; ok { if entry, ok := cache.History[cache.DisplayedIDs[i]]; ok {
items = append(items, entry) items = append(items, entry)
} }
} }
// 分页 // 分页
start := offset start := offset
end := offset + limit end := offset + limit
if start >= len(items) { if start >= len(items) {
return []InspirationFlowCacheEntry{} return []InspirationFlowCacheEntry{}
} }
if end > len(items) { if end > len(items) {
end = len(items) end = len(items)
} }
return items[start:end] return items[start:end]
} }