211 lines
5.3 KiB
Go
211 lines
5.3 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"strings"
|
||
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"github.com/topfans/backend/services/aiChatService/model"
|
||
"github.com/topfans/backend/services/aiChatService/repository"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// MemoryService 记忆服务
|
||
type MemoryService struct {
|
||
shortTermRepo repository.ShortTermMemoryRepository
|
||
longTermRepo repository.LongTermMemoryRepository
|
||
}
|
||
|
||
// NewMemoryService 创建记忆服务
|
||
func NewMemoryService(shortTermRepo repository.ShortTermMemoryRepository, longTermRepo repository.LongTermMemoryRepository) *MemoryService {
|
||
return &MemoryService{
|
||
shortTermRepo: shortTermRepo,
|
||
longTermRepo: longTermRepo,
|
||
}
|
||
}
|
||
|
||
// SaveContext 保存短期上下文
|
||
func (s *MemoryService) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error {
|
||
return s.shortTermRepo.SaveContext(ctx, sessionID, messages, personaID)
|
||
}
|
||
|
||
// GetContext 获取短期上下文
|
||
func (s *MemoryService) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) {
|
||
return s.shortTermRepo.GetContext(ctx, sessionID)
|
||
}
|
||
|
||
// RecallMemories 召回相关记忆
|
||
func (s *MemoryService) RecallMemories(ctx context.Context, userID int64, userInput string, limit int) (string, error) {
|
||
// 从用户输入提取关键词
|
||
keywords := extractKeywords(userInput)
|
||
|
||
// 查询长期记忆
|
||
memories, err := s.longTermRepo.GetMemories(ctx, userID, keywords, limit)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(memories) == 0 {
|
||
return "", nil
|
||
}
|
||
|
||
// 组装记忆文本
|
||
var builder strings.Builder
|
||
builder.WriteString("# 用户核心记忆\n")
|
||
for _, m := range memories {
|
||
builder.WriteString("- ")
|
||
builder.WriteString(m.Content)
|
||
builder.WriteString("\n")
|
||
}
|
||
|
||
return builder.String(), nil
|
||
}
|
||
|
||
// ExtractMemory 提取记忆(每5轮对话触发一次)
|
||
func (s *MemoryService) ExtractMemory(ctx context.Context, userID int64, recentMessages []model.Message) error {
|
||
// 从最近的用户消息中提取关键词
|
||
var userMessages []string
|
||
for i := len(recentMessages) - 1; i >= 0 && len(userMessages) < 5; i-- {
|
||
if recentMessages[i].Role == "user" {
|
||
userMessages = append(userMessages, recentMessages[i].Content)
|
||
}
|
||
}
|
||
|
||
if len(userMessages) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 简单关键词提取
|
||
keywords := extractKeywordsFromMessages(userMessages)
|
||
|
||
// 生成记忆摘要
|
||
content := summarizeMessages(userMessages)
|
||
if content == "" {
|
||
return nil
|
||
}
|
||
|
||
// 保存到长期记忆
|
||
memory := &model.UserMemory{
|
||
UserID: userID,
|
||
Content: content,
|
||
Keywords: keywords,
|
||
Weight: 50,
|
||
}
|
||
|
||
return s.longTermRepo.SaveMemory(ctx, memory)
|
||
}
|
||
|
||
// GetUserMemories 获取用户所有长期记忆
|
||
func (s *MemoryService) GetUserMemories(ctx context.Context, userID int64) ([]model.UserMemory, error) {
|
||
return s.longTermRepo.GetMemoriesByUserID(ctx, userID)
|
||
}
|
||
|
||
// extractKeywords 从用户输入提取关键词
|
||
func extractKeywords(text string) []string {
|
||
// 简单实现:按标点符号分割,检查长度大于2的词
|
||
var keywords []string
|
||
words := strings.FieldsFunc(text, func(r rune) bool {
|
||
return r == ',' || r == '.' || r == '!' || r == '?' || r == ',' || r == '。' || r == '!' || r == '?' || r == ' ' || r == '\n'
|
||
})
|
||
|
||
for _, word := range words {
|
||
word = strings.TrimSpace(word)
|
||
if len(word) >= 2 {
|
||
keywords = append(keywords, word)
|
||
}
|
||
}
|
||
|
||
return keywords
|
||
}
|
||
|
||
// extractKeywordsFromMessages 从多条消息提取关键词
|
||
func extractKeywordsFromMessages(messages []string) []string {
|
||
var allKeywords []string
|
||
seen := make(map[string]bool)
|
||
|
||
for _, msg := range messages {
|
||
keywords := extractKeywords(msg)
|
||
for _, k := range keywords {
|
||
if !seen[k] {
|
||
seen[k] = true
|
||
allKeywords = append(allKeywords, k)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 只保留前5个关键词
|
||
if len(allKeywords) > 5 {
|
||
allKeywords = allKeywords[:5]
|
||
}
|
||
|
||
return allKeywords
|
||
}
|
||
|
||
// summarizeMessages 生成记忆摘要
|
||
func summarizeMessages(messages []string) string {
|
||
if len(messages) == 0 {
|
||
return ""
|
||
}
|
||
|
||
// 简单实现:拼接前3条消息的核心内容
|
||
var summary strings.Builder
|
||
count := 0
|
||
for _, msg := range messages {
|
||
if count >= 3 {
|
||
break
|
||
}
|
||
// 截断过长的消息
|
||
if len(msg) > 100 {
|
||
msg = msg[:100] + "..."
|
||
}
|
||
summary.WriteString(msg)
|
||
summary.WriteString("; ")
|
||
count++
|
||
}
|
||
|
||
result := strings.TrimSpace(summary.String())
|
||
if len(result) > 200 {
|
||
result = result[:200] + "..."
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// FormatMemoriesForPrompt 将记忆格式化为 prompt 片段
|
||
func FormatMemoriesForPrompt(memories []model.UserMemory) string {
|
||
if len(memories) == 0 {
|
||
return ""
|
||
}
|
||
|
||
var builder strings.Builder
|
||
builder.WriteString("# 用户核心记忆\n")
|
||
for _, m := range memories {
|
||
builder.WriteString("- ")
|
||
builder.WriteString(m.Content)
|
||
builder.WriteString("\n")
|
||
}
|
||
|
||
return builder.String()
|
||
}
|
||
|
||
// GetTurnCount 计算对话轮数
|
||
func GetTurnCount(messages []model.Message) int {
|
||
turns := 0
|
||
for i := 0; i < len(messages)-1; i += 2 {
|
||
if i+1 < len(messages) && messages[i].Role == "user" && messages[i+1].Role == "assistant" {
|
||
turns++
|
||
}
|
||
}
|
||
return turns
|
||
}
|
||
|
||
// ShouldExtractMemory 判断是否应该触发记忆提取
|
||
func ShouldExtractMemory(messages []model.Message, triggerTurns int) bool {
|
||
turns := GetTurnCount(messages)
|
||
logger.Logger.Info("ShouldExtractMemory check",
|
||
zap.Int("turns", turns),
|
||
zap.Int("trigger_turns", triggerTurns),
|
||
zap.Bool("should_extract", turns >= triggerTurns),
|
||
)
|
||
return turns >= triggerTurns
|
||
} |