topfans/backend/gateway/controller/laser_generate_controller.go
2026-06-03 22:19:22 +08:00

467 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/topfans/backend/gateway/config"
"github.com/topfans/backend/gateway/pkg/response"
"github.com/topfans/backend/gateway/repository"
"github.com/topfans/backend/gateway/service"
"github.com/topfans/backend/pkg/database"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/pkg/models"
)
// LaserGenerateController 镭射卡 AI 生成控制器
type LaserGenerateController struct {
minimaxClient *service.MinimaxClient
compositorClient *service.CompositorClient
jobStore *jobStore
laserRepo *repository.LaserCardRepository
}
// jobStore 内存中维护 job 状态
type jobStore struct {
mu sync.RWMutex
jobs map[string]*GenerateJob
}
// GenerateJob 生成任务
type GenerateJob struct {
ID string `json:"job_id"`
Status string `json:"status"`
Progress float64 `json:"progress"`
CreatedAt time.Time `json:"created_at"`
Variants []map[string]interface{} `json:"variants"`
CutoutURL string `json:"cutout_url"`
Warnings []string `json:"warnings"`
Error string `json:"error"`
// 持久化用
UserID int64 `json:"-"`
StarID int64 `json:"-"`
InstanceID int64 `json:"-"`
InstanceNo string `json:"-"`
}
var globalJobStore = &jobStore{
jobs: make(map[string]*GenerateJob),
}
// NewLaserGenerateController 创建控制器
func NewLaserGenerateController(cfg *config.Config) *LaserGenerateController {
return &LaserGenerateController{
minimaxClient: service.NewMinimaxClient(cfg.Minimax.APIURL, cfg.Minimax.APIKey),
compositorClient: service.NewCompositorClient(cfg.LaserCompositor.URL),
jobStore: globalJobStore,
laserRepo: repository.NewLaserCardRepository(database.GetDB()),
}
}
// variantConfig 单个 variant 的生成配置
type variantConfig struct {
PresetID string `json:"preset_id"`
BgPrompt string `json:"bg_prompt"`
OverlayPrompt string `json:"overlay_prompt"`
GratingConfig map[string]interface{} `json:"grating_config"`
}
// generateReq 镭射卡生成请求体
type generateReq struct {
CutoutURL string `json:"cutout_url"`
PresetCodes []string `json:"preset_codes"`
RenderConfigs []map[string]interface{} `json:"render_configs"`
UserPrompt string `json:"user_prompt"` // 用户自定义 prompt注入到 AI 生图
}
// CreateGenerateJob POST /api/v1/laser/generate
func (ctrl *LaserGenerateController) CreateGenerateJob(c *gin.Context) {
var req generateReq
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "参数错误: "+err.Error())
return
}
if len(req.RenderConfigs) == 0 {
response.BadRequest(c, "render_configs 不能为空")
return
}
userID, ok := c.Get("user_id")
if !ok {
response.Unauthorized(c, "未登录")
return
}
starIDRaw, _ := c.Get("star_id")
var starID int64
if starIDRaw != nil {
starID = starIDRaw.(int64)
}
variants := parseVariants(req)
jobID := uuid.New().String()
// 持久化:先创建 laser_card_instancesstatus=rendered
// 从 presets 中取第一个作为默认 template_code
templateCode := "default"
if len(req.PresetCodes) > 0 {
templateCode = req.PresetCodes[0]
}
var instanceID int64
var instanceNoResp string
if ctrl.laserRepo != nil {
renderedInst := &models.LaserCardInstance{
InstanceNo: "", // BeforeCreate 自动生成
InstanceUlid: "", // BeforeCreate 自动生成
TemplateID: 0, // 种子数据 id=1-5按 code 查找
TemplateCode: templateCode,
TemplateVersion: 1,
OwnerUserID: userID.(int64),
StarID: starID,
Status: models.LaserCardInstanceStatusRendered,
MaterialsSnapshot: models.MaterialsSnapshot{},
}
// 尝试按 template_code 查到真实 template_id
if tpl, err := ctrl.laserRepo.FindTemplateByCode(templateCode); err == nil {
renderedInst.TemplateID = tpl.ID
}
if err := ctrl.laserRepo.CreateInstance(renderedInst); err != nil {
logger.Logger.Warn("Failed to persist laser_card_instance on create",
zap.String("job_id", jobID),
zap.Error(err),
)
} else {
instanceID = renderedInst.ID
instanceNoResp = renderedInst.InstanceNo
_ = ctrl.laserRepo.CreateOperationLogSimple(
renderedInst.ID, renderedInst.InstanceNo, userID.(int64),
models.LaserCardActionGenerateVariants, "", models.LaserCardInstanceStatusRendered,
)
logger.Logger.Info("Laser card instance persisted",
zap.Int64("instance_id", instanceID),
zap.String("instance_no", renderedInst.InstanceNo),
)
}
}
job := &GenerateJob{
ID: jobID,
Status: "processing",
Progress: 0,
CreatedAt: time.Now(),
CutoutURL: req.CutoutURL,
UserID: userID.(int64),
StarID: starID,
InstanceID: instanceID,
InstanceNo: instanceNoResp,
}
ctrl.jobStore.mu.Lock()
ctrl.jobStore.jobs[jobID] = job
ctrl.jobStore.mu.Unlock()
go ctrl.runParallelGeneration(userID.(int64), starID, jobID, req.CutoutURL, req.UserPrompt, variants)
response.Success(c, gin.H{
"job_id": jobID,
"estimated_seconds": 90,
"status": "processing",
"variant_count": len(variants),
"instance_no": instanceNoResp,
})
}
// GetGenerateJob GET /api/v1/laser/generate/:id
func (ctrl *LaserGenerateController) GetGenerateJob(c *gin.Context) {
jobID := c.Param("id")
ctrl.jobStore.mu.RLock()
job, exists := ctrl.jobStore.jobs[jobID]
ctrl.jobStore.mu.RUnlock()
if !exists {
response.NotFound(c, "任务不存在")
return
}
resp := gin.H{
"status": job.Status,
"progress": job.Progress,
}
if job.Status == "succeeded" {
resp["variants"] = job.Variants
resp["cutout_url"] = job.CutoutURL
resp["warnings"] = job.Warnings
if job.InstanceNo != "" {
resp["instance_no"] = job.InstanceNo
}
} else if job.Status == "failed" {
resp["error"] = job.Error
}
response.Success(c, resp)
}
func parseVariants(req generateReq) []variantConfig {
configMap := make(map[string]variantConfig)
for _, rc := range req.RenderConfigs {
if rc == nil {
continue
}
pid, _ := rc["preset_id"].(string)
bg, _ := rc["bg_prompt"].(string)
ov, _ := rc["overlay_prompt"].(string)
gc, _ := rc["grating_config"].(map[string]interface{})
if pid != "" {
configMap[pid] = variantConfig{
PresetID: pid,
BgPrompt: bg,
OverlayPrompt: ov,
GratingConfig: gc,
}
}
}
var variants []variantConfig
for _, pc := range req.PresetCodes {
if v, ok := configMap[pc]; ok {
variants = append(variants, v)
}
}
if len(variants) == 0 {
for _, v := range configMap {
variants = append(variants, v)
}
}
return variants
}
// runParallelGeneration 并行生成 5 个 variantMiniMax(背景+装饰) → compositor 合成
func (ctrl *LaserGenerateController) runParallelGeneration(userID, starID int64, jobID, cutoutURL, userPrompt string, variants []variantConfig) {
ctx := context.Background()
logger.Logger.Info("runParallelGeneration start",
zap.String("job_id", jobID),
zap.Int("variant_count", len(variants)),
zap.String("cutout_url_prefix", safePrefix(cutoutURL, 60)),
zap.String("user_prompt", safePrefix(userPrompt, 60)),
)
var wg sync.WaitGroup
type variantResult struct {
PresetID string
Data map[string]interface{}
Err string
}
resultCh := make(chan variantResult, len(variants))
total := len(variants)
var completed int32
updateProgress := func(done int) {
pid := float64(done) / float64(total)
if pid > 0.95 {
pid = 0.95
}
ctrl.jobStore.mu.Lock()
if j, ok := ctrl.jobStore.jobs[jobID]; ok {
j.Progress = pid
}
ctrl.jobStore.mu.Unlock()
}
for i, v := range variants {
wg.Add(1)
go func(idx int, vc variantConfig) {
defer wg.Done()
ossKey := fmt.Sprintf("%s_%s", vc.PresetID, jobID[:8])
// 合并用户 prompt 到背景/装饰模板
finalBgPrompt := service.BuildBgPrompt(vc.BgPrompt, userPrompt)
finalOverlayPrompt := service.BuildOverlayPrompt(vc.OverlayPrompt, userPrompt)
logger.Logger.Info("Prompt merged",
zap.String("preset", vc.PresetID),
zap.String("bg_full", safePrefix(finalBgPrompt, 80)),
zap.String("ol_full", safePrefix(finalOverlayPrompt, 80)),
)
// Step 1: MiniMax 生成背景图(纯文生图,不含人物)
bgURL, err := ctrl.minimaxClient.GenerateImage(ctx, finalBgPrompt)
if err != nil {
logger.Logger.Error("MiniMax bg failed",
zap.String("preset", vc.PresetID),
zap.Error(err),
)
resultCh <- variantResult{PresetID: vc.PresetID, Err: fmt.Sprintf("背景生成失败: %v", err)}
done := int(atomic.AddInt32(&completed, 1))
updateProgress(done)
return
}
// Step 2: MiniMax 生成装饰图
overlayURL, err := ctrl.minimaxClient.GenerateImage(ctx, finalOverlayPrompt)
if err != nil {
logger.Logger.Error("MiniMax overlay failed",
zap.String("preset", vc.PresetID),
zap.Error(err),
)
// 装饰图失败不致命,继续合成
overlayURL = ""
}
// Step 3: compositor 6 层合成(人像由 compositor 叠加在金属层之上)
logger.Logger.Info("Compositor request",
zap.Int("idx", idx),
zap.String("preset", vc.PresetID),
zap.String("bg_url_prefix", safePrefix(bgURL, 50)),
zap.String("cutout_url_prefix", safePrefix(cutoutURL, 50)),
zap.String("overlay_url_prefix", safePrefix(overlayURL, 50)),
)
compReq := service.ComposeRequest{
BackgroundURL: bgURL,
CutoutURL: cutoutURL,
OverlayURL: overlayURL,
GratingConfig: vc.GratingConfig,
ExportWidth: 450,
ExportHeight: 600,
VariantIndex: idx,
OutputOSSKey: ossKey,
}
compResp, err := ctrl.compositorClient.Compose(ctx, compReq)
if err != nil {
logger.Logger.Error("Compositor failed",
zap.String("preset", vc.PresetID),
zap.Error(err),
)
resultCh <- variantResult{PresetID: vc.PresetID, Err: fmt.Sprintf("合成失败: %v", err)}
done := int(atomic.AddInt32(&completed, 1))
updateProgress(done)
return
}
data := map[string]interface{}{
"preset_id": vc.PresetID,
"oss_key": compResp.OSSKey,
"signed_url": compResp.SignedURL,
"width": compResp.Width,
"height": compResp.Height,
}
if compResp.Warning != "" {
data["warning"] = compResp.Warning
}
resultCh <- variantResult{PresetID: vc.PresetID, Data: data}
done := int(atomic.AddInt32(&completed, 1))
updateProgress(done)
}(i, v)
}
wg.Wait()
close(resultCh)
ctrl.jobStore.mu.Lock()
defer ctrl.jobStore.mu.Unlock()
job, ok := ctrl.jobStore.jobs[jobID]
if !ok {
return
}
var warnings []string
hasError := false
for r := range resultCh {
if r.Err != "" {
warnings = append(warnings, fmt.Sprintf("%s: %s", r.PresetID, r.Err))
hasError = true
} else if r.Data != nil {
job.Variants = append(job.Variants, r.Data)
}
}
job.Warnings = warnings
if hasError && len(job.Variants) == 0 {
job.Status = "failed"
if len(warnings) > 0 {
job.Error = warnings[0]
}
} else {
job.Status = "succeeded"
job.Progress = 1.0
// 持久化:更新 materials_snapshot
if ctrl.laserRepo != nil && job.InstanceID > 0 {
snapshot := make(models.MaterialsSnapshot, 0, len(job.Variants)+1)
// 用户原图/抠图记录
if job.CutoutURL != "" {
snapshot = append(snapshot, models.MaterialSnapshotItem{
Role: "cutout",
OssKey: job.CutoutURL,
})
}
for _, v := range job.Variants {
ossKey, _ := v["oss_key"].(string)
presetID, _ := v["preset_id"].(string)
if ossKey != "" {
snapshot = append(snapshot, models.MaterialSnapshotItem{
Role: "composite",
OssKey: ossKey,
PresetID: presetID,
})
}
}
if err := ctrl.laserRepo.UpdateMaterialsSnapshot(job.InstanceID, snapshot); err != nil {
logger.Logger.Warn("Failed to update materials_snapshot",
zap.Int64("instance_id", job.InstanceID),
zap.Error(err),
)
}
_ = ctrl.laserRepo.CreateOperationLogSimple(
job.InstanceID, job.InstanceNo, job.UserID,
models.LaserCardActionGenerateVariants,
models.LaserCardInstanceStatusRendered,
"",
)
}
}
logger.Logger.Info("Generate job completed",
zap.String("job_id", jobID),
zap.Int("variants_count", len(job.Variants)),
zap.Int("warnings", len(warnings)),
)
}
func marshalToString(v interface{}) string {
if v == nil {
return ""
}
b, err := json.Marshal(v)
if err != nil {
return ""
}
return string(b)
}
func safePrefix(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}