package service import ( "context" "strings" "github.com/topfans/backend/pkg/logger" "github.com/topfans/backend/services/aiChatService/model" "github.com/topfans/backend/services/aiChatService/repository" "go.uber.org/zap" ) // MemoryService 记忆服务 type MemoryService struct { shortTermRepo repository.ShortTermMemoryRepository longTermRepo repository.LongTermMemoryRepository } // NewMemoryService 创建记忆服务 func NewMemoryService(shortTermRepo repository.ShortTermMemoryRepository, longTermRepo repository.LongTermMemoryRepository) *MemoryService { return &MemoryService{ shortTermRepo: shortTermRepo, longTermRepo: longTermRepo, } } // SaveContext 保存短期上下文 func (s *MemoryService) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error { return s.shortTermRepo.SaveContext(ctx, sessionID, messages, personaID) } // GetContext 获取短期上下文 func (s *MemoryService) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) { return s.shortTermRepo.GetContext(ctx, sessionID) } // RecallMemories 召回相关记忆 func (s *MemoryService) RecallMemories(ctx context.Context, userID int64, userInput string, limit int) (string, error) { // 从用户输入提取关键词 keywords := extractKeywords(userInput) // 查询长期记忆 memories, err := s.longTermRepo.GetMemories(ctx, userID, keywords, limit) if err != nil { return "", err } if len(memories) == 0 { return "", nil } // 组装记忆文本 var builder strings.Builder builder.WriteString("# 用户核心记忆\n") for _, m := range memories { builder.WriteString("- ") builder.WriteString(m.Content) builder.WriteString("\n") } return builder.String(), nil } // ExtractMemory 提取记忆(每5轮对话触发一次) func (s *MemoryService) ExtractMemory(ctx context.Context, userID int64, recentMessages []model.Message) error { // 从最近的用户消息中提取关键词 var userMessages []string for i := len(recentMessages) - 1; i >= 0 && len(userMessages) < 5; i-- { if recentMessages[i].Role == "user" { userMessages = append(userMessages, recentMessages[i].Content) } } if len(userMessages) == 0 { return nil } // 简单关键词提取 keywords := extractKeywordsFromMessages(userMessages) // 生成记忆摘要 content := summarizeMessages(userMessages) if content == "" { return nil } // 保存到长期记忆 memory := &model.UserMemory{ UserID: userID, Content: content, Keywords: keywords, Weight: 50, } return s.longTermRepo.SaveMemory(ctx, memory) } // GetUserMemories 获取用户所有长期记忆 func (s *MemoryService) GetUserMemories(ctx context.Context, userID int64) ([]model.UserMemory, error) { return s.longTermRepo.GetMemoriesByUserID(ctx, userID) } // extractKeywords 从用户输入提取关键词 func extractKeywords(text string) []string { // 简单实现:按标点符号分割,检查长度大于2的词 var keywords []string words := strings.FieldsFunc(text, func(r rune) bool { return r == ',' || r == '.' || r == '!' || r == '?' || r == ',' || r == '。' || r == '!' || r == '?' || r == ' ' || r == '\n' }) for _, word := range words { word = strings.TrimSpace(word) if len(word) >= 2 { keywords = append(keywords, word) } } return keywords } // extractKeywordsFromMessages 从多条消息提取关键词 func extractKeywordsFromMessages(messages []string) []string { var allKeywords []string seen := make(map[string]bool) for _, msg := range messages { keywords := extractKeywords(msg) for _, k := range keywords { if !seen[k] { seen[k] = true allKeywords = append(allKeywords, k) } } } // 只保留前5个关键词 if len(allKeywords) > 5 { allKeywords = allKeywords[:5] } return allKeywords } // summarizeMessages 生成记忆摘要 func summarizeMessages(messages []string) string { if len(messages) == 0 { return "" } // 简单实现:拼接前3条消息的核心内容 var summary strings.Builder count := 0 for _, msg := range messages { if count >= 3 { break } // 截断过长的消息 if len(msg) > 100 { msg = msg[:100] + "..." } summary.WriteString(msg) summary.WriteString("; ") count++ } result := strings.TrimSpace(summary.String()) if len(result) > 200 { result = result[:200] + "..." } return result } // FormatMemoriesForPrompt 将记忆格式化为 prompt 片段 func FormatMemoriesForPrompt(memories []model.UserMemory) string { if len(memories) == 0 { return "" } var builder strings.Builder builder.WriteString("# 用户核心记忆\n") for _, m := range memories { builder.WriteString("- ") builder.WriteString(m.Content) builder.WriteString("\n") } return builder.String() } // GetTurnCount 计算对话轮数 func GetTurnCount(messages []model.Message) int { turns := 0 for i := 0; i < len(messages)-1; i += 2 { if i+1 < len(messages) && messages[i].Role == "user" && messages[i+1].Role == "assistant" { turns++ } } return turns } // ShouldExtractMemory 判断是否应该触发记忆提取 func ShouldExtractMemory(messages []model.Message, triggerTurns int) bool { turns := GetTurnCount(messages) logger.Logger.Info("ShouldExtractMemory check", zap.Int("turns", turns), zap.Int("trigger_turns", triggerTurns), zap.Bool("should_extract", turns >= triggerTurns), ) return turns >= triggerTurns }