package repository import ( "context" "encoding/json" "fmt" "time" "github.com/lib/pq" "github.com/redis/go-redis/v9" "github.com/topfans/backend/services/aiChatService/model" "gorm.io/gorm" ) // ShortTermMemoryRepository 短期记忆仓库接口 (Redis) type ShortTermMemoryRepository interface { SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error GetContext(ctx context.Context, sessionID string) ([]model.Message, error) DeleteContext(ctx context.Context, sessionID string) error } // LongTermMemoryRepository 长期记忆仓库接口 (PostgreSQL) type LongTermMemoryRepository interface { SaveMemory(ctx context.Context, memory *model.UserMemory) error GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error) GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error) } // MemoryRepository 记忆仓库接口(兼容性别名,用于不需要区分短期/长期的场景) type MemoryRepository interface { ShortTermMemoryRepository LongTermMemoryRepository } // RedisMemoryRepository Redis 短期记忆实现 type RedisMemoryRepository struct { client *redis.Client ttl time.Duration } // NewRedisMemoryRepository 创建 Redis 短期记忆仓库 func NewRedisMemoryRepository(client *redis.Client, ttlSeconds int) *RedisMemoryRepository { return &RedisMemoryRepository{ client: client, ttl: time.Duration(ttlSeconds) * time.Second, } } // contextKey 生成 Redis key func contextKey(sessionID string) string { return fmt.Sprintf("context:%s", sessionID) } // SaveContext 保存短期上下文到 Redis func (r *RedisMemoryRepository) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error { data := model.ChatContext{ SessionID: sessionID, Messages: messages, PersonaID: personaID, UpdatedAt: time.Now(), } jsonData, err := json.Marshal(data) if err != nil { return fmt.Errorf("failed to marshal context: %w", err) } return r.client.Set(ctx, contextKey(sessionID), jsonData, r.ttl).Err() } // GetContext 从 Redis 获取短期上下文 func (r *RedisMemoryRepository) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) { data, err := r.client.Get(ctx, contextKey(sessionID)).Bytes() if err != nil { if err == redis.Nil { return nil, nil } return nil, fmt.Errorf("failed to get context: %w", err) } var chatCtx model.ChatContext if err := json.Unmarshal(data, &chatCtx); err != nil { return nil, fmt.Errorf("failed to unmarshal context: %w", err) } return chatCtx.Messages, nil } // DeleteContext 删除短期上下文 func (r *RedisMemoryRepository) DeleteContext(ctx context.Context, sessionID string) error { return r.client.Del(ctx, contextKey(sessionID)).Err() } // PostgreSQLMemoryRepository PostgreSQL 长期记忆实现 type PostgreSQLMemoryRepository struct { db *gorm.DB } // NewPostgreSQLMemoryRepository 创建 PostgreSQL 长期记忆仓库 func NewPostgreSQLMemoryRepository(db *gorm.DB) *PostgreSQLMemoryRepository { return &PostgreSQLMemoryRepository{db: db} } // SaveMemory 保存长期记忆 func (r *PostgreSQLMemoryRepository) SaveMemory(ctx context.Context, memory *model.UserMemory) error { now := time.Now().UnixMilli() if memory.CreatedAt == 0 { memory.CreatedAt = now } if memory.UpdatedAt == 0 { memory.UpdatedAt = now } return r.db.WithContext(ctx).Exec( `INSERT INTO "ai_user_memories" ("user_id", "content", "keywords", "weight", "is_core", "created_at", "updated_at") VALUES ($1, $2, $3, $4, $5, $6, $7)`, memory.UserID, memory.Content, pq.Array(memory.Keywords), memory.Weight, memory.IsCore, memory.CreatedAt, memory.UpdatedAt, ).Error } // GetMemories 根据关键词查询记忆 func (r *PostgreSQLMemoryRepository) GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error) { var memories []model.UserMemory query := r.db.WithContext(ctx). Where("user_id = ?", userID). Order("weight DESC, created_at DESC"). Limit(limit) if len(keywords) > 0 { query = query.Where("keywords && $1", pq.Array(keywords)) } if err := query.Find(&memories).Error; err != nil { return nil, fmt.Errorf("failed to get memories: %w", err) } return memories, nil } // GetMemoriesByUserID 获取用户所有记忆 func (r *PostgreSQLMemoryRepository) GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error) { var memories []model.UserMemory if err := r.db.WithContext(ctx). Where("user_id = ?", userID). Order("weight DESC, created_at DESC"). Find(&memories).Error; err != nil { return nil, fmt.Errorf("failed to get memories: %w", err) } return memories, nil }