478 lines
11 KiB
Go
478 lines
11 KiB
Go
package socket
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"dubbo.apache.org/dubbo-go/v3/client"
|
||
"dubbo.apache.org/dubbo-go/v3/common/constant"
|
||
"github.com/gorilla/websocket"
|
||
"github.com/topfans/backend/pkg/jwt"
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"go.uber.org/zap"
|
||
|
||
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
|
||
)
|
||
|
||
var upgrader = websocket.Upgrader{
|
||
ReadBufferSize: 1024,
|
||
WriteBufferSize: 1024,
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true // 允许跨域
|
||
},
|
||
}
|
||
|
||
// Hub 管理所有 AI Chat WebSocket 连接
|
||
type Hub struct {
|
||
// 用户连接映射: userId -> *Connection
|
||
clients map[int64]*Connection
|
||
|
||
// Dubbo 客户端
|
||
aiChatClient *client.Client
|
||
|
||
// WebSocket 配置
|
||
aiChatPath string
|
||
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// Connection WebSocket 连接
|
||
type Connection struct {
|
||
UserID int64
|
||
StarID int64
|
||
Conn *websocket.Conn
|
||
Send chan []byte
|
||
Hub *Hub
|
||
writeMu sync.Mutex // 保护并发写操作
|
||
}
|
||
|
||
// writeJSON 线程安全的 WriteJSON 封装
|
||
func (c *Connection) writeJSON(data interface{}) error {
|
||
c.writeMu.Lock()
|
||
defer c.writeMu.Unlock()
|
||
return c.Conn.WriteJSON(data)
|
||
}
|
||
|
||
// sendError 发送错误消息
|
||
func (c *Connection) sendError(code, message string) {
|
||
c.Send <- []byte(fmt.Sprintf(`{"type":"error","code":"%s","message":"%s","session_id":"%d_%d"}`,
|
||
code, message, c.UserID, c.StarID))
|
||
}
|
||
|
||
// NewHub 创建 Hub 实例
|
||
func NewHub(aiChatClient *client.Client, aiChatPath string) *Hub {
|
||
return &Hub{
|
||
clients: make(map[int64]*Connection),
|
||
aiChatClient: aiChatClient,
|
||
aiChatPath: aiChatPath,
|
||
}
|
||
}
|
||
|
||
// HandleWebSocket 处理 WebSocket 连接
|
||
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||
// 从 URL 参数获取 token
|
||
token := r.URL.Query().Get("token")
|
||
if token == "" {
|
||
// 尝试从 query 参数获取(URL 编码后)
|
||
token = r.URL.Query().Get("token")
|
||
}
|
||
|
||
// 解析 token 获取用户信息
|
||
userID, starID, err := h.validateToken(token)
|
||
if err != nil {
|
||
logger.Logger.Error("WebSocket token validation failed", zap.Error(err))
|
||
// 返回 401 Unauthorized,而不是尝试升级 WebSocket
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||
"type": "auth_response",
|
||
"success": false,
|
||
"error": "invalid_token",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 升级为 WebSocket 连接
|
||
conn, err := upgrader.Upgrade(w, r, nil)
|
||
if err != nil {
|
||
logger.Logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||
return
|
||
}
|
||
|
||
connection := &Connection{
|
||
UserID: userID,
|
||
StarID: starID,
|
||
Conn: conn,
|
||
Send: make(chan []byte, 256),
|
||
Hub: h,
|
||
}
|
||
|
||
// 注册连接
|
||
h.mu.Lock()
|
||
h.clients[userID] = connection
|
||
h.mu.Unlock()
|
||
|
||
logger.Logger.Info("WebSocket connection established",
|
||
zap.Int64("user_id", userID),
|
||
zap.Int64("star_id", starID),
|
||
)
|
||
|
||
// 发送鉴权成功响应
|
||
authResp := map[string]interface{}{
|
||
"type": "auth_response",
|
||
"success": true,
|
||
"user_id": userID,
|
||
"star_id": starID,
|
||
}
|
||
conn.WriteJSON(authResp)
|
||
|
||
// 启动读 goroutine
|
||
go connection.readPump()
|
||
|
||
// 启动写 goroutine
|
||
go connection.writePump()
|
||
}
|
||
|
||
// validateToken 验证 token(使用 JWT 解析)
|
||
func (h *Hub) validateToken(token string) (int64, int64, error) {
|
||
// 去掉 "Bearer_" 前缀
|
||
if strings.HasPrefix(token, "Bearer_") {
|
||
token = strings.TrimPrefix(token, "Bearer_")
|
||
}
|
||
|
||
// 解析 JWT token
|
||
claims, err := jwt.ParseToken(token)
|
||
if err != nil {
|
||
return 0, 0, fmt.Errorf("failed to parse token: %w", err)
|
||
}
|
||
|
||
userID := claims.UserID
|
||
starID := claims.StarID
|
||
|
||
if userID == 0 {
|
||
return 0, 0, errors.New("invalid user id")
|
||
}
|
||
|
||
return userID, starID, nil
|
||
}
|
||
|
||
// readPump 读取客户端消息
|
||
func (c *Connection) readPump() {
|
||
defer func() {
|
||
c.Hub.mu.Lock()
|
||
delete(c.Hub.clients, c.UserID)
|
||
c.Hub.mu.Unlock()
|
||
c.Conn.Close()
|
||
}()
|
||
|
||
c.Conn.SetReadLimit(512 * 1024) // 512KB
|
||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||
c.Conn.SetPongHandler(func(string) error {
|
||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||
return nil
|
||
})
|
||
|
||
for {
|
||
_, message, err := c.Conn.ReadMessage()
|
||
if err != nil {
|
||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||
logger.Logger.Error("WebSocket read error", zap.Error(err))
|
||
}
|
||
break
|
||
}
|
||
|
||
// 解析消息
|
||
var msg map[string]interface{}
|
||
if err := json.Unmarshal(message, &msg); err != nil {
|
||
logger.Logger.Error("Failed to parse message", zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
// 处理消息
|
||
c.handleMessage(msg)
|
||
}
|
||
}
|
||
|
||
// writePump 写消息到客户端
|
||
func (c *Connection) writePump() {
|
||
ticker := time.NewTicker(30 * time.Second)
|
||
defer func() {
|
||
ticker.Stop()
|
||
c.Conn.Close()
|
||
}()
|
||
|
||
for {
|
||
select {
|
||
case message, ok := <-c.Send:
|
||
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||
if !ok {
|
||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||
return
|
||
}
|
||
|
||
w, err := c.Conn.NextWriter(websocket.TextMessage)
|
||
if err != nil {
|
||
return
|
||
}
|
||
w.Write(message)
|
||
|
||
if err := w.Close(); err != nil {
|
||
return
|
||
}
|
||
|
||
case <-ticker.C:
|
||
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// handleMessage 处理客户端消息
|
||
func (c *Connection) handleMessage(msg map[string]interface{}) {
|
||
action, _ := msg["action"].(string)
|
||
sessionID, _ := msg["session_id"].(string)
|
||
message, _ := msg["message"].(string)
|
||
personaID, _ := msg["persona_id"].(string)
|
||
|
||
session := sessionID
|
||
if session == "" {
|
||
session = fmt.Sprintf("%d_%d", c.UserID, c.StarID)
|
||
}
|
||
|
||
switch action {
|
||
case "ping":
|
||
// 心跳 - 通过 Send 通道发送,避免并发写
|
||
c.Send <- []byte(`{"type":"pong"}`)
|
||
|
||
case "init_session":
|
||
// 初始化会话,返回欢迎消息
|
||
go c.initSession(session)
|
||
|
||
case "send_message":
|
||
// 发送消息
|
||
go c.sendMessage(session, message, personaID)
|
||
|
||
case "get_history":
|
||
// 获取历史
|
||
go c.getHistory(session)
|
||
|
||
case "get_personas":
|
||
// 获取人设列表
|
||
go c.getPersonas()
|
||
|
||
default:
|
||
logger.Logger.Warn("Unknown action", zap.String("action", action))
|
||
}
|
||
}
|
||
|
||
// initSession 初始化会话,返回欢迎消息
|
||
func (c *Connection) initSession(sessionID string) {
|
||
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "服务暂时不可用")
|
||
return
|
||
}
|
||
|
||
// 创建带 Attachments 的 context
|
||
userIDStr := strconv.FormatInt(c.UserID, 10)
|
||
starIDStr := strconv.FormatInt(c.StarID, 10)
|
||
attachments := map[string]interface{}{
|
||
"user_id": userIDStr,
|
||
"star_id": starIDStr,
|
||
}
|
||
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
|
||
|
||
req := &pbAIChat.InitSessionRequest{
|
||
SessionId: sessionID,
|
||
UserId: c.UserID,
|
||
}
|
||
|
||
resp, err := aiChatService.InitSession(ctx, req)
|
||
if err != nil {
|
||
logger.Logger.Error("InitSession call failed", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "初始化会话失败")
|
||
return
|
||
}
|
||
|
||
// 发送欢迎消息
|
||
c.writeJSON(map[string]interface{}{
|
||
"type": "message",
|
||
"session_id": sessionID,
|
||
"content": resp.WelcomeMessage,
|
||
"is_end": true,
|
||
})
|
||
}
|
||
|
||
// sendMessage 发送消息到 AI Chat Service
|
||
func (c *Connection) sendMessage(sessionID, message, personaID string) {
|
||
// 创建 Dubbo 客户端调用
|
||
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "服务暂时不可用")
|
||
return
|
||
}
|
||
|
||
// 创建带 Attachments 的 context
|
||
userIDStr := strconv.FormatInt(c.UserID, 10)
|
||
starIDStr := strconv.FormatInt(c.StarID, 10)
|
||
attachments := map[string]interface{}{
|
||
"user_id": userIDStr,
|
||
"star_id": starIDStr,
|
||
}
|
||
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
|
||
|
||
req := &pbAIChat.ChatMessageRequest{
|
||
SessionId: sessionID,
|
||
Message: message,
|
||
PersonaId: personaID,
|
||
UserId: c.UserID,
|
||
}
|
||
|
||
// 流式调用
|
||
stream, err := aiChatService.SendMessage(ctx, req)
|
||
if err != nil {
|
||
logger.Logger.Error("SendMessage call failed", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "消息发送失败")
|
||
return
|
||
}
|
||
|
||
// 接收流式响应
|
||
for {
|
||
if !stream.Recv() {
|
||
break
|
||
}
|
||
resp := stream.Msg()
|
||
|
||
// 发送消息到客户端
|
||
respMsg := map[string]interface{}{
|
||
"type": "message",
|
||
"content": resp.Content,
|
||
"session_id": resp.SessionId,
|
||
"is_end": resp.IsEnd,
|
||
}
|
||
if resp.Error != "" {
|
||
respMsg["error"] = resp.Error
|
||
}
|
||
|
||
if err := c.writeJSON(respMsg); err != nil {
|
||
logger.Logger.Error("Failed to send message to client", zap.Error(err))
|
||
break
|
||
}
|
||
|
||
if resp.IsEnd {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// getHistory 获取对话历史
|
||
func (c *Connection) getHistory(sessionID string) {
|
||
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "服务暂时不可用")
|
||
return
|
||
}
|
||
|
||
// 创建带 Attachments 的 context
|
||
userIDStr := strconv.FormatInt(c.UserID, 10)
|
||
starIDStr := strconv.FormatInt(c.StarID, 10)
|
||
attachments := map[string]interface{}{
|
||
"user_id": userIDStr,
|
||
"star_id": starIDStr,
|
||
}
|
||
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
|
||
|
||
req := &pbAIChat.ChatHistoryRequest{
|
||
SessionId: sessionID,
|
||
Limit: 20,
|
||
}
|
||
|
||
resp, err := aiChatService.GetHistory(ctx, req)
|
||
if err != nil {
|
||
logger.Logger.Error("GetHistory call failed", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "获取历史失败")
|
||
return
|
||
}
|
||
|
||
// 转换历史消息
|
||
history := make([]map[string]string, len(resp.History))
|
||
for i, m := range resp.History {
|
||
history[i] = map[string]string{
|
||
"role": m.Role,
|
||
"content": m.Content,
|
||
}
|
||
}
|
||
|
||
// 发送响应
|
||
c.writeJSON(map[string]interface{}{
|
||
"type": "history_response",
|
||
"session_id": sessionID,
|
||
"history": history,
|
||
})
|
||
}
|
||
|
||
// getPersonas 获取人设列表
|
||
func (c *Connection) getPersonas() {
|
||
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
|
||
if err != nil {
|
||
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "服务暂时不可用")
|
||
return
|
||
}
|
||
|
||
ctx := context.Background()
|
||
req := &pbAIChat.GetPersonasRequest{
|
||
UserId: c.UserID,
|
||
}
|
||
|
||
resp, err := aiChatService.GetPersonas(ctx, req)
|
||
if err != nil {
|
||
logger.Logger.Error("GetPersonas call failed", zap.Error(err))
|
||
c.sendError("SERVICE_ERROR", "获取人设失败")
|
||
return
|
||
}
|
||
|
||
// 转换人设信息
|
||
personas := make([]map[string]interface{}, len(resp.Personas))
|
||
for i, p := range resp.Personas {
|
||
personas[i] = map[string]interface{}{
|
||
"id": p.Id,
|
||
"name": p.Name,
|
||
"description": p.Description,
|
||
"avatar_url": p.AvatarUrl,
|
||
"talk_style": p.TalkStyle,
|
||
"is_default": p.IsDefault,
|
||
"created_at": p.CreatedAt,
|
||
"updated_at": p.UpdatedAt,
|
||
}
|
||
}
|
||
|
||
// 发送响应
|
||
c.writeJSON(map[string]interface{}{
|
||
"type": "personas_response",
|
||
"personas": personas,
|
||
})
|
||
}
|
||
|
||
// sendError 发送错误消息 (使用 Send 通道)
|
||
|
||
// Close 关闭连接
|
||
func (h *Hub) Close() {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
for _, conn := range h.clients {
|
||
conn.Conn.Close()
|
||
}
|
||
} |