topfans/backend/gateway/socket/activity_socket.go

369 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package socket
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/redis/go-redis/v9"
"github.com/topfans/backend/pkg/jwt"
"github.com/topfans/backend/pkg/logger"
"go.uber.org/zap"
)
// ActivityHub 管理所有 Activity WebSocket 连接
type ActivityHub struct {
clients map[int64]map[*ActivityConn]struct{} // userId -> set of conns
subscriptions map[string]map[*ActivityConn]struct{} // "act:42:messages" / "act:42:contributions" -> conns
redisClient *redis.Client
activityPath string
mu sync.RWMutex
}
// ActivityConn 单条 WebSocket 连接
type ActivityConn struct {
UserID int64
StarID int64
Conn *websocket.Conn
Send chan []byte
Hub *ActivityHub
writeMu sync.Mutex
}
// writeJSON 线程安全 JSON 写入
func (c *ActivityConn) writeJSON(data interface{}) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return c.Conn.WriteJSON(data)
}
// NewActivityHub 创建 ActivityHub
func NewActivityHub(redisClient *redis.Client, activityPath string) *ActivityHub {
return &ActivityHub{
clients: make(map[int64]map[*ActivityConn]struct{}),
subscriptions: make(map[string]map[*ActivityConn]struct{}),
redisClient: redisClient,
activityPath: activityPath,
}
}
// ActivityPath 返回配置的 WebSocket 路径
func (h *ActivityHub) ActivityPath() string {
return h.activityPath
}
// Run 启动 Redis PSubscribe收到 publish 后 fanout 到本地连接
func (h *ActivityHub) Run(ctx context.Context) {
if h.redisClient == nil {
logger.Logger.Warn("ActivityHub: redisClient is nil, Pub/Sub fanout disabled")
<-ctx.Done()
return
}
sub := h.redisClient.PSubscribe(ctx, "act:*:messages", "act:*:contributions")
defer sub.Close()
ch := sub.Channel()
logger.Logger.Info("ActivityHub subscribed to Redis Pub/Sub channels")
for {
select {
case <-ctx.Done():
logger.Logger.Info("ActivityHub Run loop exiting due to context done")
return
case msg, ok := <-ch:
if !ok {
logger.Logger.Warn("ActivityHub Redis Pub/Sub channel closed")
return
}
var payload map[string]interface{}
if err := json.Unmarshal([]byte(msg.Payload), &payload); err != nil {
logger.Logger.Error("ActivityHub failed to unmarshal pubsub payload", zap.Error(err))
continue
}
h.fanout(msg.Channel, payload)
}
}
}
// fanout 把 payload 推送到订阅该 channel 的所有本地连接
func (h *ActivityHub) fanout(channel string, payload map[string]interface{}) {
h.mu.RLock()
conns := h.subscriptions[channel]
targets := make([]*ActivityConn, 0, len(conns))
for c := range conns {
targets = append(targets, c)
}
h.mu.RUnlock()
for _, c := range targets {
if err := c.writeJSON(payload); err != nil {
logger.Logger.Error("ActivityHub writeJSON failed", zap.Int64("user_id", c.UserID), zap.Error(err))
}
}
}
// HandleWebSocket 处理 /activity 握手
func (h *ActivityHub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
userID, starID, err := h.validateToken(token)
if err != nil {
logger.Logger.Error("Activity WebSocket token validation failed", zap.Error(err))
w.WriteHeader(http.StatusUnauthorized)
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"type": "auth_response",
"success": false,
"error": "invalid_token",
})
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Logger.Error("Activity WebSocket upgrade failed", zap.Error(err))
return
}
c := &ActivityConn{
UserID: userID,
StarID: starID,
Conn: conn,
Send: make(chan []byte, 256),
Hub: h,
}
h.mu.Lock()
if h.clients[userID] == nil {
h.clients[userID] = make(map[*ActivityConn]struct{})
}
h.clients[userID][c] = struct{}{}
h.mu.Unlock()
logger.Logger.Info("Activity WebSocket connection established",
zap.Int64("user_id", userID),
zap.Int64("star_id", starID),
)
// 立即推送 auth_response
_ = conn.WriteJSON(map[string]interface{}{
"type": "auth_response",
"success": true,
"user_id": userID,
"star_id": starID,
})
go c.readPump()
go c.writePump()
}
// validateToken 验证 tokenJWT
func (h *ActivityHub) validateToken(token string) (int64, int64, error) {
if strings.HasPrefix(token, "Bearer_") {
token = strings.TrimPrefix(token, "Bearer_")
}
claims, err := jwt.ParseToken(token)
if err != nil {
return 0, 0, fmt.Errorf("failed to parse token: %w", err)
}
if claims.UserID == 0 {
return 0, 0, fmt.Errorf("invalid user id")
}
return claims.UserID, claims.StarID, nil
}
// readPump 读取客户端消息
func (c *ActivityConn) readPump() {
defer func() {
c.Hub.unregister(c)
c.Conn.Close()
}()
c.Conn.SetReadLimit(64 * 1024) // 64KB
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
logger.Logger.Error("Activity WebSocket read error", zap.Error(err))
}
break
}
var msg map[string]interface{}
if err := json.Unmarshal(message, &msg); err != nil {
logger.Logger.Error("Failed to parse activity message", zap.Error(err))
continue
}
c.handleMessage(msg)
}
}
// writePump 写消息到客户端
func (c *ActivityConn) writePump() {
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
_ = c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.Conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
_, _ = w.Write(message)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// handleMessage 处理客户端 subscribe/unsubscribe/ping
func (c *ActivityConn) handleMessage(msg map[string]interface{}) {
action, _ := msg["action"].(string)
switch action {
case "ping":
c.Send <- []byte(`{"type":"pong"}`)
case "subscribe":
activityID := toInt64(msg["activity_id"])
topics := toStringSlice(msg["topics"])
c.Hub.subscribe(c, activityID, topics)
_ = c.writeJSON(map[string]interface{}{
"type": "subscribe_response",
"activity_id": activityID,
"topics": topics,
})
case "unsubscribe":
activityID := toInt64(msg["activity_id"])
topics := toStringSlice(msg["topics"])
c.Hub.unsubscribe(c, activityID, topics)
_ = c.writeJSON(map[string]interface{}{
"type": "unsubscribe_response",
"activity_id": activityID,
"topics": topics,
})
default:
logger.Logger.Warn("Unknown activity action", zap.String("action", action))
}
}
// subscribe 幂等订阅
func (h *ActivityHub) subscribe(c *ActivityConn, activityID int64, topics []string) {
if activityID <= 0 || len(topics) == 0 {
return
}
h.mu.Lock()
defer h.mu.Unlock()
for _, t := range topics {
ch := fmt.Sprintf("act:%d:%s", activityID, t)
if h.subscriptions[ch] == nil {
h.subscriptions[ch] = make(map[*ActivityConn]struct{})
}
h.subscriptions[ch][c] = struct{}{}
}
}
// unsubscribe 幂等取消订阅
func (h *ActivityHub) unsubscribe(c *ActivityConn, activityID int64, topics []string) {
if activityID <= 0 || len(topics) == 0 {
return
}
h.mu.Lock()
defer h.mu.Unlock()
for _, t := range topics {
ch := fmt.Sprintf("act:%d:%s", activityID, t)
if conns, ok := h.subscriptions[ch]; ok {
delete(conns, c)
if len(conns) == 0 {
delete(h.subscriptions, ch)
}
}
}
}
// unregister 断开时清理
func (h *ActivityHub) unregister(c *ActivityConn) {
h.mu.Lock()
defer h.mu.Unlock()
if conns, ok := h.clients[c.UserID]; ok {
delete(conns, c)
if len(conns) == 0 {
delete(h.clients, c.UserID)
}
}
for ch, conns := range h.subscriptions {
if _, ok := conns[c]; ok {
delete(conns, c)
if len(conns) == 0 {
delete(h.subscriptions, ch)
}
}
}
}
// Close 关闭所有连接
func (h *ActivityHub) Close() {
h.mu.Lock()
defer h.mu.Unlock()
for _, conns := range h.clients {
for c := range conns {
_ = c.Conn.Close()
}
}
}
// helper
func toInt64(v interface{}) int64 {
switch x := v.(type) {
case float64:
return int64(x)
case int64:
return x
case int:
return int64(x)
case string:
i, _ := strconv.ParseInt(x, 10, 64)
return i
}
return 0
}
func toStringSlice(v interface{}) []string {
arr, ok := v.([]interface{})
if !ok {
return nil
}
out := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok {
out = append(out, s)
}
}
return out
}