156 lines
4.7 KiB
Go
156 lines
4.7 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/lib/pq"
|
|
"github.com/redis/go-redis/v9"
|
|
"github.com/topfans/backend/services/aiChatService/model"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// ShortTermMemoryRepository 短期记忆仓库接口 (Redis)
|
|
type ShortTermMemoryRepository interface {
|
|
SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error
|
|
GetContext(ctx context.Context, sessionID string) ([]model.Message, error)
|
|
DeleteContext(ctx context.Context, sessionID string) error
|
|
}
|
|
|
|
// LongTermMemoryRepository 长期记忆仓库接口 (PostgreSQL)
|
|
type LongTermMemoryRepository interface {
|
|
SaveMemory(ctx context.Context, memory *model.UserMemory) error
|
|
GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error)
|
|
GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error)
|
|
}
|
|
|
|
// MemoryRepository 记忆仓库接口(兼容性别名,用于不需要区分短期/长期的场景)
|
|
type MemoryRepository interface {
|
|
ShortTermMemoryRepository
|
|
LongTermMemoryRepository
|
|
}
|
|
|
|
// RedisMemoryRepository Redis 短期记忆实现
|
|
type RedisMemoryRepository struct {
|
|
client *redis.Client
|
|
ttl time.Duration
|
|
}
|
|
|
|
// NewRedisMemoryRepository 创建 Redis 短期记忆仓库
|
|
func NewRedisMemoryRepository(client *redis.Client, ttlSeconds int) *RedisMemoryRepository {
|
|
return &RedisMemoryRepository{
|
|
client: client,
|
|
ttl: time.Duration(ttlSeconds) * time.Second,
|
|
}
|
|
}
|
|
|
|
// contextKey 生成 Redis key
|
|
func contextKey(sessionID string) string {
|
|
return fmt.Sprintf("context:%s", sessionID)
|
|
}
|
|
|
|
// SaveContext 保存短期上下文到 Redis
|
|
func (r *RedisMemoryRepository) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error {
|
|
data := model.ChatContext{
|
|
SessionID: sessionID,
|
|
Messages: messages,
|
|
PersonaID: personaID,
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
jsonData, err := json.Marshal(data)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal context: %w", err)
|
|
}
|
|
|
|
return r.client.Set(ctx, contextKey(sessionID), jsonData, r.ttl).Err()
|
|
}
|
|
|
|
// GetContext 从 Redis 获取短期上下文
|
|
func (r *RedisMemoryRepository) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) {
|
|
data, err := r.client.Get(ctx, contextKey(sessionID)).Bytes()
|
|
if err != nil {
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to get context: %w", err)
|
|
}
|
|
|
|
var chatCtx model.ChatContext
|
|
if err := json.Unmarshal(data, &chatCtx); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal context: %w", err)
|
|
}
|
|
|
|
return chatCtx.Messages, nil
|
|
}
|
|
|
|
// DeleteContext 删除短期上下文
|
|
func (r *RedisMemoryRepository) DeleteContext(ctx context.Context, sessionID string) error {
|
|
return r.client.Del(ctx, contextKey(sessionID)).Err()
|
|
}
|
|
|
|
// PostgreSQLMemoryRepository PostgreSQL 长期记忆实现
|
|
type PostgreSQLMemoryRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewPostgreSQLMemoryRepository 创建 PostgreSQL 长期记忆仓库
|
|
func NewPostgreSQLMemoryRepository(db *gorm.DB) *PostgreSQLMemoryRepository {
|
|
return &PostgreSQLMemoryRepository{db: db}
|
|
}
|
|
|
|
// SaveMemory 保存长期记忆
|
|
func (r *PostgreSQLMemoryRepository) SaveMemory(ctx context.Context, memory *model.UserMemory) error {
|
|
now := time.Now().UnixMilli()
|
|
if memory.CreatedAt == 0 {
|
|
memory.CreatedAt = now
|
|
}
|
|
if memory.UpdatedAt == 0 {
|
|
memory.UpdatedAt = now
|
|
}
|
|
|
|
return r.db.WithContext(ctx).Exec(
|
|
`INSERT INTO "ai_user_memories" ("user_id", "content", "keywords", "weight", "is_core", "created_at", "updated_at")
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
|
memory.UserID,
|
|
memory.Content,
|
|
pq.Array(memory.Keywords),
|
|
memory.Weight,
|
|
memory.IsCore,
|
|
memory.CreatedAt,
|
|
memory.UpdatedAt,
|
|
).Error
|
|
}
|
|
|
|
// GetMemories 根据关键词查询记忆
|
|
func (r *PostgreSQLMemoryRepository) GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error) {
|
|
var memories []model.UserMemory
|
|
query := r.db.WithContext(ctx).
|
|
Where("user_id = ?", userID).
|
|
Order("weight DESC, created_at DESC").
|
|
Limit(limit)
|
|
|
|
if len(keywords) > 0 {
|
|
query = query.Where("keywords && $1", pq.Array(keywords))
|
|
}
|
|
|
|
if err := query.Find(&memories).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get memories: %w", err)
|
|
}
|
|
|
|
return memories, nil
|
|
}
|
|
|
|
// GetMemoriesByUserID 获取用户所有记忆
|
|
func (r *PostgreSQLMemoryRepository) GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error) {
|
|
var memories []model.UserMemory
|
|
if err := r.db.WithContext(ctx).
|
|
Where("user_id = ?", userID).
|
|
Order("weight DESC, created_at DESC").
|
|
Find(&memories).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get memories: %w", err)
|
|
}
|
|
return memories, nil
|
|
} |