package socket import ( "context" "encoding/json" "fmt" "net/http" "strconv" "strings" "sync" "time" "github.com/gorilla/websocket" "github.com/redis/go-redis/v9" "github.com/topfans/backend/pkg/jwt" "github.com/topfans/backend/pkg/logger" "go.uber.org/zap" ) // ActivityHub 管理所有 Activity WebSocket 连接 type ActivityHub struct { clients map[int64]map[*ActivityConn]struct{} // userId -> set of conns subscriptions map[string]map[*ActivityConn]struct{} // "act:42:messages" / "act:42:contributions" -> conns redisClient *redis.Client activityPath string mu sync.RWMutex } // ActivityConn 单条 WebSocket 连接 type ActivityConn struct { UserID int64 StarID int64 Conn *websocket.Conn Send chan []byte Hub *ActivityHub writeMu sync.Mutex } // writeJSON 线程安全 JSON 写入 func (c *ActivityConn) writeJSON(data interface{}) error { c.writeMu.Lock() defer c.writeMu.Unlock() return c.Conn.WriteJSON(data) } // NewActivityHub 创建 ActivityHub func NewActivityHub(redisClient *redis.Client, activityPath string) *ActivityHub { return &ActivityHub{ clients: make(map[int64]map[*ActivityConn]struct{}), subscriptions: make(map[string]map[*ActivityConn]struct{}), redisClient: redisClient, activityPath: activityPath, } } // ActivityPath 返回配置的 WebSocket 路径 func (h *ActivityHub) ActivityPath() string { return h.activityPath } // Run 启动 Redis PSubscribe,收到 publish 后 fanout 到本地连接 func (h *ActivityHub) Run(ctx context.Context) { if h.redisClient == nil { logger.Logger.Warn("ActivityHub: redisClient is nil, Pub/Sub fanout disabled") <-ctx.Done() return } sub := h.redisClient.PSubscribe(ctx, "act:*:messages", "act:*:contributions") defer sub.Close() ch := sub.Channel() logger.Logger.Info("ActivityHub subscribed to Redis Pub/Sub channels") for { select { case <-ctx.Done(): logger.Logger.Info("ActivityHub Run loop exiting due to context done") return case msg, ok := <-ch: if !ok { logger.Logger.Warn("ActivityHub Redis Pub/Sub channel closed") return } var payload map[string]interface{} if err := json.Unmarshal([]byte(msg.Payload), &payload); err != nil { logger.Logger.Error("ActivityHub failed to unmarshal pubsub payload", zap.Error(err)) continue } h.fanout(msg.Channel, payload) } } } // fanout 把 payload 推送到订阅该 channel 的所有本地连接 func (h *ActivityHub) fanout(channel string, payload map[string]interface{}) { h.mu.RLock() conns := h.subscriptions[channel] targets := make([]*ActivityConn, 0, len(conns)) for c := range conns { targets = append(targets, c) } h.mu.RUnlock() for _, c := range targets { if err := c.writeJSON(payload); err != nil { logger.Logger.Error("ActivityHub writeJSON failed", zap.Int64("user_id", c.UserID), zap.Error(err)) } } } // HandleWebSocket 处理 /activity 握手 func (h *ActivityHub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { token := r.URL.Query().Get("token") userID, starID, err := h.validateToken(token) if err != nil { logger.Logger.Error("Activity WebSocket token validation failed", zap.Error(err)) w.WriteHeader(http.StatusUnauthorized) _ = json.NewEncoder(w).Encode(map[string]interface{}{ "type": "auth_response", "success": false, "error": "invalid_token", }) return } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logger.Logger.Error("Activity WebSocket upgrade failed", zap.Error(err)) return } c := &ActivityConn{ UserID: userID, StarID: starID, Conn: conn, Send: make(chan []byte, 256), Hub: h, } h.mu.Lock() if h.clients[userID] == nil { h.clients[userID] = make(map[*ActivityConn]struct{}) } h.clients[userID][c] = struct{}{} h.mu.Unlock() logger.Logger.Info("Activity WebSocket connection established", zap.Int64("user_id", userID), zap.Int64("star_id", starID), ) // 立即推送 auth_response _ = conn.WriteJSON(map[string]interface{}{ "type": "auth_response", "success": true, "user_id": userID, "star_id": starID, }) go c.readPump() go c.writePump() } // validateToken 验证 token(JWT) func (h *ActivityHub) validateToken(token string) (int64, int64, error) { if strings.HasPrefix(token, "Bearer_") { token = strings.TrimPrefix(token, "Bearer_") } claims, err := jwt.ParseToken(token) if err != nil { return 0, 0, fmt.Errorf("failed to parse token: %w", err) } if claims.UserID == 0 { return 0, 0, fmt.Errorf("invalid user id") } return claims.UserID, claims.StarID, nil } // readPump 读取客户端消息 func (c *ActivityConn) readPump() { defer func() { c.Hub.unregister(c) c.Conn.Close() }() c.Conn.SetReadLimit(64 * 1024) // 64KB c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) c.Conn.SetPongHandler(func(string) error { c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) for { _, message, err := c.Conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { logger.Logger.Error("Activity WebSocket read error", zap.Error(err)) } break } var msg map[string]interface{} if err := json.Unmarshal(message, &msg); err != nil { logger.Logger.Error("Failed to parse activity message", zap.Error(err)) continue } c.handleMessage(msg) } } // writePump 写消息到客户端 func (c *ActivityConn) writePump() { ticker := time.NewTicker(30 * time.Second) defer func() { ticker.Stop() c.Conn.Close() }() for { select { case message, ok := <-c.Send: c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if !ok { _ = c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.Conn.NextWriter(websocket.TextMessage) if err != nil { return } _, _ = w.Write(message) if err := w.Close(); err != nil { return } case <-ticker.C: c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // handleMessage 处理客户端 subscribe/unsubscribe/ping func (c *ActivityConn) handleMessage(msg map[string]interface{}) { action, _ := msg["action"].(string) switch action { case "ping": c.Send <- []byte(`{"type":"pong"}`) case "subscribe": activityID := toInt64(msg["activity_id"]) topics := toStringSlice(msg["topics"]) c.Hub.subscribe(c, activityID, topics) _ = c.writeJSON(map[string]interface{}{ "type": "subscribe_response", "activity_id": activityID, "topics": topics, }) case "unsubscribe": activityID := toInt64(msg["activity_id"]) topics := toStringSlice(msg["topics"]) c.Hub.unsubscribe(c, activityID, topics) _ = c.writeJSON(map[string]interface{}{ "type": "unsubscribe_response", "activity_id": activityID, "topics": topics, }) default: logger.Logger.Warn("Unknown activity action", zap.String("action", action)) } } // subscribe 幂等订阅 func (h *ActivityHub) subscribe(c *ActivityConn, activityID int64, topics []string) { if activityID <= 0 || len(topics) == 0 { return } h.mu.Lock() defer h.mu.Unlock() for _, t := range topics { ch := fmt.Sprintf("act:%d:%s", activityID, t) if h.subscriptions[ch] == nil { h.subscriptions[ch] = make(map[*ActivityConn]struct{}) } h.subscriptions[ch][c] = struct{}{} } } // unsubscribe 幂等取消订阅 func (h *ActivityHub) unsubscribe(c *ActivityConn, activityID int64, topics []string) { if activityID <= 0 || len(topics) == 0 { return } h.mu.Lock() defer h.mu.Unlock() for _, t := range topics { ch := fmt.Sprintf("act:%d:%s", activityID, t) if conns, ok := h.subscriptions[ch]; ok { delete(conns, c) if len(conns) == 0 { delete(h.subscriptions, ch) } } } } // unregister 断开时清理 func (h *ActivityHub) unregister(c *ActivityConn) { h.mu.Lock() defer h.mu.Unlock() if conns, ok := h.clients[c.UserID]; ok { delete(conns, c) if len(conns) == 0 { delete(h.clients, c.UserID) } } for ch, conns := range h.subscriptions { if _, ok := conns[c]; ok { delete(conns, c) if len(conns) == 0 { delete(h.subscriptions, ch) } } } } // Close 关闭所有连接 func (h *ActivityHub) Close() { h.mu.Lock() defer h.mu.Unlock() for _, conns := range h.clients { for c := range conns { _ = c.Conn.Close() } } } // helper func toInt64(v interface{}) int64 { switch x := v.(type) { case float64: return int64(x) case int64: return x case int: return int64(x) case string: i, _ := strconv.ParseInt(x, 10, 64) return i } return 0 } func toStringSlice(v interface{}) []string { arr, ok := v.([]interface{}) if !ok { return nil } out := make([]string, 0, len(arr)) for _, item := range arr { if s, ok := item.(string); ok { out = append(out, s) } } return out }