package provider import ( "context" "fmt" "io" "dubbo.apache.org/dubbo-go/v3/common/constant" "github.com/topfans/backend/pkg/logger" "github.com/topfans/backend/services/aiChatService/model" "github.com/topfans/backend/services/aiChatService/service" pb "github.com/topfans/backend/pkg/proto/ai_chat" "go.uber.org/zap" ) // AIChatProvider AI Chat 服务 Provider 实现 type AIChatProvider struct { chatService *service.ChatService personaService *service.PersonaService memoryService *service.MemoryService auditService *service.AuditService } // 确保 AIChatProvider 实现了 AIChatServiceHandler 接口 var _ pb.AIChatServiceHandler = (*AIChatProvider)(nil) // NewAIChatProvider 创建 AIChatProvider 实例 func NewAIChatProvider( chatService *service.ChatService, personaService *service.PersonaService, memoryService *service.MemoryService, auditService *service.AuditService, ) *AIChatProvider { return &AIChatProvider{ chatService: chatService, personaService: personaService, memoryService: memoryService, auditService: auditService, } } // InitSession 初始化会话,返回欢迎消息 func (p *AIChatProvider) InitSession(ctx context.Context, req *pb.InitSessionRequest) (*pb.InitSessionResponse, error) { userID, starID, err := extractUserInfoFromDubboAttachments(ctx) sessionID := req.SessionId if sessionID == "" { sessionID = fmt.Sprintf("%d_%d", userID, starID) } if err != nil { logger.Logger.Error("Failed to extract user info from attachments", zap.Error(err), ) return nil, err } logger.Logger.Info("Received InitSession request", zap.Int64("user_id", userID), zap.String("session_id", sessionID), ) // 获取欢迎消息 welcomeMessage := p.chatService.GetWelcomeMessage(sessionID, userID, starID) return &pb.InitSessionResponse{ WelcomeMessage: welcomeMessage, SessionId: sessionID, }, nil } // SendMessage 发送消息(流式返回) func (p *AIChatProvider) SendMessage(ctx context.Context, req *pb.ChatMessageRequest, stream pb.AIChatService_SendMessageServer) error { userID, starID, err := extractUserInfoFromDubboAttachments(ctx) sessionID := req.SessionId if sessionID == "" { sessionID = fmt.Sprintf("%d_%d", userID, starID) } if err != nil { logger.Logger.Error("Failed to extract user info from attachments", zap.Error(err), ) stream.Send(&pb.ChatMessageResponse{ Content: "user authentication required", SessionId: sessionID, IsEnd: true, Error: err.Error(), }) return err } if sessionID == "" { // 如果没有 sessionID,生成一个 sessionID = fmt.Sprintf("%d_%d", userID, starID) } message := req.Message personaID := req.PersonaId logger.Logger.Info("Received SendMessage request", zap.Int64("user_id", userID), zap.String("session_id", sessionID), zap.String("message", message), ) // 1. 前置审核 if !p.auditService.AuditText(message) { logger.Logger.Info("Message blocked by audit") stream.Send(&pb.ChatMessageResponse{ Content: p.auditService.DefaultSafeResponse(), SessionId: sessionID, IsEnd: false, }) stream.Send(&pb.ChatMessageResponse{ SessionId: sessionID, IsEnd: true, }) return nil } // 2. 获取人设 persona, err := p.personaService.GetPersonaOrDefault(ctx, userID, personaID) if err != nil { logger.Logger.Error("Failed to get persona", zap.Error(err)) stream.Send(&pb.ChatMessageResponse{ Content: err.Error(), IsEnd: true, Error: err.Error(), }) return err } // 3. 记忆召回 memoryText, _ := p.memoryService.RecallMemories(ctx, userID, message, 5) // 4. 获取对话历史 history, _ := p.memoryService.GetContext(ctx, sessionID) // 5. 构建 Prompt tokenizer := &service.Tokenizer{} messages, _ := service.BuildPrompt( persona.SystemPrompt, memoryText, history, message, tokenizer, ) // 6. 检查是否需要调用大模型 if service.IsNoNeedLLMCall(message) { stream.Send(&pb.ChatMessageResponse{ Content: "好的,我听到了。", SessionId: sessionID, IsEnd: false, }) stream.Send(&pb.ChatMessageResponse{ SessionId: sessionID, IsEnd: true, }) return nil } // 7. 调用大模型(流式) streamReader, err := p.chatService.LLMService.StreamChat(ctx, messages) if err != nil { logger.Logger.Error("LLM call failed", zap.Error(err)) // 检查是否是敏感内容错误 if _, ok := err.(*service.SensitiveContentError); ok { logger.Logger.Info("Content blocked by MiniMax safety filter, trying backup model") // 尝试备用模型 streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages) if err != nil { // 备用模型也失败 logger.Logger.Error("Backup model also failed", zap.Error(err)) stream.Send(&pb.ChatMessageResponse{ Content: p.auditService.DefaultSafeResponse(), SessionId: sessionID, IsEnd: true, }) return nil } } else { // 其他错误,尝试备用模型 streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages) if err != nil { logger.Logger.Error("Backup model also failed", zap.Error(err)) stream.Send(&pb.ChatMessageResponse{ Content: "抱歉,服务暂时不可用", SessionId: sessionID, IsEnd: true, Error: err.Error(), }) return err } } } defer streamReader.Close() // 8. 流式处理 var fullResponse string var sentEnd = false for { content, done, err := streamReader.Next() if err != nil { if err == io.EOF { // 流结束,发送 is_end if !sentEnd { stream.Send(&pb.ChatMessageResponse{ SessionId: sessionID, IsEnd: true, }) sentEnd = true } break } logger.Logger.Error("Stream read error", zap.Error(err)) break } // 后置审核(逐 token) if !p.auditService.AuditResponse(content) { logger.Logger.Info("Response blocked by audit") streamReader.Close() // 发送安全回复作为替代 stream.Send(&pb.ChatMessageResponse{ Content: p.auditService.DefaultSafeResponse(), SessionId: sessionID, IsEnd: false, }) stream.Send(&pb.ChatMessageResponse{ SessionId: sessionID, IsEnd: true, }) sentEnd = true return nil } fullResponse += content // 发送 token 给客户端 if err := stream.Send(&pb.ChatMessageResponse{ Content: content, SessionId: sessionID, IsEnd: done, }); err != nil { logger.Logger.Error("Failed to send message to stream", zap.Error(err)) return err } if done { sentEnd = true } } // 9. 保存上下文 newHistory := append(history, model.Message{Role: "user", Content: message}) newHistory = append(newHistory, model.Message{Role: "assistant", Content: fullResponse}) p.memoryService.SaveContext(ctx, sessionID, newHistory, personaID) // 10. 触发记忆提取(每5轮) newTurns := len(newHistory) / 2 logger.Logger.Info("Memory extraction check", zap.Int("message_count", len(newHistory)), zap.Int("turns", newTurns), zap.Bool("should_extract", newTurns >= 5), ) if newTurns >= 5 { logger.Logger.Info("Triggering memory extraction", zap.Int64("user_id", userID)) if err := p.memoryService.ExtractMemory(ctx, userID, newHistory); err != nil { logger.Logger.Error("Failed to extract memory", zap.Error(err)) } else { logger.Logger.Info("Memory extracted successfully", zap.Int64("user_id", userID)) } } logger.Logger.Info("SendMessage completed", zap.Int64("user_id", userID), zap.String("session_id", sessionID), zap.Int("response_length", len(fullResponse)), ) return nil } // GetHistory 获取对话历史 func (p *AIChatProvider) GetHistory(ctx context.Context, req *pb.ChatHistoryRequest) (*pb.ChatHistoryResponse, error) { userID, starID, err := extractUserInfoFromDubboAttachments(ctx) if err != nil { logger.Logger.Error("Failed to extract user info from attachments", zap.Error(err), ) return nil, err } sessionID := req.SessionId if sessionID == "" { sessionID = fmt.Sprintf("%d_%d", userID, starID) } logger.Logger.Info("Received GetHistory request", zap.Int64("user_id", userID), zap.String("session_id", sessionID), ) messages, err := p.memoryService.GetContext(ctx, sessionID) if err != nil { return nil, err } pbMessages := make([]*pb.Message, len(messages)) for i, m := range messages { pbMessages[i] = &pb.Message{ Role: m.Role, Content: m.Content, } } return &pb.ChatHistoryResponse{ History: pbMessages, }, nil } // GetPersonas 获取用户的所有人设 func (p *AIChatProvider) GetPersonas(ctx context.Context, req *pb.GetPersonasRequest) (*pb.PersonaListResponse, error) { userID := req.UserId logger.Logger.Info("Received GetPersonas request", zap.Int64("user_id", userID), ) personas, err := p.personaService.GetPersonas(ctx, userID) if err != nil { return nil, err } pbPersonas := make([]*pb.PersonaInfo, len(personas)) for i, persona := range personas { pbPersonas[i] = &pb.PersonaInfo{ Id: persona.ID, Name: persona.Name, Description: persona.Description, AvatarUrl: persona.AvatarURL, TalkStyle: persona.TalkStyle, IsDefault: persona.IsDefault, CreatedAt: persona.CreatedAt, UpdatedAt: persona.UpdatedAt, } } return &pb.PersonaListResponse{ Personas: pbPersonas, }, nil } // extractUserInfoFromDubboAttachments 从 Dubbo attachments 中提取用户信息 func extractUserInfoFromDubboAttachments(ctx context.Context) (int64, int64, error) { logger.Logger.Debug("Extracting user info from Dubbo attachments", zap.Any("context_type", fmt.Sprintf("%T", ctx)), ) // Try to get any value from context if attachments := ctx.Value(constant.AttachmentKey); attachments != nil { logger.Logger.Debug("Found attachments via constant.AttachmentKey", zap.Any("attachments", attachments), zap.String("type", fmt.Sprintf("%T", attachments)), ) if attMap, ok := attachments.(map[string]interface{}); ok { logger.Logger.Debug("Attachments map content", zap.Any("map", attMap), ) userID := parseIntValue(attMap["user_id"]) starID := parseIntValue(attMap["star_id"]) logger.Logger.Debug("Parsed user info from attachments", zap.Any("user_id_raw", attMap["user_id"]), zap.Int64("user_id", userID), zap.Any("star_id_raw", attMap["star_id"]), zap.Int64("star_id", starID), ) if userID > 0 && starID > 0 { return userID, starID, nil } logger.Logger.Warn("Parsed user_id or star_id is zero", zap.Int64("user_id", userID), zap.Int64("star_id", starID), ) } else { logger.Logger.Warn("Attachments is not map[string]interface{}", zap.String("actual_type", fmt.Sprintf("%T", attachments)), ) } } else { logger.Logger.Warn("ctx.Value(constant.AttachmentKey) returned nil", zap.String("constant_attachment_key", string(constant.AttachmentKey)), ) } // Debug: list all keys in context logger.Logger.Warn("Checking alternative key: 'attachment'") if val := ctx.Value("attachment"); val != nil { logger.Logger.Debug("Found value with key 'attachment'", zap.Any("value", val), zap.String("type", fmt.Sprintf("%T", val)), ) } return 0, 0, fmt.Errorf("user info not found in Dubbo attachments") } // parseIntValue 解析各种类型的值为 int64 func parseIntValue(v interface{}) int64 { switch val := v.(type) { case int64: return val case int: return int64(val) case float64: return int64(val) case string: var result int64 fmt.Sscanf(val, "%d", &result) return result case []string: if len(val) > 0 { var result int64 fmt.Sscanf(val[0], "%d", &result) return result } case []interface{}: if len(val) > 0 { switch s := val[0].(type) { case string: var result int64 fmt.Sscanf(s, "%d", &result) return result case int: return int64(s) case int64: return s } } } return 0 }