topfans/backend/services/aiChatService/service/memory_service.go
2026-05-28 12:00:19 +08:00

211 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"strings"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/aiChatService/model"
"github.com/topfans/backend/services/aiChatService/repository"
"go.uber.org/zap"
)
// MemoryService 记忆服务
type MemoryService struct {
shortTermRepo repository.ShortTermMemoryRepository
longTermRepo repository.LongTermMemoryRepository
}
// NewMemoryService 创建记忆服务
func NewMemoryService(shortTermRepo repository.ShortTermMemoryRepository, longTermRepo repository.LongTermMemoryRepository) *MemoryService {
return &MemoryService{
shortTermRepo: shortTermRepo,
longTermRepo: longTermRepo,
}
}
// SaveContext 保存短期上下文
func (s *MemoryService) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error {
return s.shortTermRepo.SaveContext(ctx, sessionID, messages, personaID)
}
// GetContext 获取短期上下文
func (s *MemoryService) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) {
return s.shortTermRepo.GetContext(ctx, sessionID)
}
// RecallMemories 召回相关记忆
func (s *MemoryService) RecallMemories(ctx context.Context, userID int64, userInput string, limit int) (string, error) {
// 从用户输入提取关键词
keywords := extractKeywords(userInput)
// 查询长期记忆
memories, err := s.longTermRepo.GetMemories(ctx, userID, keywords, limit)
if err != nil {
return "", err
}
if len(memories) == 0 {
return "", nil
}
// 组装记忆文本
var builder strings.Builder
builder.WriteString("# 用户核心记忆\n")
for _, m := range memories {
builder.WriteString("- ")
builder.WriteString(m.Content)
builder.WriteString("\n")
}
return builder.String(), nil
}
// ExtractMemory 提取记忆每5轮对话触发一次
func (s *MemoryService) ExtractMemory(ctx context.Context, userID int64, recentMessages []model.Message) error {
// 从最近的用户消息中提取关键词
var userMessages []string
for i := len(recentMessages) - 1; i >= 0 && len(userMessages) < 5; i-- {
if recentMessages[i].Role == "user" {
userMessages = append(userMessages, recentMessages[i].Content)
}
}
if len(userMessages) == 0 {
return nil
}
// 简单关键词提取
keywords := extractKeywordsFromMessages(userMessages)
// 生成记忆摘要
content := summarizeMessages(userMessages)
if content == "" {
return nil
}
// 保存到长期记忆
memory := &model.UserMemory{
UserID: userID,
Content: content,
Keywords: keywords,
Weight: 50,
}
return s.longTermRepo.SaveMemory(ctx, memory)
}
// GetUserMemories 获取用户所有长期记忆
func (s *MemoryService) GetUserMemories(ctx context.Context, userID int64) ([]model.UserMemory, error) {
return s.longTermRepo.GetMemoriesByUserID(ctx, userID)
}
// extractKeywords 从用户输入提取关键词
func extractKeywords(text string) []string {
// 简单实现按标点符号分割检查长度大于2的词
var keywords []string
words := strings.FieldsFunc(text, func(r rune) bool {
return r == ',' || r == '.' || r == '!' || r == '?' || r == '' || r == '。' || r == '' || r == '' || r == ' ' || r == '\n'
})
for _, word := range words {
word = strings.TrimSpace(word)
if len(word) >= 2 {
keywords = append(keywords, word)
}
}
return keywords
}
// extractKeywordsFromMessages 从多条消息提取关键词
func extractKeywordsFromMessages(messages []string) []string {
var allKeywords []string
seen := make(map[string]bool)
for _, msg := range messages {
keywords := extractKeywords(msg)
for _, k := range keywords {
if !seen[k] {
seen[k] = true
allKeywords = append(allKeywords, k)
}
}
}
// 只保留前5个关键词
if len(allKeywords) > 5 {
allKeywords = allKeywords[:5]
}
return allKeywords
}
// summarizeMessages 生成记忆摘要
func summarizeMessages(messages []string) string {
if len(messages) == 0 {
return ""
}
// 简单实现拼接前3条消息的核心内容
var summary strings.Builder
count := 0
for _, msg := range messages {
if count >= 3 {
break
}
// 截断过长的消息
if len(msg) > 100 {
msg = msg[:100] + "..."
}
summary.WriteString(msg)
summary.WriteString("; ")
count++
}
result := strings.TrimSpace(summary.String())
if len(result) > 200 {
result = result[:200] + "..."
}
return result
}
// FormatMemoriesForPrompt 将记忆格式化为 prompt 片段
func FormatMemoriesForPrompt(memories []model.UserMemory) string {
if len(memories) == 0 {
return ""
}
var builder strings.Builder
builder.WriteString("# 用户核心记忆\n")
for _, m := range memories {
builder.WriteString("- ")
builder.WriteString(m.Content)
builder.WriteString("\n")
}
return builder.String()
}
// GetTurnCount 计算对话轮数
func GetTurnCount(messages []model.Message) int {
turns := 0
for i := 0; i < len(messages)-1; i += 2 {
if i+1 < len(messages) && messages[i].Role == "user" && messages[i+1].Role == "assistant" {
turns++
}
}
return turns
}
// ShouldExtractMemory 判断是否应该触发记忆提取
func ShouldExtractMemory(messages []model.Message, triggerTurns int) bool {
turns := GetTurnCount(messages)
logger.Logger.Info("ShouldExtractMemory check",
zap.Int("turns", turns),
zap.Int("trigger_turns", triggerTurns),
zap.Bool("should_extract", turns >= triggerTurns),
)
return turns >= triggerTurns
}