topfans/backend/services/notificationService/repository/user_device_repository.go

169 lines
5.0 KiB
Go

// Package repository 提供 UserDevice 的数据访问层(public.user_devices 表)。
//
// 设计约定(与 NotificationRepository 保持一致):
// - 写操作(Upsert / Deactivate / MarkAllInactiveForUser)接受 *gorm.DB,
// 由 service 层在事务回调内传入 tx。
// - 读操作(ListActiveCIDsByUserID)直接走仓储 db,无事务。
package repository
import (
"context"
"errors"
"time"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/notificationService/model"
"go.uber.org/zap"
"gorm.io/gorm"
)
// UserDeviceRepository 用户推送设备仓储。
type UserDeviceRepository struct {
db *gorm.DB
}
// NewUserDeviceRepository 创建 UserDeviceRepository。
func NewUserDeviceRepository(db *gorm.DB) *UserDeviceRepository {
return &UserDeviceRepository{db: db}
}
// execDB 返回带 ctx 的执行句柄(优先使用传入的 tx,否则使用仓储默认 db)。
func (r *UserDeviceRepository) execDB(tx *gorm.DB, ctx context.Context) *gorm.DB {
if tx != nil {
return tx.WithContext(ctx)
}
return r.db.WithContext(ctx)
}
// guardDB 在 db 与 tx 都为 nil 时返回 error;便于 service 层在测试桩或异常状态下快速失败。
func (r *UserDeviceRepository) guardDB(tx *gorm.DB) error {
if tx != nil {
return nil
}
if r.db == nil {
return errors.New("user_device repository: db is nil")
}
return nil
}
// UpsertByCID 按 cid upsert:同 cid 已存在则更新 user_id/platform/version/active/updated_at;
// 不存在则插入新行(id 由 PG 序列生成)。
//
// 返回:写入后的最新 UserDevice(含 id)和 err。
//
// 注:cid 是 Pusher 的关键标识,App 端可能因 token 轮换而变化,所以主键是 cid 而非 user_id。
// 同一用户多设备 = 多行;同一设备 token 变化 = 同一行更新。
func (r *UserDeviceRepository) UpsertByCID(
ctx context.Context,
tx *gorm.DB,
cid string,
userID int64,
platform, appVersion, deviceModel string,
) (*model.UserDevice, error) {
if cid == "" {
return nil, errors.New("cid is required")
}
if userID <= 0 {
return nil, errors.New("user_id is required")
}
if err := r.guardDB(tx); err != nil {
return nil, err
}
now := time.Now().UnixMilli()
gdb := r.execDB(tx, ctx)
// 先查一次:命中则 update,未命中则 insert。
var existing model.UserDevice
err := gdb.Where("cid = ?", cid).First(&existing).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if existing.ID > 0 {
// 命中:更新 user/platform/version/model/active/updated_at
existing.UserID = userID
existing.Platform = platform
existing.AppVersion = appVersion
existing.DeviceModel = deviceModel
existing.Active = true
existing.UpdatedAt = now
if err := gdb.Save(&existing).Error; err != nil {
return nil, err
}
return &existing, nil
}
// 未命中:insert(id 由序列生成)。
d := &model.UserDevice{
UserID: userID,
CID: cid,
Platform: platform,
AppVersion: appVersion,
DeviceModel: deviceModel,
Active: true,
CreatedAt: now,
UpdatedAt: now,
}
if err := gdb.Exec(`
INSERT INTO public.user_devices
(user_id, cid, platform, app_version, device_model, active, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, d.UserID, d.CID, d.Platform, d.AppVersion, d.DeviceModel, d.Active, d.CreatedAt, d.UpdatedAt).Error; err != nil {
logger.Logger.Error("failed to insert user_device", zap.Error(err))
return nil, err
}
// 取自增 id
var id int64
if err := gdb.Raw(`SELECT currval(pg_get_serial_sequence('public.user_devices','id'))`).Scan(&id).Error; err != nil {
return nil, err
}
d.ID = id
return d, nil
}
// DeactivateByCID 将指定 cid 标记为 inactive(登出 / token 失效)。
// 返回:受影响的行数。
func (r *UserDeviceRepository) DeactivateByCID(ctx context.Context, tx *gorm.DB, cid string) (int64, error) {
if cid == "" {
return 0, errors.New("cid is required")
}
if err := r.guardDB(tx); err != nil {
return 0, err
}
gdb := r.execDB(tx, ctx)
res := gdb.Exec(`
UPDATE public.user_devices
SET active = FALSE, updated_at = $1
WHERE cid = $2 AND active = TRUE
`, time.Now().UnixMilli(), cid)
return res.RowsAffected, res.Error
}
// ListActiveCIDsByUserID 查询某用户所有 active=TRUE 设备的 cid。
// 用于推送时拉取目标 cids。
func (r *UserDeviceRepository) ListActiveCIDsByUserID(ctx context.Context, userID int64) ([]string, error) {
if userID <= 0 {
return nil, errors.New("user_id is required")
}
if r.db == nil {
return nil, errors.New("user_device repository: db is nil")
}
var cids []string
err := r.db.WithContext(ctx).
Model(&model.UserDevice{}).
Where("user_id = ? AND active = TRUE", userID).
Pluck("cid", &cids).Error
return cids, err
}
// CountActiveByUserID 统计某用户 active 设备数(用于调试/监控;非关键路径)。
func (r *UserDeviceRepository) CountActiveByUserID(ctx context.Context, userID int64) (int64, error) {
var n int64
err := r.db.WithContext(ctx).
Model(&model.UserDevice{}).
Where("user_id = ? AND active = TRUE", userID).
Count(&n).Error
return n, err
}