package socket import ( "context" "encoding/json" "errors" "fmt" "net/http" "strconv" "strings" "sync" "time" "dubbo.apache.org/dubbo-go/v3/client" "dubbo.apache.org/dubbo-go/v3/common/constant" "github.com/gorilla/websocket" "github.com/topfans/backend/pkg/jwt" "github.com/topfans/backend/pkg/logger" "go.uber.org/zap" pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // 允许跨域 }, } // Hub 管理所有 AI Chat WebSocket 连接 type Hub struct { // 用户连接映射: userId -> *Connection clients map[int64]*Connection // Dubbo 客户端 aiChatClient *client.Client // WebSocket 配置 aiChatPath string mu sync.RWMutex } // Connection WebSocket 连接 type Connection struct { UserID int64 StarID int64 Conn *websocket.Conn Send chan []byte Hub *Hub writeMu sync.Mutex // 保护并发写操作 } // writeJSON 线程安全的 WriteJSON 封装 func (c *Connection) writeJSON(data interface{}) error { c.writeMu.Lock() defer c.writeMu.Unlock() return c.Conn.WriteJSON(data) } // sendError 发送错误消息 func (c *Connection) sendError(code, message string) { c.Send <- []byte(fmt.Sprintf(`{"type":"error","code":"%s","message":"%s","session_id":"%d_%d"}`, code, message, c.UserID, c.StarID)) } // NewHub 创建 Hub 实例 func NewHub(aiChatClient *client.Client, aiChatPath string) *Hub { return &Hub{ clients: make(map[int64]*Connection), aiChatClient: aiChatClient, aiChatPath: aiChatPath, } } // HandleWebSocket 处理 WebSocket 连接 func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { // 从 URL 参数获取 token token := r.URL.Query().Get("token") if token == "" { // 尝试从 query 参数获取(URL 编码后) token = r.URL.Query().Get("token") } // 解析 token 获取用户信息 userID, starID, err := h.validateToken(token) if err != nil { logger.Logger.Error("WebSocket token validation failed", zap.Error(err)) // 返回 401 Unauthorized,而不是尝试升级 WebSocket w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(map[string]interface{}{ "type": "auth_response", "success": false, "error": "invalid_token", }) return } // 升级为 WebSocket 连接 conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logger.Logger.Error("WebSocket upgrade failed", zap.Error(err)) return } connection := &Connection{ UserID: userID, StarID: starID, Conn: conn, Send: make(chan []byte, 256), Hub: h, } // 注册连接 h.mu.Lock() h.clients[userID] = connection h.mu.Unlock() logger.Logger.Info("WebSocket connection established", zap.Int64("user_id", userID), zap.Int64("star_id", starID), ) // 发送鉴权成功响应 authResp := map[string]interface{}{ "type": "auth_response", "success": true, "user_id": userID, "star_id": starID, } conn.WriteJSON(authResp) // 启动读 goroutine go connection.readPump() // 启动写 goroutine go connection.writePump() } // validateToken 验证 token(使用 JWT 解析) func (h *Hub) validateToken(token string) (int64, int64, error) { // 去掉 "Bearer_" 前缀 if strings.HasPrefix(token, "Bearer_") { token = strings.TrimPrefix(token, "Bearer_") } // 解析 JWT token claims, err := jwt.ParseToken(token) if err != nil { return 0, 0, fmt.Errorf("failed to parse token: %w", err) } userID := claims.UserID starID := claims.StarID if userID == 0 { return 0, 0, errors.New("invalid user id") } return userID, starID, nil } // readPump 读取客户端消息 func (c *Connection) readPump() { defer func() { c.Hub.mu.Lock() delete(c.Hub.clients, c.UserID) c.Hub.mu.Unlock() c.Conn.Close() }() c.Conn.SetReadLimit(512 * 1024) // 512KB c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) c.Conn.SetPongHandler(func(string) error { c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) for { _, message, err := c.Conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { logger.Logger.Error("WebSocket read error", zap.Error(err)) } break } // 解析消息 var msg map[string]interface{} if err := json.Unmarshal(message, &msg); err != nil { logger.Logger.Error("Failed to parse message", zap.Error(err)) continue } // 处理消息 c.handleMessage(msg) } } // writePump 写消息到客户端 func (c *Connection) writePump() { ticker := time.NewTicker(30 * time.Second) defer func() { ticker.Stop() c.Conn.Close() }() for { select { case message, ok := <-c.Send: c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if !ok { c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.Conn.NextWriter(websocket.TextMessage) if err != nil { return } w.Write(message) if err := w.Close(); err != nil { return } case <-ticker.C: c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // handleMessage 处理客户端消息 func (c *Connection) handleMessage(msg map[string]interface{}) { action, _ := msg["action"].(string) sessionID, _ := msg["session_id"].(string) message, _ := msg["message"].(string) personaID, _ := msg["persona_id"].(string) session := sessionID if session == "" { session = fmt.Sprintf("%d_%d", c.UserID, c.StarID) } switch action { case "ping": // 心跳 - 通过 Send 通道发送,避免并发写 c.Send <- []byte(`{"type":"pong"}`) case "init_session": // 初始化会话,返回欢迎消息 go c.initSession(session) case "send_message": // 发送消息 go c.sendMessage(session, message, personaID) case "get_history": // 获取历史 go c.getHistory(session) case "get_personas": // 获取人设列表 go c.getPersonas() default: logger.Logger.Warn("Unknown action", zap.String("action", action)) } } // initSession 初始化会话,返回欢迎消息 func (c *Connection) initSession(sessionID string) { aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) if err != nil { logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) c.sendError("SERVICE_ERROR", "服务暂时不可用") return } // 创建带 Attachments 的 context userIDStr := strconv.FormatInt(c.UserID, 10) starIDStr := strconv.FormatInt(c.StarID, 10) attachments := map[string]interface{}{ "user_id": userIDStr, "star_id": starIDStr, } ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) req := &pbAIChat.InitSessionRequest{ SessionId: sessionID, UserId: c.UserID, } resp, err := aiChatService.InitSession(ctx, req) if err != nil { logger.Logger.Error("InitSession call failed", zap.Error(err)) c.sendError("SERVICE_ERROR", "初始化会话失败") return } // 发送欢迎消息 c.writeJSON(map[string]interface{}{ "type": "message", "session_id": sessionID, "content": resp.WelcomeMessage, "is_end": true, }) } // sendMessage 发送消息到 AI Chat Service func (c *Connection) sendMessage(sessionID, message, personaID string) { // 创建 Dubbo 客户端调用 aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) if err != nil { logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) c.sendError("SERVICE_ERROR", "服务暂时不可用") return } // 创建带 Attachments 的 context userIDStr := strconv.FormatInt(c.UserID, 10) starIDStr := strconv.FormatInt(c.StarID, 10) attachments := map[string]interface{}{ "user_id": userIDStr, "star_id": starIDStr, } ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) req := &pbAIChat.ChatMessageRequest{ SessionId: sessionID, Message: message, PersonaId: personaID, UserId: c.UserID, } // 流式调用 stream, err := aiChatService.SendMessage(ctx, req) if err != nil { logger.Logger.Error("SendMessage call failed", zap.Error(err)) c.sendError("SERVICE_ERROR", "消息发送失败") return } // 接收流式响应 for { if !stream.Recv() { break } resp := stream.Msg() // 发送消息到客户端 respMsg := map[string]interface{}{ "type": "message", "content": resp.Content, "session_id": resp.SessionId, "is_end": resp.IsEnd, } if resp.Error != "" { respMsg["error"] = resp.Error } if err := c.writeJSON(respMsg); err != nil { logger.Logger.Error("Failed to send message to client", zap.Error(err)) break } if resp.IsEnd { break } } } // getHistory 获取对话历史 func (c *Connection) getHistory(sessionID string) { aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) if err != nil { logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) c.sendError("SERVICE_ERROR", "服务暂时不可用") return } // 创建带 Attachments 的 context userIDStr := strconv.FormatInt(c.UserID, 10) starIDStr := strconv.FormatInt(c.StarID, 10) attachments := map[string]interface{}{ "user_id": userIDStr, "star_id": starIDStr, } ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) req := &pbAIChat.ChatHistoryRequest{ SessionId: sessionID, Limit: 20, } resp, err := aiChatService.GetHistory(ctx, req) if err != nil { logger.Logger.Error("GetHistory call failed", zap.Error(err)) c.sendError("SERVICE_ERROR", "获取历史失败") return } // 转换历史消息 history := make([]map[string]string, len(resp.History)) for i, m := range resp.History { history[i] = map[string]string{ "role": m.Role, "content": m.Content, } } // 发送响应 c.writeJSON(map[string]interface{}{ "type": "history_response", "session_id": sessionID, "history": history, }) } // getPersonas 获取人设列表 func (c *Connection) getPersonas() { aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) if err != nil { logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) c.sendError("SERVICE_ERROR", "服务暂时不可用") return } ctx := context.Background() req := &pbAIChat.GetPersonasRequest{ UserId: c.UserID, } resp, err := aiChatService.GetPersonas(ctx, req) if err != nil { logger.Logger.Error("GetPersonas call failed", zap.Error(err)) c.sendError("SERVICE_ERROR", "获取人设失败") return } // 转换人设信息 personas := make([]map[string]interface{}, len(resp.Personas)) for i, p := range resp.Personas { personas[i] = map[string]interface{}{ "id": p.Id, "name": p.Name, "description": p.Description, "avatar_url": p.AvatarUrl, "talk_style": p.TalkStyle, "is_default": p.IsDefault, "created_at": p.CreatedAt, "updated_at": p.UpdatedAt, } } // 发送响应 c.writeJSON(map[string]interface{}{ "type": "personas_response", "personas": personas, }) } // sendError 发送错误消息 (使用 Send 通道) // Close 关闭连接 func (h *Hub) Close() { h.mu.Lock() defer h.mu.Unlock() for _, conn := range h.clients { conn.Conn.Close() } }