// Package repository 提供 UserDevice 的数据访问层(public.user_devices 表)。 // // 设计约定(与 NotificationRepository 保持一致): // - 写操作(Upsert / Deactivate / MarkAllInactiveForUser)接受 *gorm.DB, // 由 service 层在事务回调内传入 tx。 // - 读操作(ListActiveCIDsByUserID)直接走仓储 db,无事务。 package repository import ( "context" "errors" "time" "github.com/topfans/backend/pkg/logger" "github.com/topfans/backend/services/notificationService/model" "go.uber.org/zap" "gorm.io/gorm" ) // UserDeviceRepository 用户推送设备仓储。 type UserDeviceRepository struct { db *gorm.DB } // NewUserDeviceRepository 创建 UserDeviceRepository。 func NewUserDeviceRepository(db *gorm.DB) *UserDeviceRepository { return &UserDeviceRepository{db: db} } // execDB 返回带 ctx 的执行句柄(优先使用传入的 tx,否则使用仓储默认 db)。 func (r *UserDeviceRepository) execDB(tx *gorm.DB, ctx context.Context) *gorm.DB { if tx != nil { return tx.WithContext(ctx) } return r.db.WithContext(ctx) } // guardDB 在 db 与 tx 都为 nil 时返回 error;便于 service 层在测试桩或异常状态下快速失败。 func (r *UserDeviceRepository) guardDB(tx *gorm.DB) error { if tx != nil { return nil } if r.db == nil { return errors.New("user_device repository: db is nil") } return nil } // UpsertByCID 按 cid upsert:同 cid 已存在则更新 user_id/platform/version/active/updated_at; // 不存在则插入新行(id 由 PG 序列生成)。 // // 返回:写入后的最新 UserDevice(含 id)和 err。 // // 注:cid 是 Pusher 的关键标识,App 端可能因 token 轮换而变化,所以主键是 cid 而非 user_id。 // 同一用户多设备 = 多行;同一设备 token 变化 = 同一行更新。 func (r *UserDeviceRepository) UpsertByCID( ctx context.Context, tx *gorm.DB, cid string, userID int64, platform, appVersion, deviceModel string, ) (*model.UserDevice, error) { if cid == "" { return nil, errors.New("cid is required") } if userID <= 0 { return nil, errors.New("user_id is required") } if err := r.guardDB(tx); err != nil { return nil, err } now := time.Now().UnixMilli() gdb := r.execDB(tx, ctx) // 先查一次:命中则 update,未命中则 insert。 var existing model.UserDevice err := gdb.Where("cid = ?", cid).First(&existing).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } if existing.ID > 0 { // 命中:更新 user/platform/version/model/active/updated_at existing.UserID = userID existing.Platform = platform existing.AppVersion = appVersion existing.DeviceModel = deviceModel existing.Active = true existing.UpdatedAt = now if err := gdb.Save(&existing).Error; err != nil { return nil, err } return &existing, nil } // 未命中:insert(id 由序列生成)。 d := &model.UserDevice{ UserID: userID, CID: cid, Platform: platform, AppVersion: appVersion, DeviceModel: deviceModel, Active: true, CreatedAt: now, UpdatedAt: now, } if err := gdb.Exec(` INSERT INTO public.user_devices (user_id, cid, platform, app_version, device_model, active, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) `, d.UserID, d.CID, d.Platform, d.AppVersion, d.DeviceModel, d.Active, d.CreatedAt, d.UpdatedAt).Error; err != nil { logger.Logger.Error("failed to insert user_device", zap.Error(err)) return nil, err } // 取自增 id var id int64 if err := gdb.Raw(`SELECT currval(pg_get_serial_sequence('public.user_devices','id'))`).Scan(&id).Error; err != nil { return nil, err } d.ID = id return d, nil } // DeactivateByCID 将指定 cid 标记为 inactive(登出 / token 失效)。 // 返回:受影响的行数。 func (r *UserDeviceRepository) DeactivateByCID(ctx context.Context, tx *gorm.DB, cid string) (int64, error) { if cid == "" { return 0, errors.New("cid is required") } if err := r.guardDB(tx); err != nil { return 0, err } gdb := r.execDB(tx, ctx) res := gdb.Exec(` UPDATE public.user_devices SET active = FALSE, updated_at = $1 WHERE cid = $2 AND active = TRUE `, time.Now().UnixMilli(), cid) return res.RowsAffected, res.Error } // ListActiveCIDsByUserID 查询某用户所有 active=TRUE 设备的 cid。 // 用于推送时拉取目标 cids。 func (r *UserDeviceRepository) ListActiveCIDsByUserID(ctx context.Context, userID int64) ([]string, error) { if userID <= 0 { return nil, errors.New("user_id is required") } if r.db == nil { return nil, errors.New("user_device repository: db is nil") } var cids []string err := r.db.WithContext(ctx). Model(&model.UserDevice{}). Where("user_id = ? AND active = TRUE", userID). Pluck("cid", &cids).Error return cids, err } // CountActiveByUserID 统计某用户 active 设备数(用于调试/监控;非关键路径)。 func (r *UserDeviceRepository) CountActiveByUserID(ctx context.Context, userID int64) (int64, error) { var n int64 err := r.db.WithContext(ctx). Model(&model.UserDevice{}). Where("user_id = ? AND active = TRUE", userID). Count(&n).Error return n, err }