topfans/backend/services/aiChatService/provider/ai_chat_provider.go
2026-05-28 12:00:19 +08:00

443 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package provider
import (
"context"
"fmt"
"io"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/aiChatService/model"
"github.com/topfans/backend/services/aiChatService/service"
pb "github.com/topfans/backend/pkg/proto/ai_chat"
"go.uber.org/zap"
)
// AIChatProvider AI Chat 服务 Provider 实现
type AIChatProvider struct {
chatService *service.ChatService
personaService *service.PersonaService
memoryService *service.MemoryService
auditService *service.AuditService
}
// 确保 AIChatProvider 实现了 AIChatServiceHandler 接口
var _ pb.AIChatServiceHandler = (*AIChatProvider)(nil)
// NewAIChatProvider 创建 AIChatProvider 实例
func NewAIChatProvider(
chatService *service.ChatService,
personaService *service.PersonaService,
memoryService *service.MemoryService,
auditService *service.AuditService,
) *AIChatProvider {
return &AIChatProvider{
chatService: chatService,
personaService: personaService,
memoryService: memoryService,
auditService: auditService,
}
}
// InitSession 初始化会话,返回欢迎消息
func (p *AIChatProvider) InitSession(ctx context.Context, req *pb.InitSessionRequest) (*pb.InitSessionResponse, error) {
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
sessionID := req.SessionId
if sessionID == "" {
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
if err != nil {
logger.Logger.Error("Failed to extract user info from attachments",
zap.Error(err),
)
return nil, err
}
logger.Logger.Info("Received InitSession request",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
)
// 获取欢迎消息
welcomeMessage := p.chatService.GetWelcomeMessage(sessionID, userID, starID)
return &pb.InitSessionResponse{
WelcomeMessage: welcomeMessage,
SessionId: sessionID,
}, nil
}
// SendMessage 发送消息(流式返回)
func (p *AIChatProvider) SendMessage(ctx context.Context, req *pb.ChatMessageRequest, stream pb.AIChatService_SendMessageServer) error {
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
sessionID := req.SessionId
if sessionID == "" {
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
if err != nil {
logger.Logger.Error("Failed to extract user info from attachments",
zap.Error(err),
)
stream.Send(&pb.ChatMessageResponse{
Content: "user authentication required",
SessionId: sessionID,
IsEnd: true,
Error: err.Error(),
})
return err
}
if sessionID == "" {
// 如果没有 sessionID生成一个
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
message := req.Message
personaID := req.PersonaId
logger.Logger.Info("Received SendMessage request",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
zap.String("message", message),
)
// 1. 前置审核
if !p.auditService.AuditText(message) {
logger.Logger.Info("Message blocked by audit")
stream.Send(&pb.ChatMessageResponse{
Content: p.auditService.DefaultSafeResponse(),
SessionId: sessionID,
IsEnd: false,
})
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
return nil
}
// 2. 获取人设
persona, err := p.personaService.GetPersonaOrDefault(ctx, userID, personaID)
if err != nil {
logger.Logger.Error("Failed to get persona", zap.Error(err))
stream.Send(&pb.ChatMessageResponse{
Content: err.Error(),
IsEnd: true,
Error: err.Error(),
})
return err
}
// 3. 记忆召回
memoryText, _ := p.memoryService.RecallMemories(ctx, userID, message, 5)
// 4. 获取对话历史
history, _ := p.memoryService.GetContext(ctx, sessionID)
// 5. 构建 Prompt
tokenizer := &service.Tokenizer{}
messages, _ := service.BuildPrompt(
persona.SystemPrompt,
memoryText,
history,
message,
tokenizer,
)
// 6. 检查是否需要调用大模型
if service.IsNoNeedLLMCall(message) {
stream.Send(&pb.ChatMessageResponse{
Content: "好的,我听到了。",
SessionId: sessionID,
IsEnd: false,
})
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
return nil
}
// 7. 调用大模型(流式)
streamReader, err := p.chatService.LLMService.StreamChat(ctx, messages)
if err != nil {
logger.Logger.Error("LLM call failed", zap.Error(err))
// 检查是否是敏感内容错误
if _, ok := err.(*service.SensitiveContentError); ok {
logger.Logger.Info("Content blocked by MiniMax safety filter, trying backup model")
// 尝试备用模型
streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages)
if err != nil {
// 备用模型也失败
logger.Logger.Error("Backup model also failed", zap.Error(err))
stream.Send(&pb.ChatMessageResponse{
Content: p.auditService.DefaultSafeResponse(),
SessionId: sessionID,
IsEnd: true,
})
return nil
}
} else {
// 其他错误,尝试备用模型
streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages)
if err != nil {
logger.Logger.Error("Backup model also failed", zap.Error(err))
stream.Send(&pb.ChatMessageResponse{
Content: "抱歉,服务暂时不可用",
SessionId: sessionID,
IsEnd: true,
Error: err.Error(),
})
return err
}
}
}
defer streamReader.Close()
// 8. 流式处理
var fullResponse string
var sentEnd = false
for {
content, done, err := streamReader.Next()
if err != nil {
if err == io.EOF {
// 流结束,发送 is_end
if !sentEnd {
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
sentEnd = true
}
break
}
logger.Logger.Error("Stream read error", zap.Error(err))
break
}
// 后置审核(逐 token
if !p.auditService.AuditResponse(content) {
logger.Logger.Info("Response blocked by audit")
streamReader.Close()
// 发送安全回复作为替代
stream.Send(&pb.ChatMessageResponse{
Content: p.auditService.DefaultSafeResponse(),
SessionId: sessionID,
IsEnd: false,
})
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
sentEnd = true
return nil
}
fullResponse += content
// 发送 token 给客户端
if err := stream.Send(&pb.ChatMessageResponse{
Content: content,
SessionId: sessionID,
IsEnd: done,
}); err != nil {
logger.Logger.Error("Failed to send message to stream", zap.Error(err))
return err
}
if done {
sentEnd = true
}
}
// 9. 保存上下文
newHistory := append(history, model.Message{Role: "user", Content: message})
newHistory = append(newHistory, model.Message{Role: "assistant", Content: fullResponse})
p.memoryService.SaveContext(ctx, sessionID, newHistory, personaID)
// 10. 触发记忆提取每5轮
newTurns := len(newHistory) / 2
logger.Logger.Info("Memory extraction check",
zap.Int("message_count", len(newHistory)),
zap.Int("turns", newTurns),
zap.Bool("should_extract", newTurns >= 5),
)
if newTurns >= 5 {
logger.Logger.Info("Triggering memory extraction", zap.Int64("user_id", userID))
if err := p.memoryService.ExtractMemory(ctx, userID, newHistory); err != nil {
logger.Logger.Error("Failed to extract memory", zap.Error(err))
} else {
logger.Logger.Info("Memory extracted successfully", zap.Int64("user_id", userID))
}
}
logger.Logger.Info("SendMessage completed",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
zap.Int("response_length", len(fullResponse)),
)
return nil
}
// GetHistory 获取对话历史
func (p *AIChatProvider) GetHistory(ctx context.Context, req *pb.ChatHistoryRequest) (*pb.ChatHistoryResponse, error) {
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
if err != nil {
logger.Logger.Error("Failed to extract user info from attachments",
zap.Error(err),
)
return nil, err
}
sessionID := req.SessionId
if sessionID == "" {
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
logger.Logger.Info("Received GetHistory request",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
)
messages, err := p.memoryService.GetContext(ctx, sessionID)
if err != nil {
return nil, err
}
pbMessages := make([]*pb.Message, len(messages))
for i, m := range messages {
pbMessages[i] = &pb.Message{
Role: m.Role,
Content: m.Content,
}
}
return &pb.ChatHistoryResponse{
History: pbMessages,
}, nil
}
// GetPersonas 获取用户的所有人设
func (p *AIChatProvider) GetPersonas(ctx context.Context, req *pb.GetPersonasRequest) (*pb.PersonaListResponse, error) {
userID := req.UserId
logger.Logger.Info("Received GetPersonas request",
zap.Int64("user_id", userID),
)
personas, err := p.personaService.GetPersonas(ctx, userID)
if err != nil {
return nil, err
}
pbPersonas := make([]*pb.PersonaInfo, len(personas))
for i, persona := range personas {
pbPersonas[i] = &pb.PersonaInfo{
Id: persona.ID,
Name: persona.Name,
Description: persona.Description,
AvatarUrl: persona.AvatarURL,
TalkStyle: persona.TalkStyle,
IsDefault: persona.IsDefault,
CreatedAt: persona.CreatedAt,
UpdatedAt: persona.UpdatedAt,
}
}
return &pb.PersonaListResponse{
Personas: pbPersonas,
}, nil
}
// extractUserInfoFromDubboAttachments 从 Dubbo attachments 中提取用户信息
func extractUserInfoFromDubboAttachments(ctx context.Context) (int64, int64, error) {
logger.Logger.Debug("Extracting user info from Dubbo attachments",
zap.Any("context_type", fmt.Sprintf("%T", ctx)),
)
// Try to get any value from context
if attachments := ctx.Value(constant.AttachmentKey); attachments != nil {
logger.Logger.Debug("Found attachments via constant.AttachmentKey",
zap.Any("attachments", attachments),
zap.String("type", fmt.Sprintf("%T", attachments)),
)
if attMap, ok := attachments.(map[string]interface{}); ok {
logger.Logger.Debug("Attachments map content",
zap.Any("map", attMap),
)
userID := parseIntValue(attMap["user_id"])
starID := parseIntValue(attMap["star_id"])
logger.Logger.Debug("Parsed user info from attachments",
zap.Any("user_id_raw", attMap["user_id"]),
zap.Int64("user_id", userID),
zap.Any("star_id_raw", attMap["star_id"]),
zap.Int64("star_id", starID),
)
if userID > 0 && starID > 0 {
return userID, starID, nil
}
logger.Logger.Warn("Parsed user_id or star_id is zero",
zap.Int64("user_id", userID),
zap.Int64("star_id", starID),
)
} else {
logger.Logger.Warn("Attachments is not map[string]interface{}",
zap.String("actual_type", fmt.Sprintf("%T", attachments)),
)
}
} else {
logger.Logger.Warn("ctx.Value(constant.AttachmentKey) returned nil",
zap.String("constant_attachment_key", string(constant.AttachmentKey)),
)
}
// Debug: list all keys in context
logger.Logger.Warn("Checking alternative key: 'attachment'")
if val := ctx.Value("attachment"); val != nil {
logger.Logger.Debug("Found value with key 'attachment'",
zap.Any("value", val),
zap.String("type", fmt.Sprintf("%T", val)),
)
}
return 0, 0, fmt.Errorf("user info not found in Dubbo attachments")
}
// parseIntValue 解析各种类型的值为 int64
func parseIntValue(v interface{}) int64 {
switch val := v.(type) {
case int64:
return val
case int:
return int64(val)
case float64:
return int64(val)
case string:
var result int64
fmt.Sscanf(val, "%d", &result)
return result
case []string:
if len(val) > 0 {
var result int64
fmt.Sscanf(val[0], "%d", &result)
return result
}
case []interface{}:
if len(val) > 0 {
switch s := val[0].(type) {
case string:
var result int64
fmt.Sscanf(s, "%d", &result)
return result
case int:
return int64(s)
case int64:
return s
}
}
}
return 0
}