443 lines
12 KiB
Go
443 lines
12 KiB
Go
package provider
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
|
||
"dubbo.apache.org/dubbo-go/v3/common/constant"
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"github.com/topfans/backend/services/aiChatService/model"
|
||
"github.com/topfans/backend/services/aiChatService/service"
|
||
pb "github.com/topfans/backend/pkg/proto/ai_chat"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// AIChatProvider AI Chat 服务 Provider 实现
|
||
type AIChatProvider struct {
|
||
chatService *service.ChatService
|
||
personaService *service.PersonaService
|
||
memoryService *service.MemoryService
|
||
auditService *service.AuditService
|
||
}
|
||
|
||
// 确保 AIChatProvider 实现了 AIChatServiceHandler 接口
|
||
var _ pb.AIChatServiceHandler = (*AIChatProvider)(nil)
|
||
|
||
// NewAIChatProvider 创建 AIChatProvider 实例
|
||
func NewAIChatProvider(
|
||
chatService *service.ChatService,
|
||
personaService *service.PersonaService,
|
||
memoryService *service.MemoryService,
|
||
auditService *service.AuditService,
|
||
) *AIChatProvider {
|
||
return &AIChatProvider{
|
||
chatService: chatService,
|
||
personaService: personaService,
|
||
memoryService: memoryService,
|
||
auditService: auditService,
|
||
}
|
||
}
|
||
|
||
// InitSession 初始化会话,返回欢迎消息
|
||
func (p *AIChatProvider) InitSession(ctx context.Context, req *pb.InitSessionRequest) (*pb.InitSessionResponse, error) {
|
||
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
|
||
sessionID := req.SessionId
|
||
if sessionID == "" {
|
||
sessionID = fmt.Sprintf("%d_%d", userID, starID)
|
||
}
|
||
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to extract user info from attachments",
|
||
zap.Error(err),
|
||
)
|
||
return nil, err
|
||
}
|
||
|
||
logger.Logger.Info("Received InitSession request",
|
||
zap.Int64("user_id", userID),
|
||
zap.String("session_id", sessionID),
|
||
)
|
||
|
||
// 获取欢迎消息
|
||
welcomeMessage := p.chatService.GetWelcomeMessage(sessionID, userID, starID)
|
||
|
||
return &pb.InitSessionResponse{
|
||
WelcomeMessage: welcomeMessage,
|
||
SessionId: sessionID,
|
||
}, nil
|
||
}
|
||
|
||
// SendMessage 发送消息(流式返回)
|
||
func (p *AIChatProvider) SendMessage(ctx context.Context, req *pb.ChatMessageRequest, stream pb.AIChatService_SendMessageServer) error {
|
||
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
|
||
sessionID := req.SessionId
|
||
if sessionID == "" {
|
||
sessionID = fmt.Sprintf("%d_%d", userID, starID)
|
||
}
|
||
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to extract user info from attachments",
|
||
zap.Error(err),
|
||
)
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
Content: "user authentication required",
|
||
SessionId: sessionID,
|
||
IsEnd: true,
|
||
Error: err.Error(),
|
||
})
|
||
return err
|
||
}
|
||
if sessionID == "" {
|
||
// 如果没有 sessionID,生成一个
|
||
sessionID = fmt.Sprintf("%d_%d", userID, starID)
|
||
}
|
||
|
||
message := req.Message
|
||
personaID := req.PersonaId
|
||
|
||
logger.Logger.Info("Received SendMessage request",
|
||
zap.Int64("user_id", userID),
|
||
zap.String("session_id", sessionID),
|
||
zap.String("message", message),
|
||
)
|
||
|
||
// 1. 前置审核
|
||
if !p.auditService.AuditText(message) {
|
||
logger.Logger.Info("Message blocked by audit")
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
Content: p.auditService.DefaultSafeResponse(),
|
||
SessionId: sessionID,
|
||
IsEnd: false,
|
||
})
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
SessionId: sessionID,
|
||
IsEnd: true,
|
||
})
|
||
return nil
|
||
}
|
||
|
||
// 2. 获取人设
|
||
persona, err := p.personaService.GetPersonaOrDefault(ctx, userID, personaID)
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to get persona", zap.Error(err))
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
Content: err.Error(),
|
||
IsEnd: true,
|
||
Error: err.Error(),
|
||
})
|
||
return err
|
||
}
|
||
|
||
// 3. 记忆召回
|
||
memoryText, _ := p.memoryService.RecallMemories(ctx, userID, message, 5)
|
||
|
||
// 4. 获取对话历史
|
||
history, _ := p.memoryService.GetContext(ctx, sessionID)
|
||
|
||
// 5. 构建 Prompt
|
||
tokenizer := &service.Tokenizer{}
|
||
messages, _ := service.BuildPrompt(
|
||
persona.SystemPrompt,
|
||
memoryText,
|
||
history,
|
||
message,
|
||
tokenizer,
|
||
)
|
||
|
||
// 6. 检查是否需要调用大模型
|
||
if service.IsNoNeedLLMCall(message) {
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
Content: "好的,我听到了。",
|
||
SessionId: sessionID,
|
||
IsEnd: false,
|
||
})
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
SessionId: sessionID,
|
||
IsEnd: true,
|
||
})
|
||
return nil
|
||
}
|
||
|
||
// 7. 调用大模型(流式)
|
||
streamReader, err := p.chatService.LLMService.StreamChat(ctx, messages)
|
||
if err != nil {
|
||
logger.Logger.Error("LLM call failed", zap.Error(err))
|
||
|
||
// 检查是否是敏感内容错误
|
||
if _, ok := err.(*service.SensitiveContentError); ok {
|
||
logger.Logger.Info("Content blocked by MiniMax safety filter, trying backup model")
|
||
// 尝试备用模型
|
||
streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages)
|
||
if err != nil {
|
||
// 备用模型也失败
|
||
logger.Logger.Error("Backup model also failed", zap.Error(err))
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
Content: p.auditService.DefaultSafeResponse(),
|
||
SessionId: sessionID,
|
||
IsEnd: true,
|
||
})
|
||
return nil
|
||
}
|
||
} else {
|
||
// 其他错误,尝试备用模型
|
||
streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages)
|
||
if err != nil {
|
||
logger.Logger.Error("Backup model also failed", zap.Error(err))
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
Content: "抱歉,服务暂时不可用",
|
||
SessionId: sessionID,
|
||
IsEnd: true,
|
||
Error: err.Error(),
|
||
})
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
defer streamReader.Close()
|
||
|
||
// 8. 流式处理
|
||
var fullResponse string
|
||
var sentEnd = false
|
||
for {
|
||
content, done, err := streamReader.Next()
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
// 流结束,发送 is_end
|
||
if !sentEnd {
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
SessionId: sessionID,
|
||
IsEnd: true,
|
||
})
|
||
sentEnd = true
|
||
}
|
||
break
|
||
}
|
||
logger.Logger.Error("Stream read error", zap.Error(err))
|
||
break
|
||
}
|
||
|
||
// 后置审核(逐 token)
|
||
if !p.auditService.AuditResponse(content) {
|
||
logger.Logger.Info("Response blocked by audit")
|
||
streamReader.Close()
|
||
// 发送安全回复作为替代
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
Content: p.auditService.DefaultSafeResponse(),
|
||
SessionId: sessionID,
|
||
IsEnd: false,
|
||
})
|
||
stream.Send(&pb.ChatMessageResponse{
|
||
SessionId: sessionID,
|
||
IsEnd: true,
|
||
})
|
||
sentEnd = true
|
||
return nil
|
||
}
|
||
|
||
fullResponse += content
|
||
|
||
// 发送 token 给客户端
|
||
if err := stream.Send(&pb.ChatMessageResponse{
|
||
Content: content,
|
||
SessionId: sessionID,
|
||
IsEnd: done,
|
||
}); err != nil {
|
||
logger.Logger.Error("Failed to send message to stream", zap.Error(err))
|
||
return err
|
||
}
|
||
if done {
|
||
sentEnd = true
|
||
}
|
||
}
|
||
|
||
// 9. 保存上下文
|
||
newHistory := append(history, model.Message{Role: "user", Content: message})
|
||
newHistory = append(newHistory, model.Message{Role: "assistant", Content: fullResponse})
|
||
p.memoryService.SaveContext(ctx, sessionID, newHistory, personaID)
|
||
|
||
// 10. 触发记忆提取(每5轮)
|
||
newTurns := len(newHistory) / 2
|
||
logger.Logger.Info("Memory extraction check",
|
||
zap.Int("message_count", len(newHistory)),
|
||
zap.Int("turns", newTurns),
|
||
zap.Bool("should_extract", newTurns >= 5),
|
||
)
|
||
if newTurns >= 5 {
|
||
logger.Logger.Info("Triggering memory extraction", zap.Int64("user_id", userID))
|
||
if err := p.memoryService.ExtractMemory(ctx, userID, newHistory); err != nil {
|
||
logger.Logger.Error("Failed to extract memory", zap.Error(err))
|
||
} else {
|
||
logger.Logger.Info("Memory extracted successfully", zap.Int64("user_id", userID))
|
||
}
|
||
}
|
||
|
||
logger.Logger.Info("SendMessage completed",
|
||
zap.Int64("user_id", userID),
|
||
zap.String("session_id", sessionID),
|
||
zap.Int("response_length", len(fullResponse)),
|
||
)
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetHistory 获取对话历史
|
||
func (p *AIChatProvider) GetHistory(ctx context.Context, req *pb.ChatHistoryRequest) (*pb.ChatHistoryResponse, error) {
|
||
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to extract user info from attachments",
|
||
zap.Error(err),
|
||
)
|
||
return nil, err
|
||
}
|
||
|
||
sessionID := req.SessionId
|
||
if sessionID == "" {
|
||
sessionID = fmt.Sprintf("%d_%d", userID, starID)
|
||
}
|
||
|
||
logger.Logger.Info("Received GetHistory request",
|
||
zap.Int64("user_id", userID),
|
||
zap.String("session_id", sessionID),
|
||
)
|
||
|
||
messages, err := p.memoryService.GetContext(ctx, sessionID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
pbMessages := make([]*pb.Message, len(messages))
|
||
for i, m := range messages {
|
||
pbMessages[i] = &pb.Message{
|
||
Role: m.Role,
|
||
Content: m.Content,
|
||
}
|
||
}
|
||
|
||
return &pb.ChatHistoryResponse{
|
||
History: pbMessages,
|
||
}, nil
|
||
}
|
||
|
||
// GetPersonas 获取用户的所有人设
|
||
func (p *AIChatProvider) GetPersonas(ctx context.Context, req *pb.GetPersonasRequest) (*pb.PersonaListResponse, error) {
|
||
userID := req.UserId
|
||
|
||
logger.Logger.Info("Received GetPersonas request",
|
||
zap.Int64("user_id", userID),
|
||
)
|
||
|
||
personas, err := p.personaService.GetPersonas(ctx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
pbPersonas := make([]*pb.PersonaInfo, len(personas))
|
||
for i, persona := range personas {
|
||
pbPersonas[i] = &pb.PersonaInfo{
|
||
Id: persona.ID,
|
||
Name: persona.Name,
|
||
Description: persona.Description,
|
||
AvatarUrl: persona.AvatarURL,
|
||
TalkStyle: persona.TalkStyle,
|
||
IsDefault: persona.IsDefault,
|
||
CreatedAt: persona.CreatedAt,
|
||
UpdatedAt: persona.UpdatedAt,
|
||
}
|
||
}
|
||
|
||
return &pb.PersonaListResponse{
|
||
Personas: pbPersonas,
|
||
}, nil
|
||
}
|
||
|
||
// extractUserInfoFromDubboAttachments 从 Dubbo attachments 中提取用户信息
|
||
func extractUserInfoFromDubboAttachments(ctx context.Context) (int64, int64, error) {
|
||
logger.Logger.Debug("Extracting user info from Dubbo attachments",
|
||
zap.Any("context_type", fmt.Sprintf("%T", ctx)),
|
||
)
|
||
|
||
// Try to get any value from context
|
||
if attachments := ctx.Value(constant.AttachmentKey); attachments != nil {
|
||
logger.Logger.Debug("Found attachments via constant.AttachmentKey",
|
||
zap.Any("attachments", attachments),
|
||
zap.String("type", fmt.Sprintf("%T", attachments)),
|
||
)
|
||
if attMap, ok := attachments.(map[string]interface{}); ok {
|
||
logger.Logger.Debug("Attachments map content",
|
||
zap.Any("map", attMap),
|
||
)
|
||
userID := parseIntValue(attMap["user_id"])
|
||
starID := parseIntValue(attMap["star_id"])
|
||
logger.Logger.Debug("Parsed user info from attachments",
|
||
zap.Any("user_id_raw", attMap["user_id"]),
|
||
zap.Int64("user_id", userID),
|
||
zap.Any("star_id_raw", attMap["star_id"]),
|
||
zap.Int64("star_id", starID),
|
||
)
|
||
if userID > 0 && starID > 0 {
|
||
return userID, starID, nil
|
||
}
|
||
logger.Logger.Warn("Parsed user_id or star_id is zero",
|
||
zap.Int64("user_id", userID),
|
||
zap.Int64("star_id", starID),
|
||
)
|
||
} else {
|
||
logger.Logger.Warn("Attachments is not map[string]interface{}",
|
||
zap.String("actual_type", fmt.Sprintf("%T", attachments)),
|
||
)
|
||
}
|
||
} else {
|
||
logger.Logger.Warn("ctx.Value(constant.AttachmentKey) returned nil",
|
||
zap.String("constant_attachment_key", string(constant.AttachmentKey)),
|
||
)
|
||
}
|
||
|
||
// Debug: list all keys in context
|
||
logger.Logger.Warn("Checking alternative key: 'attachment'")
|
||
if val := ctx.Value("attachment"); val != nil {
|
||
logger.Logger.Debug("Found value with key 'attachment'",
|
||
zap.Any("value", val),
|
||
zap.String("type", fmt.Sprintf("%T", val)),
|
||
)
|
||
}
|
||
|
||
return 0, 0, fmt.Errorf("user info not found in Dubbo attachments")
|
||
}
|
||
|
||
// parseIntValue 解析各种类型的值为 int64
|
||
func parseIntValue(v interface{}) int64 {
|
||
switch val := v.(type) {
|
||
case int64:
|
||
return val
|
||
case int:
|
||
return int64(val)
|
||
case float64:
|
||
return int64(val)
|
||
case string:
|
||
var result int64
|
||
fmt.Sscanf(val, "%d", &result)
|
||
return result
|
||
case []string:
|
||
if len(val) > 0 {
|
||
var result int64
|
||
fmt.Sscanf(val[0], "%d", &result)
|
||
return result
|
||
}
|
||
case []interface{}:
|
||
if len(val) > 0 {
|
||
switch s := val[0].(type) {
|
||
case string:
|
||
var result int64
|
||
fmt.Sscanf(s, "%d", &result)
|
||
return result
|
||
case int:
|
||
return int64(s)
|
||
case int64:
|
||
return s
|
||
}
|
||
}
|
||
}
|
||
return 0
|
||
}
|