169 lines
5.0 KiB
Go
169 lines
5.0 KiB
Go
// Package repository 提供 UserDevice 的数据访问层(public.user_devices 表)。
|
|
//
|
|
// 设计约定(与 NotificationRepository 保持一致):
|
|
// - 写操作(Upsert / Deactivate / MarkAllInactiveForUser)接受 *gorm.DB,
|
|
// 由 service 层在事务回调内传入 tx。
|
|
// - 读操作(ListActiveCIDsByUserID)直接走仓储 db,无事务。
|
|
package repository
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/topfans/backend/pkg/logger"
|
|
"github.com/topfans/backend/services/notificationService/model"
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// UserDeviceRepository 用户推送设备仓储。
|
|
type UserDeviceRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewUserDeviceRepository 创建 UserDeviceRepository。
|
|
func NewUserDeviceRepository(db *gorm.DB) *UserDeviceRepository {
|
|
return &UserDeviceRepository{db: db}
|
|
}
|
|
|
|
// execDB 返回带 ctx 的执行句柄(优先使用传入的 tx,否则使用仓储默认 db)。
|
|
func (r *UserDeviceRepository) execDB(tx *gorm.DB, ctx context.Context) *gorm.DB {
|
|
if tx != nil {
|
|
return tx.WithContext(ctx)
|
|
}
|
|
return r.db.WithContext(ctx)
|
|
}
|
|
|
|
// guardDB 在 db 与 tx 都为 nil 时返回 error;便于 service 层在测试桩或异常状态下快速失败。
|
|
func (r *UserDeviceRepository) guardDB(tx *gorm.DB) error {
|
|
if tx != nil {
|
|
return nil
|
|
}
|
|
if r.db == nil {
|
|
return errors.New("user_device repository: db is nil")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpsertByCID 按 cid upsert:同 cid 已存在则更新 user_id/platform/version/active/updated_at;
|
|
// 不存在则插入新行(id 由 PG 序列生成)。
|
|
//
|
|
// 返回:写入后的最新 UserDevice(含 id)和 err。
|
|
//
|
|
// 注:cid 是 Pusher 的关键标识,App 端可能因 token 轮换而变化,所以主键是 cid 而非 user_id。
|
|
// 同一用户多设备 = 多行;同一设备 token 变化 = 同一行更新。
|
|
func (r *UserDeviceRepository) UpsertByCID(
|
|
ctx context.Context,
|
|
tx *gorm.DB,
|
|
cid string,
|
|
userID int64,
|
|
platform, appVersion, deviceModel string,
|
|
) (*model.UserDevice, error) {
|
|
if cid == "" {
|
|
return nil, errors.New("cid is required")
|
|
}
|
|
if userID <= 0 {
|
|
return nil, errors.New("user_id is required")
|
|
}
|
|
if err := r.guardDB(tx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().UnixMilli()
|
|
gdb := r.execDB(tx, ctx)
|
|
|
|
// 先查一次:命中则 update,未命中则 insert。
|
|
var existing model.UserDevice
|
|
err := gdb.Where("cid = ?", cid).First(&existing).Error
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
|
|
if existing.ID > 0 {
|
|
// 命中:更新 user/platform/version/model/active/updated_at
|
|
existing.UserID = userID
|
|
existing.Platform = platform
|
|
existing.AppVersion = appVersion
|
|
existing.DeviceModel = deviceModel
|
|
existing.Active = true
|
|
existing.UpdatedAt = now
|
|
if err := gdb.Save(&existing).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &existing, nil
|
|
}
|
|
|
|
// 未命中:insert(id 由序列生成)。
|
|
d := &model.UserDevice{
|
|
UserID: userID,
|
|
CID: cid,
|
|
Platform: platform,
|
|
AppVersion: appVersion,
|
|
DeviceModel: deviceModel,
|
|
Active: true,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
if err := gdb.Exec(`
|
|
INSERT INTO public.user_devices
|
|
(user_id, cid, platform, app_version, device_model, active, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
`, d.UserID, d.CID, d.Platform, d.AppVersion, d.DeviceModel, d.Active, d.CreatedAt, d.UpdatedAt).Error; err != nil {
|
|
logger.Logger.Error("failed to insert user_device", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
// 取自增 id
|
|
var id int64
|
|
if err := gdb.Raw(`SELECT currval(pg_get_serial_sequence('public.user_devices','id'))`).Scan(&id).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
d.ID = id
|
|
return d, nil
|
|
}
|
|
|
|
// DeactivateByCID 将指定 cid 标记为 inactive(登出 / token 失效)。
|
|
// 返回:受影响的行数。
|
|
func (r *UserDeviceRepository) DeactivateByCID(ctx context.Context, tx *gorm.DB, cid string) (int64, error) {
|
|
if cid == "" {
|
|
return 0, errors.New("cid is required")
|
|
}
|
|
if err := r.guardDB(tx); err != nil {
|
|
return 0, err
|
|
}
|
|
gdb := r.execDB(tx, ctx)
|
|
res := gdb.Exec(`
|
|
UPDATE public.user_devices
|
|
SET active = FALSE, updated_at = $1
|
|
WHERE cid = $2 AND active = TRUE
|
|
`, time.Now().UnixMilli(), cid)
|
|
return res.RowsAffected, res.Error
|
|
}
|
|
|
|
// ListActiveCIDsByUserID 查询某用户所有 active=TRUE 设备的 cid。
|
|
// 用于推送时拉取目标 cids。
|
|
func (r *UserDeviceRepository) ListActiveCIDsByUserID(ctx context.Context, userID int64) ([]string, error) {
|
|
if userID <= 0 {
|
|
return nil, errors.New("user_id is required")
|
|
}
|
|
if r.db == nil {
|
|
return nil, errors.New("user_device repository: db is nil")
|
|
}
|
|
var cids []string
|
|
err := r.db.WithContext(ctx).
|
|
Model(&model.UserDevice{}).
|
|
Where("user_id = ? AND active = TRUE", userID).
|
|
Pluck("cid", &cids).Error
|
|
return cids, err
|
|
}
|
|
|
|
// CountActiveByUserID 统计某用户 active 设备数(用于调试/监控;非关键路径)。
|
|
func (r *UserDeviceRepository) CountActiveByUserID(ctx context.Context, userID int64) (int64, error) {
|
|
var n int64
|
|
err := r.db.WithContext(ctx).
|
|
Model(&model.UserDevice{}).
|
|
Where("user_id = ? AND active = TRUE", userID).
|
|
Count(&n).Error
|
|
return n, err
|
|
} |