topfans/backend/gateway/socket/ai_chat_socket.go
2026-05-28 12:00:19 +08:00

478 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package socket
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"dubbo.apache.org/dubbo-go/v3/client"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"github.com/gorilla/websocket"
"github.com/topfans/backend/pkg/jwt"
"github.com/topfans/backend/pkg/logger"
"go.uber.org/zap"
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
// Hub 管理所有 AI Chat WebSocket 连接
type Hub struct {
// 用户连接映射: userId -> *Connection
clients map[int64]*Connection
// Dubbo 客户端
aiChatClient *client.Client
// WebSocket 配置
aiChatPath string
mu sync.RWMutex
}
// Connection WebSocket 连接
type Connection struct {
UserID int64
StarID int64
Conn *websocket.Conn
Send chan []byte
Hub *Hub
writeMu sync.Mutex // 保护并发写操作
}
// writeJSON 线程安全的 WriteJSON 封装
func (c *Connection) writeJSON(data interface{}) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return c.Conn.WriteJSON(data)
}
// sendError 发送错误消息
func (c *Connection) sendError(code, message string) {
c.Send <- []byte(fmt.Sprintf(`{"type":"error","code":"%s","message":"%s","session_id":"%d_%d"}`,
code, message, c.UserID, c.StarID))
}
// NewHub 创建 Hub 实例
func NewHub(aiChatClient *client.Client, aiChatPath string) *Hub {
return &Hub{
clients: make(map[int64]*Connection),
aiChatClient: aiChatClient,
aiChatPath: aiChatPath,
}
}
// HandleWebSocket 处理 WebSocket 连接
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
// 从 URL 参数获取 token
token := r.URL.Query().Get("token")
if token == "" {
// 尝试从 query 参数获取URL 编码后)
token = r.URL.Query().Get("token")
}
// 解析 token 获取用户信息
userID, starID, err := h.validateToken(token)
if err != nil {
logger.Logger.Error("WebSocket token validation failed", zap.Error(err))
// 返回 401 Unauthorized而不是尝试升级 WebSocket
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]interface{}{
"type": "auth_response",
"success": false,
"error": "invalid_token",
})
return
}
// 升级为 WebSocket 连接
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
connection := &Connection{
UserID: userID,
StarID: starID,
Conn: conn,
Send: make(chan []byte, 256),
Hub: h,
}
// 注册连接
h.mu.Lock()
h.clients[userID] = connection
h.mu.Unlock()
logger.Logger.Info("WebSocket connection established",
zap.Int64("user_id", userID),
zap.Int64("star_id", starID),
)
// 发送鉴权成功响应
authResp := map[string]interface{}{
"type": "auth_response",
"success": true,
"user_id": userID,
"star_id": starID,
}
conn.WriteJSON(authResp)
// 启动读 goroutine
go connection.readPump()
// 启动写 goroutine
go connection.writePump()
}
// validateToken 验证 token使用 JWT 解析)
func (h *Hub) validateToken(token string) (int64, int64, error) {
// 去掉 "Bearer_" 前缀
if strings.HasPrefix(token, "Bearer_") {
token = strings.TrimPrefix(token, "Bearer_")
}
// 解析 JWT token
claims, err := jwt.ParseToken(token)
if err != nil {
return 0, 0, fmt.Errorf("failed to parse token: %w", err)
}
userID := claims.UserID
starID := claims.StarID
if userID == 0 {
return 0, 0, errors.New("invalid user id")
}
return userID, starID, nil
}
// readPump 读取客户端消息
func (c *Connection) readPump() {
defer func() {
c.Hub.mu.Lock()
delete(c.Hub.clients, c.UserID)
c.Hub.mu.Unlock()
c.Conn.Close()
}()
c.Conn.SetReadLimit(512 * 1024) // 512KB
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
logger.Logger.Error("WebSocket read error", zap.Error(err))
}
break
}
// 解析消息
var msg map[string]interface{}
if err := json.Unmarshal(message, &msg); err != nil {
logger.Logger.Error("Failed to parse message", zap.Error(err))
continue
}
// 处理消息
c.handleMessage(msg)
}
}
// writePump 写消息到客户端
func (c *Connection) writePump() {
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.Conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// handleMessage 处理客户端消息
func (c *Connection) handleMessage(msg map[string]interface{}) {
action, _ := msg["action"].(string)
sessionID, _ := msg["session_id"].(string)
message, _ := msg["message"].(string)
personaID, _ := msg["persona_id"].(string)
session := sessionID
if session == "" {
session = fmt.Sprintf("%d_%d", c.UserID, c.StarID)
}
switch action {
case "ping":
// 心跳 - 通过 Send 通道发送,避免并发写
c.Send <- []byte(`{"type":"pong"}`)
case "init_session":
// 初始化会话,返回欢迎消息
go c.initSession(session)
case "send_message":
// 发送消息
go c.sendMessage(session, message, personaID)
case "get_history":
// 获取历史
go c.getHistory(session)
case "get_personas":
// 获取人设列表
go c.getPersonas()
default:
logger.Logger.Warn("Unknown action", zap.String("action", action))
}
}
// initSession 初始化会话,返回欢迎消息
func (c *Connection) initSession(sessionID string) {
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(c.UserID, 10)
starIDStr := strconv.FormatInt(c.StarID, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
"star_id": starIDStr,
}
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
req := &pbAIChat.InitSessionRequest{
SessionId: sessionID,
UserId: c.UserID,
}
resp, err := aiChatService.InitSession(ctx, req)
if err != nil {
logger.Logger.Error("InitSession call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "初始化会话失败")
return
}
// 发送欢迎消息
c.writeJSON(map[string]interface{}{
"type": "message",
"session_id": sessionID,
"content": resp.WelcomeMessage,
"is_end": true,
})
}
// sendMessage 发送消息到 AI Chat Service
func (c *Connection) sendMessage(sessionID, message, personaID string) {
// 创建 Dubbo 客户端调用
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(c.UserID, 10)
starIDStr := strconv.FormatInt(c.StarID, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
"star_id": starIDStr,
}
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
req := &pbAIChat.ChatMessageRequest{
SessionId: sessionID,
Message: message,
PersonaId: personaID,
UserId: c.UserID,
}
// 流式调用
stream, err := aiChatService.SendMessage(ctx, req)
if err != nil {
logger.Logger.Error("SendMessage call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "消息发送失败")
return
}
// 接收流式响应
for {
if !stream.Recv() {
break
}
resp := stream.Msg()
// 发送消息到客户端
respMsg := map[string]interface{}{
"type": "message",
"content": resp.Content,
"session_id": resp.SessionId,
"is_end": resp.IsEnd,
}
if resp.Error != "" {
respMsg["error"] = resp.Error
}
if err := c.writeJSON(respMsg); err != nil {
logger.Logger.Error("Failed to send message to client", zap.Error(err))
break
}
if resp.IsEnd {
break
}
}
}
// getHistory 获取对话历史
func (c *Connection) getHistory(sessionID string) {
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(c.UserID, 10)
starIDStr := strconv.FormatInt(c.StarID, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
"star_id": starIDStr,
}
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
req := &pbAIChat.ChatHistoryRequest{
SessionId: sessionID,
Limit: 20,
}
resp, err := aiChatService.GetHistory(ctx, req)
if err != nil {
logger.Logger.Error("GetHistory call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "获取历史失败")
return
}
// 转换历史消息
history := make([]map[string]string, len(resp.History))
for i, m := range resp.History {
history[i] = map[string]string{
"role": m.Role,
"content": m.Content,
}
}
// 发送响应
c.writeJSON(map[string]interface{}{
"type": "history_response",
"session_id": sessionID,
"history": history,
})
}
// getPersonas 获取人设列表
func (c *Connection) getPersonas() {
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
ctx := context.Background()
req := &pbAIChat.GetPersonasRequest{
UserId: c.UserID,
}
resp, err := aiChatService.GetPersonas(ctx, req)
if err != nil {
logger.Logger.Error("GetPersonas call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "获取人设失败")
return
}
// 转换人设信息
personas := make([]map[string]interface{}, len(resp.Personas))
for i, p := range resp.Personas {
personas[i] = map[string]interface{}{
"id": p.Id,
"name": p.Name,
"description": p.Description,
"avatar_url": p.AvatarUrl,
"talk_style": p.TalkStyle,
"is_default": p.IsDefault,
"created_at": p.CreatedAt,
"updated_at": p.UpdatedAt,
}
}
// 发送响应
c.writeJSON(map[string]interface{}{
"type": "personas_response",
"personas": personas,
})
}
// sendError 发送错误消息 (使用 Send 通道)
// Close 关闭连接
func (h *Hub) Close() {
h.mu.Lock()
defer h.mu.Unlock()
for _, conn := range h.clients {
conn.Conn.Close()
}
}