topfans/backend/services/aiChatService/repository/persona_repository.go
2026-05-28 12:00:19 +08:00

123 lines
4.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/topfans/backend/services/aiChatService/model"
"gorm.io/gorm"
)
// PersonaRepository 人设仓库接口
type PersonaRepository interface {
Create(ctx context.Context, persona *model.Persona) error
GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error)
GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error)
GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error)
Update(ctx context.Context, persona *model.Persona) error
Delete(ctx context.Context, id uuid.UUID) error
EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error)
}
// DefaultSystemPrompt 默认系统提示词
const DefaultSystemPrompt = `你是一个温柔体贴的AI伴侣名字叫角角。你善于倾听能理解用户的情绪
用温暖的话语陪伴用户。说话风格亲切自然,像朋友聊天一样。
不要过于正式或说教,当用户情绪低落时,先给予共情和安慰。`
// DefaultPersonaName 默认人设名称
const DefaultPersonaName = "角角"
// DefaultPersonaDescription 默认人设描述
const DefaultPersonaDescription = "温柔陪伴型闺蜜"
// PostgreSQLPersonaRepository PostgreSQL 人设仓库实现
type PostgreSQLPersonaRepository struct {
db *gorm.DB
}
// NewPostgreSQLPersonaRepository 创建人设仓库
func NewPostgreSQLPersonaRepository(db *gorm.DB) *PostgreSQLPersonaRepository {
return &PostgreSQLPersonaRepository{db: db}
}
// Create 创建人设
func (r *PostgreSQLPersonaRepository) Create(ctx context.Context, persona *model.Persona) error {
return r.db.WithContext(ctx).Create(persona).Error
}
// GetByID 根据 ID 获取人设
func (r *PostgreSQLPersonaRepository) GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error) {
var persona model.Persona
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&persona).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, model.ErrPersonaNotFound
}
return nil, fmt.Errorf("failed to get persona: %w", err)
}
return &persona, nil
}
// GetByUserID 获取用户的所有人设
func (r *PostgreSQLPersonaRepository) GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error) {
var personas []model.Persona
if err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&personas).Error; err != nil {
return nil, fmt.Errorf("failed to get personas: %w", err)
}
return personas, nil
}
// GetDefaultByUserID 获取用户的默认人设
func (r *PostgreSQLPersonaRepository) GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error) {
var persona model.Persona
if err := r.db.WithContext(ctx).
Where("user_id = ? AND is_default = TRUE", userID).
First(&persona).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, model.ErrPersonaNotFound
}
return nil, fmt.Errorf("failed to get default persona: %w", err)
}
return &persona, nil
}
// Update 更新人设
func (r *PostgreSQLPersonaRepository) Update(ctx context.Context, persona *model.Persona) error {
return r.db.WithContext(ctx).Save(persona).Error
}
// Delete 删除人设
func (r *PostgreSQLPersonaRepository) Delete(ctx context.Context, id uuid.UUID) error {
return r.db.WithContext(ctx).Delete(&model.Persona{}, "id = ?", id).Error
}
// EnsureDefaultPersona 确保用户有默认人设
func (r *PostgreSQLPersonaRepository) EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error) {
// 检查是否已有默认人设
persona, err := r.GetDefaultByUserID(ctx, userID)
if err == nil {
return persona, nil
}
if err != model.ErrPersonaNotFound {
return nil, err
}
// 创建默认人设
persona = &model.Persona{
UserID: userID,
Name: DefaultPersonaName,
Description: DefaultPersonaDescription,
SystemPrompt: DefaultSystemPrompt,
IsDefault: true,
}
if err := r.Create(ctx, persona); err != nil {
return nil, fmt.Errorf("failed to create default persona: %w", err)
}
return persona, nil
}