topfans/backend/services/aiChatService/repository/memory_repository.go
2026-05-28 12:00:19 +08:00

156 lines
4.7 KiB
Go

package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/lib/pq"
"github.com/redis/go-redis/v9"
"github.com/topfans/backend/services/aiChatService/model"
"gorm.io/gorm"
)
// ShortTermMemoryRepository 短期记忆仓库接口 (Redis)
type ShortTermMemoryRepository interface {
SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error
GetContext(ctx context.Context, sessionID string) ([]model.Message, error)
DeleteContext(ctx context.Context, sessionID string) error
}
// LongTermMemoryRepository 长期记忆仓库接口 (PostgreSQL)
type LongTermMemoryRepository interface {
SaveMemory(ctx context.Context, memory *model.UserMemory) error
GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error)
GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error)
}
// MemoryRepository 记忆仓库接口(兼容性别名,用于不需要区分短期/长期的场景)
type MemoryRepository interface {
ShortTermMemoryRepository
LongTermMemoryRepository
}
// RedisMemoryRepository Redis 短期记忆实现
type RedisMemoryRepository struct {
client *redis.Client
ttl time.Duration
}
// NewRedisMemoryRepository 创建 Redis 短期记忆仓库
func NewRedisMemoryRepository(client *redis.Client, ttlSeconds int) *RedisMemoryRepository {
return &RedisMemoryRepository{
client: client,
ttl: time.Duration(ttlSeconds) * time.Second,
}
}
// contextKey 生成 Redis key
func contextKey(sessionID string) string {
return fmt.Sprintf("context:%s", sessionID)
}
// SaveContext 保存短期上下文到 Redis
func (r *RedisMemoryRepository) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error {
data := model.ChatContext{
SessionID: sessionID,
Messages: messages,
PersonaID: personaID,
UpdatedAt: time.Now(),
}
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal context: %w", err)
}
return r.client.Set(ctx, contextKey(sessionID), jsonData, r.ttl).Err()
}
// GetContext 从 Redis 获取短期上下文
func (r *RedisMemoryRepository) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) {
data, err := r.client.Get(ctx, contextKey(sessionID)).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("failed to get context: %w", err)
}
var chatCtx model.ChatContext
if err := json.Unmarshal(data, &chatCtx); err != nil {
return nil, fmt.Errorf("failed to unmarshal context: %w", err)
}
return chatCtx.Messages, nil
}
// DeleteContext 删除短期上下文
func (r *RedisMemoryRepository) DeleteContext(ctx context.Context, sessionID string) error {
return r.client.Del(ctx, contextKey(sessionID)).Err()
}
// PostgreSQLMemoryRepository PostgreSQL 长期记忆实现
type PostgreSQLMemoryRepository struct {
db *gorm.DB
}
// NewPostgreSQLMemoryRepository 创建 PostgreSQL 长期记忆仓库
func NewPostgreSQLMemoryRepository(db *gorm.DB) *PostgreSQLMemoryRepository {
return &PostgreSQLMemoryRepository{db: db}
}
// SaveMemory 保存长期记忆
func (r *PostgreSQLMemoryRepository) SaveMemory(ctx context.Context, memory *model.UserMemory) error {
now := time.Now().UnixMilli()
if memory.CreatedAt == 0 {
memory.CreatedAt = now
}
if memory.UpdatedAt == 0 {
memory.UpdatedAt = now
}
return r.db.WithContext(ctx).Exec(
`INSERT INTO "ai_user_memories" ("user_id", "content", "keywords", "weight", "is_core", "created_at", "updated_at")
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
memory.UserID,
memory.Content,
pq.Array(memory.Keywords),
memory.Weight,
memory.IsCore,
memory.CreatedAt,
memory.UpdatedAt,
).Error
}
// GetMemories 根据关键词查询记忆
func (r *PostgreSQLMemoryRepository) GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error) {
var memories []model.UserMemory
query := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("weight DESC, created_at DESC").
Limit(limit)
if len(keywords) > 0 {
query = query.Where("keywords && $1", pq.Array(keywords))
}
if err := query.Find(&memories).Error; err != nil {
return nil, fmt.Errorf("failed to get memories: %w", err)
}
return memories, nil
}
// GetMemoriesByUserID 获取用户所有记忆
func (r *PostgreSQLMemoryRepository) GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error) {
var memories []model.UserMemory
if err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("weight DESC, created_at DESC").
Find(&memories).Error; err != nil {
return nil, fmt.Errorf("failed to get memories: %w", err)
}
return memories, nil
}