123 lines
4.0 KiB
Go
123 lines
4.0 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
|
||
"github.com/google/uuid"
|
||
"github.com/topfans/backend/services/aiChatService/model"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// PersonaRepository 人设仓库接口
|
||
type PersonaRepository interface {
|
||
Create(ctx context.Context, persona *model.Persona) error
|
||
GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error)
|
||
GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error)
|
||
GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error)
|
||
Update(ctx context.Context, persona *model.Persona) error
|
||
Delete(ctx context.Context, id uuid.UUID) error
|
||
EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error)
|
||
}
|
||
|
||
// DefaultSystemPrompt 默认系统提示词
|
||
const DefaultSystemPrompt = `你是一个温柔体贴的AI伴侣,名字叫角角。你善于倾听,能理解用户的情绪,
|
||
用温暖的话语陪伴用户。说话风格亲切自然,像朋友聊天一样。
|
||
不要过于正式或说教,当用户情绪低落时,先给予共情和安慰。`
|
||
|
||
// DefaultPersonaName 默认人设名称
|
||
const DefaultPersonaName = "角角"
|
||
|
||
// DefaultPersonaDescription 默认人设描述
|
||
const DefaultPersonaDescription = "温柔陪伴型闺蜜"
|
||
|
||
// PostgreSQLPersonaRepository PostgreSQL 人设仓库实现
|
||
type PostgreSQLPersonaRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewPostgreSQLPersonaRepository 创建人设仓库
|
||
func NewPostgreSQLPersonaRepository(db *gorm.DB) *PostgreSQLPersonaRepository {
|
||
return &PostgreSQLPersonaRepository{db: db}
|
||
}
|
||
|
||
// Create 创建人设
|
||
func (r *PostgreSQLPersonaRepository) Create(ctx context.Context, persona *model.Persona) error {
|
||
return r.db.WithContext(ctx).Create(persona).Error
|
||
}
|
||
|
||
// GetByID 根据 ID 获取人设
|
||
func (r *PostgreSQLPersonaRepository) GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error) {
|
||
var persona model.Persona
|
||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&persona).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil, model.ErrPersonaNotFound
|
||
}
|
||
return nil, fmt.Errorf("failed to get persona: %w", err)
|
||
}
|
||
return &persona, nil
|
||
}
|
||
|
||
// GetByUserID 获取用户的所有人设
|
||
func (r *PostgreSQLPersonaRepository) GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error) {
|
||
var personas []model.Persona
|
||
if err := r.db.WithContext(ctx).
|
||
Where("user_id = ?", userID).
|
||
Order("created_at DESC").
|
||
Find(&personas).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to get personas: %w", err)
|
||
}
|
||
return personas, nil
|
||
}
|
||
|
||
// GetDefaultByUserID 获取用户的默认人设
|
||
func (r *PostgreSQLPersonaRepository) GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error) {
|
||
var persona model.Persona
|
||
if err := r.db.WithContext(ctx).
|
||
Where("user_id = ? AND is_default = TRUE", userID).
|
||
First(&persona).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil, model.ErrPersonaNotFound
|
||
}
|
||
return nil, fmt.Errorf("failed to get default persona: %w", err)
|
||
}
|
||
return &persona, nil
|
||
}
|
||
|
||
// Update 更新人设
|
||
func (r *PostgreSQLPersonaRepository) Update(ctx context.Context, persona *model.Persona) error {
|
||
return r.db.WithContext(ctx).Save(persona).Error
|
||
}
|
||
|
||
// Delete 删除人设
|
||
func (r *PostgreSQLPersonaRepository) Delete(ctx context.Context, id uuid.UUID) error {
|
||
return r.db.WithContext(ctx).Delete(&model.Persona{}, "id = ?", id).Error
|
||
}
|
||
|
||
// EnsureDefaultPersona 确保用户有默认人设
|
||
func (r *PostgreSQLPersonaRepository) EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error) {
|
||
// 检查是否已有默认人设
|
||
persona, err := r.GetDefaultByUserID(ctx, userID)
|
||
if err == nil {
|
||
return persona, nil
|
||
}
|
||
if err != model.ErrPersonaNotFound {
|
||
return nil, err
|
||
}
|
||
|
||
// 创建默认人设
|
||
persona = &model.Persona{
|
||
UserID: userID,
|
||
Name: DefaultPersonaName,
|
||
Description: DefaultPersonaDescription,
|
||
SystemPrompt: DefaultSystemPrompt,
|
||
IsDefault: true,
|
||
}
|
||
|
||
if err := r.Create(ctx, persona); err != nil {
|
||
return nil, fmt.Errorf("failed to create default persona: %w", err)
|
||
}
|
||
|
||
return persona, nil
|
||
}
|