467 lines
13 KiB
Go
467 lines
13 KiB
Go
package controller
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
|
||
"github.com/topfans/backend/gateway/config"
|
||
"github.com/topfans/backend/gateway/pkg/response"
|
||
"github.com/topfans/backend/gateway/repository"
|
||
"github.com/topfans/backend/gateway/service"
|
||
"github.com/topfans/backend/pkg/database"
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"github.com/topfans/backend/pkg/models"
|
||
)
|
||
|
||
// LaserGenerateController 镭射卡 AI 生成控制器
|
||
type LaserGenerateController struct {
|
||
minimaxClient *service.MinimaxClient
|
||
compositorClient *service.CompositorClient
|
||
jobStore *jobStore
|
||
laserRepo *repository.LaserCardRepository
|
||
}
|
||
|
||
// jobStore 内存中维护 job 状态
|
||
type jobStore struct {
|
||
mu sync.RWMutex
|
||
jobs map[string]*GenerateJob
|
||
}
|
||
|
||
// GenerateJob 生成任务
|
||
type GenerateJob struct {
|
||
ID string `json:"job_id"`
|
||
Status string `json:"status"`
|
||
Progress float64 `json:"progress"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
Variants []map[string]interface{} `json:"variants"`
|
||
CutoutURL string `json:"cutout_url"`
|
||
Warnings []string `json:"warnings"`
|
||
Error string `json:"error"`
|
||
// 持久化用
|
||
UserID int64 `json:"-"`
|
||
StarID int64 `json:"-"`
|
||
InstanceID int64 `json:"-"`
|
||
InstanceNo string `json:"-"`
|
||
}
|
||
|
||
var globalJobStore = &jobStore{
|
||
jobs: make(map[string]*GenerateJob),
|
||
}
|
||
|
||
// NewLaserGenerateController 创建控制器
|
||
func NewLaserGenerateController(cfg *config.Config) *LaserGenerateController {
|
||
return &LaserGenerateController{
|
||
minimaxClient: service.NewMinimaxClient(cfg.Minimax.APIURL, cfg.Minimax.APIKey),
|
||
compositorClient: service.NewCompositorClient(cfg.LaserCompositor.URL),
|
||
jobStore: globalJobStore,
|
||
laserRepo: repository.NewLaserCardRepository(database.GetDB()),
|
||
}
|
||
}
|
||
|
||
// variantConfig 单个 variant 的生成配置
|
||
type variantConfig struct {
|
||
PresetID string `json:"preset_id"`
|
||
BgPrompt string `json:"bg_prompt"`
|
||
OverlayPrompt string `json:"overlay_prompt"`
|
||
GratingConfig map[string]interface{} `json:"grating_config"`
|
||
}
|
||
|
||
// generateReq 镭射卡生成请求体
|
||
type generateReq struct {
|
||
CutoutURL string `json:"cutout_url"`
|
||
PresetCodes []string `json:"preset_codes"`
|
||
RenderConfigs []map[string]interface{} `json:"render_configs"`
|
||
UserPrompt string `json:"user_prompt"` // 用户自定义 prompt,注入到 AI 生图
|
||
}
|
||
|
||
// CreateGenerateJob POST /api/v1/laser/generate
|
||
func (ctrl *LaserGenerateController) CreateGenerateJob(c *gin.Context) {
|
||
var req generateReq
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.BadRequest(c, "参数错误: "+err.Error())
|
||
return
|
||
}
|
||
|
||
if len(req.RenderConfigs) == 0 {
|
||
response.BadRequest(c, "render_configs 不能为空")
|
||
return
|
||
}
|
||
|
||
userID, ok := c.Get("user_id")
|
||
if !ok {
|
||
response.Unauthorized(c, "未登录")
|
||
return
|
||
}
|
||
starIDRaw, _ := c.Get("star_id")
|
||
var starID int64
|
||
if starIDRaw != nil {
|
||
starID = starIDRaw.(int64)
|
||
}
|
||
|
||
variants := parseVariants(req)
|
||
|
||
jobID := uuid.New().String()
|
||
|
||
// 持久化:先创建 laser_card_instances(status=rendered)
|
||
// 从 presets 中取第一个作为默认 template_code
|
||
templateCode := "default"
|
||
if len(req.PresetCodes) > 0 {
|
||
templateCode = req.PresetCodes[0]
|
||
}
|
||
var instanceID int64
|
||
var instanceNoResp string
|
||
if ctrl.laserRepo != nil {
|
||
renderedInst := &models.LaserCardInstance{
|
||
InstanceNo: "", // BeforeCreate 自动生成
|
||
InstanceUlid: "", // BeforeCreate 自动生成
|
||
TemplateID: 0, // 种子数据 id=1-5,按 code 查找
|
||
TemplateCode: templateCode,
|
||
TemplateVersion: 1,
|
||
OwnerUserID: userID.(int64),
|
||
StarID: starID,
|
||
Status: models.LaserCardInstanceStatusRendered,
|
||
MaterialsSnapshot: models.MaterialsSnapshot{},
|
||
}
|
||
// 尝试按 template_code 查到真实 template_id
|
||
if tpl, err := ctrl.laserRepo.FindTemplateByCode(templateCode); err == nil {
|
||
renderedInst.TemplateID = tpl.ID
|
||
}
|
||
|
||
if err := ctrl.laserRepo.CreateInstance(renderedInst); err != nil {
|
||
logger.Logger.Warn("Failed to persist laser_card_instance on create",
|
||
zap.String("job_id", jobID),
|
||
zap.Error(err),
|
||
)
|
||
} else {
|
||
instanceID = renderedInst.ID
|
||
instanceNoResp = renderedInst.InstanceNo
|
||
_ = ctrl.laserRepo.CreateOperationLogSimple(
|
||
renderedInst.ID, renderedInst.InstanceNo, userID.(int64),
|
||
models.LaserCardActionGenerateVariants, "", models.LaserCardInstanceStatusRendered,
|
||
)
|
||
logger.Logger.Info("Laser card instance persisted",
|
||
zap.Int64("instance_id", instanceID),
|
||
zap.String("instance_no", renderedInst.InstanceNo),
|
||
)
|
||
}
|
||
}
|
||
|
||
job := &GenerateJob{
|
||
ID: jobID,
|
||
Status: "processing",
|
||
Progress: 0,
|
||
CreatedAt: time.Now(),
|
||
CutoutURL: req.CutoutURL,
|
||
UserID: userID.(int64),
|
||
StarID: starID,
|
||
InstanceID: instanceID,
|
||
InstanceNo: instanceNoResp,
|
||
}
|
||
|
||
ctrl.jobStore.mu.Lock()
|
||
ctrl.jobStore.jobs[jobID] = job
|
||
ctrl.jobStore.mu.Unlock()
|
||
|
||
go ctrl.runParallelGeneration(userID.(int64), starID, jobID, req.CutoutURL, req.UserPrompt, variants)
|
||
|
||
response.Success(c, gin.H{
|
||
"job_id": jobID,
|
||
"estimated_seconds": 90,
|
||
"status": "processing",
|
||
"variant_count": len(variants),
|
||
"instance_no": instanceNoResp,
|
||
})
|
||
}
|
||
|
||
// GetGenerateJob GET /api/v1/laser/generate/:id
|
||
func (ctrl *LaserGenerateController) GetGenerateJob(c *gin.Context) {
|
||
jobID := c.Param("id")
|
||
|
||
ctrl.jobStore.mu.RLock()
|
||
job, exists := ctrl.jobStore.jobs[jobID]
|
||
ctrl.jobStore.mu.RUnlock()
|
||
|
||
if !exists {
|
||
response.NotFound(c, "任务不存在")
|
||
return
|
||
}
|
||
|
||
resp := gin.H{
|
||
"status": job.Status,
|
||
"progress": job.Progress,
|
||
}
|
||
|
||
if job.Status == "succeeded" {
|
||
resp["variants"] = job.Variants
|
||
resp["cutout_url"] = job.CutoutURL
|
||
resp["warnings"] = job.Warnings
|
||
if job.InstanceNo != "" {
|
||
resp["instance_no"] = job.InstanceNo
|
||
}
|
||
} else if job.Status == "failed" {
|
||
resp["error"] = job.Error
|
||
}
|
||
|
||
response.Success(c, resp)
|
||
}
|
||
|
||
func parseVariants(req generateReq) []variantConfig {
|
||
configMap := make(map[string]variantConfig)
|
||
for _, rc := range req.RenderConfigs {
|
||
if rc == nil {
|
||
continue
|
||
}
|
||
pid, _ := rc["preset_id"].(string)
|
||
bg, _ := rc["bg_prompt"].(string)
|
||
ov, _ := rc["overlay_prompt"].(string)
|
||
gc, _ := rc["grating_config"].(map[string]interface{})
|
||
if pid != "" {
|
||
configMap[pid] = variantConfig{
|
||
PresetID: pid,
|
||
BgPrompt: bg,
|
||
OverlayPrompt: ov,
|
||
GratingConfig: gc,
|
||
}
|
||
}
|
||
}
|
||
|
||
var variants []variantConfig
|
||
for _, pc := range req.PresetCodes {
|
||
if v, ok := configMap[pc]; ok {
|
||
variants = append(variants, v)
|
||
}
|
||
}
|
||
if len(variants) == 0 {
|
||
for _, v := range configMap {
|
||
variants = append(variants, v)
|
||
}
|
||
}
|
||
return variants
|
||
}
|
||
|
||
// runParallelGeneration 并行生成 5 个 variant:MiniMax(背景+装饰) → compositor 合成
|
||
func (ctrl *LaserGenerateController) runParallelGeneration(userID, starID int64, jobID, cutoutURL, userPrompt string, variants []variantConfig) {
|
||
ctx := context.Background()
|
||
|
||
logger.Logger.Info("runParallelGeneration start",
|
||
zap.String("job_id", jobID),
|
||
zap.Int("variant_count", len(variants)),
|
||
zap.String("cutout_url_prefix", safePrefix(cutoutURL, 60)),
|
||
zap.String("user_prompt", safePrefix(userPrompt, 60)),
|
||
)
|
||
|
||
var wg sync.WaitGroup
|
||
type variantResult struct {
|
||
PresetID string
|
||
Data map[string]interface{}
|
||
Err string
|
||
}
|
||
|
||
resultCh := make(chan variantResult, len(variants))
|
||
total := len(variants)
|
||
var completed int32
|
||
|
||
updateProgress := func(done int) {
|
||
pid := float64(done) / float64(total)
|
||
if pid > 0.95 {
|
||
pid = 0.95
|
||
}
|
||
ctrl.jobStore.mu.Lock()
|
||
if j, ok := ctrl.jobStore.jobs[jobID]; ok {
|
||
j.Progress = pid
|
||
}
|
||
ctrl.jobStore.mu.Unlock()
|
||
}
|
||
|
||
for i, v := range variants {
|
||
wg.Add(1)
|
||
go func(idx int, vc variantConfig) {
|
||
defer wg.Done()
|
||
|
||
ossKey := fmt.Sprintf("%s_%s", vc.PresetID, jobID[:8])
|
||
|
||
// 合并用户 prompt 到背景/装饰模板
|
||
finalBgPrompt := service.BuildBgPrompt(vc.BgPrompt, userPrompt)
|
||
finalOverlayPrompt := service.BuildOverlayPrompt(vc.OverlayPrompt, userPrompt)
|
||
|
||
logger.Logger.Info("Prompt merged",
|
||
zap.String("preset", vc.PresetID),
|
||
zap.String("bg_full", safePrefix(finalBgPrompt, 80)),
|
||
zap.String("ol_full", safePrefix(finalOverlayPrompt, 80)),
|
||
)
|
||
|
||
// Step 1: MiniMax 生成背景图(纯文生图,不含人物)
|
||
bgURL, err := ctrl.minimaxClient.GenerateImage(ctx, finalBgPrompt)
|
||
if err != nil {
|
||
logger.Logger.Error("MiniMax bg failed",
|
||
zap.String("preset", vc.PresetID),
|
||
zap.Error(err),
|
||
)
|
||
resultCh <- variantResult{PresetID: vc.PresetID, Err: fmt.Sprintf("背景生成失败: %v", err)}
|
||
done := int(atomic.AddInt32(&completed, 1))
|
||
updateProgress(done)
|
||
return
|
||
}
|
||
|
||
// Step 2: MiniMax 生成装饰图
|
||
overlayURL, err := ctrl.minimaxClient.GenerateImage(ctx, finalOverlayPrompt)
|
||
if err != nil {
|
||
logger.Logger.Error("MiniMax overlay failed",
|
||
zap.String("preset", vc.PresetID),
|
||
zap.Error(err),
|
||
)
|
||
// 装饰图失败不致命,继续合成
|
||
overlayURL = ""
|
||
}
|
||
|
||
// Step 3: compositor 6 层合成(人像由 compositor 叠加在金属层之上)
|
||
logger.Logger.Info("Compositor request",
|
||
zap.Int("idx", idx),
|
||
zap.String("preset", vc.PresetID),
|
||
zap.String("bg_url_prefix", safePrefix(bgURL, 50)),
|
||
zap.String("cutout_url_prefix", safePrefix(cutoutURL, 50)),
|
||
zap.String("overlay_url_prefix", safePrefix(overlayURL, 50)),
|
||
)
|
||
compReq := service.ComposeRequest{
|
||
BackgroundURL: bgURL,
|
||
CutoutURL: cutoutURL,
|
||
OverlayURL: overlayURL,
|
||
GratingConfig: vc.GratingConfig,
|
||
ExportWidth: 450,
|
||
ExportHeight: 600,
|
||
VariantIndex: idx,
|
||
OutputOSSKey: ossKey,
|
||
}
|
||
|
||
compResp, err := ctrl.compositorClient.Compose(ctx, compReq)
|
||
if err != nil {
|
||
logger.Logger.Error("Compositor failed",
|
||
zap.String("preset", vc.PresetID),
|
||
zap.Error(err),
|
||
)
|
||
resultCh <- variantResult{PresetID: vc.PresetID, Err: fmt.Sprintf("合成失败: %v", err)}
|
||
done := int(atomic.AddInt32(&completed, 1))
|
||
updateProgress(done)
|
||
return
|
||
}
|
||
|
||
data := map[string]interface{}{
|
||
"preset_id": vc.PresetID,
|
||
"oss_key": compResp.OSSKey,
|
||
"signed_url": compResp.SignedURL,
|
||
"width": compResp.Width,
|
||
"height": compResp.Height,
|
||
}
|
||
if compResp.Warning != "" {
|
||
data["warning"] = compResp.Warning
|
||
}
|
||
|
||
resultCh <- variantResult{PresetID: vc.PresetID, Data: data}
|
||
done := int(atomic.AddInt32(&completed, 1))
|
||
updateProgress(done)
|
||
}(i, v)
|
||
}
|
||
|
||
wg.Wait()
|
||
close(resultCh)
|
||
|
||
ctrl.jobStore.mu.Lock()
|
||
defer ctrl.jobStore.mu.Unlock()
|
||
|
||
job, ok := ctrl.jobStore.jobs[jobID]
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
var warnings []string
|
||
hasError := false
|
||
|
||
for r := range resultCh {
|
||
if r.Err != "" {
|
||
warnings = append(warnings, fmt.Sprintf("%s: %s", r.PresetID, r.Err))
|
||
hasError = true
|
||
} else if r.Data != nil {
|
||
job.Variants = append(job.Variants, r.Data)
|
||
}
|
||
}
|
||
|
||
job.Warnings = warnings
|
||
|
||
if hasError && len(job.Variants) == 0 {
|
||
job.Status = "failed"
|
||
if len(warnings) > 0 {
|
||
job.Error = warnings[0]
|
||
}
|
||
} else {
|
||
job.Status = "succeeded"
|
||
job.Progress = 1.0
|
||
// 持久化:更新 materials_snapshot
|
||
if ctrl.laserRepo != nil && job.InstanceID > 0 {
|
||
snapshot := make(models.MaterialsSnapshot, 0, len(job.Variants)+1)
|
||
// 用户原图/抠图记录
|
||
if job.CutoutURL != "" {
|
||
snapshot = append(snapshot, models.MaterialSnapshotItem{
|
||
Role: "cutout",
|
||
OssKey: job.CutoutURL,
|
||
})
|
||
}
|
||
for _, v := range job.Variants {
|
||
ossKey, _ := v["oss_key"].(string)
|
||
presetID, _ := v["preset_id"].(string)
|
||
if ossKey != "" {
|
||
snapshot = append(snapshot, models.MaterialSnapshotItem{
|
||
Role: "composite",
|
||
OssKey: ossKey,
|
||
PresetID: presetID,
|
||
})
|
||
}
|
||
}
|
||
if err := ctrl.laserRepo.UpdateMaterialsSnapshot(job.InstanceID, snapshot); err != nil {
|
||
logger.Logger.Warn("Failed to update materials_snapshot",
|
||
zap.Int64("instance_id", job.InstanceID),
|
||
zap.Error(err),
|
||
)
|
||
}
|
||
_ = ctrl.laserRepo.CreateOperationLogSimple(
|
||
job.InstanceID, job.InstanceNo, job.UserID,
|
||
models.LaserCardActionGenerateVariants,
|
||
models.LaserCardInstanceStatusRendered,
|
||
"",
|
||
)
|
||
}
|
||
}
|
||
|
||
logger.Logger.Info("Generate job completed",
|
||
zap.String("job_id", jobID),
|
||
zap.Int("variants_count", len(job.Variants)),
|
||
zap.Int("warnings", len(warnings)),
|
||
)
|
||
}
|
||
|
||
func marshalToString(v interface{}) string {
|
||
if v == nil {
|
||
return ""
|
||
}
|
||
b, err := json.Marshal(v)
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
return string(b)
|
||
}
|
||
|
||
func safePrefix(s string, n int) string {
|
||
if len(s) <= n {
|
||
return s
|
||
}
|
||
return s[:n] + "..."
|
||
}
|