package service import ( "bytes" "context" "encoding/base64" "encoding/json" "fmt" "image" "image/gif" "image/jpeg" "image/png" "io" "net" "net/http" "net/url" "os" "strings" "sync" "time" "github.com/google/uuid" "github.com/nfnt/resize" "github.com/topfans/backend/services/assetService/config" dto "github.com/topfans/backend/gateway/dto" "go.uber.org/zap" ) // JobStatus 任务状态 type JobStatus string const ( StatusPending JobStatus = "PENDING" StatusProcessing JobStatus = "PROCESSING" StatusCompleted JobStatus = "COMPLETED" StatusFailed JobStatus = "FAILED" ) // ImageGenerationJob 图生图任务 type ImageGenerationJob struct { JobID string `json:"job_id"` UserID int64 `json:"user_id"` StarID int64 `json:"star_id"` Status JobStatus `json:"status"` Progress int `json:"progress"` Images []string `json:"images,omitempty"` ErrorMsg string `json:"error_msg,omitempty"` Request *dto.ImageGenerationRequest `json:"request,omitempty"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` CompletedAt int64 `json:"completed_at,omitempty"` } // MinimaxService MiniMax API 转发服务 type MinimaxService interface { CreateJob(ctx context.Context, userID, starID int64, req *dto.ImageGenerationRequest) (*ImageGenerationJob, error) GetJob(ctx context.Context, jobID string, userID, starID int64) (*ImageGenerationJob, error) } type minimaxService struct { config *config.AssetConfig jobs map[string]*ImageGenerationJob jobsLock sync.RWMutex client *http.Client } // NewMinimaxService 创建 MiniMax 服务 func NewMinimaxService(cfg *config.AssetConfig) MinimaxService { svc := &minimaxService{ config: cfg, jobs: make(map[string]*ImageGenerationJob), client: &http.Client{Timeout: 120 * time.Second}, } go svc.cleanupExpiredJobs() return svc } // CreateJob 创建图生图任务 func (s *minimaxService) CreateJob(ctx context.Context, userID, starID int64, req *dto.ImageGenerationRequest) (*ImageGenerationJob, error) { jobID := uuid.New().String() now := time.Now().UnixMilli() job := &ImageGenerationJob{ JobID: jobID, UserID: userID, StarID: starID, Status: StatusProcessing, Progress: 0, Request: req, CreatedAt: now, UpdatedAt: now, } s.jobsLock.Lock() s.jobs[jobID] = job s.jobsLock.Unlock() go s.processJob(job) return job, nil } // GetJob 获取任务 func (s *minimaxService) GetJob(ctx context.Context, jobID string, userID, starID int64) (*ImageGenerationJob, error) { s.jobsLock.RLock() job, ok := s.jobs[jobID] s.jobsLock.RUnlock() if !ok { return nil, fmt.Errorf("job not found") } if job.UserID != userID || job.StarID != starID { return nil, fmt.Errorf("access denied") } // Copy data to avoid race result := *job result.Images = make([]string, len(job.Images)) copy(result.Images, job.Images) return &result, nil } // processJob 异步处理任务 func (s *minimaxService) processJob(job *ImageGenerationJob) { defer func() { if r := recover(); r != nil { job.Status = StatusFailed job.ErrorMsg = fmt.Sprintf("panic: %v", r) job.UpdatedAt = time.Now().UnixMilli() } }() // 1. 校验 SSRF for _, ref := range job.Request.SubjectReference { if err := validateURL(ref.ImageFile); err != nil { job.Status = StatusFailed job.ErrorMsg = "invalid image URL: " + err.Error() job.UpdatedAt = time.Now().UnixMilli() return } } // 2. 压缩图片 processedRefs := make([]dto.SubjectReference, len(job.Request.SubjectReference)) for i, ref := range job.Request.SubjectReference { job.Progress = 10 + i*20 job.UpdatedAt = time.Now().UnixMilli() compressed, err := s.compressImageIfNeeded(ref.ImageFile) if err != nil { compressed = ref.ImageFile zap.S().Warnf("Image compression failed, using original: %v", err) } processedRefs[i] = dto.SubjectReference{ Type: ref.Type, ImageFile: compressed, } } job.Progress = 50 job.UpdatedAt = time.Now().UnixMilli() // 3. 调用 MiniMax API images, err := s.callMiniMaxAPI(job.Request.Model, job.Request.Prompt, job.Request.AspectRatio, processedRefs, job.Request.N) if err != nil { job.Status = StatusFailed job.ErrorMsg = "MiniMax API failed: " + err.Error() job.UpdatedAt = time.Now().UnixMilli() return } job.Progress = 90 job.UpdatedAt = time.Now().UnixMilli() // 4. 完成 job.Status = StatusCompleted job.Progress = 100 job.Images = images job.CompletedAt = time.Now().UnixMilli() job.UpdatedAt = time.Now().UnixMilli() } // callMiniMaxAPI 调用 MiniMax API func (s *minimaxService) callMiniMaxAPI(model, prompt, aspectRatio string, refs []dto.SubjectReference, n int) ([]string, error) { apiURL := os.Getenv("MINIMAX_API_URL") apiKey := os.Getenv("MINIMAX_API_KEY") payload := map[string]interface{}{ "model": model, "prompt": prompt, "aspect_ratio": aspectRatio, "subject_reference": refs, "n": n, } jsonData, err := json.Marshal(payload) if err != nil { return nil, err } req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Content-Type", "application/json") resp, err := s.client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) } var result struct { Images []struct { URL string `json:"url"` } `json:"images"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, err } images := make([]string, len(result.Images)) for i, img := range result.Images { images[i] = img.URL } return images, nil } // compressImageIfNeeded 下载并压缩图片 func (s *minimaxService) compressImageIfNeeded(imageURL string) (string, error) { client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Get(imageURL) if err != nil { return "", err } defer resp.Body.Close() imgData, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) // 10MB limit if err != nil { return "", err } img, format, err := image.Decode(bytes.NewReader(imgData)) if err != nil { return "", err } bounds := img.Bounds() maxDim := uint(1024) newWidth := uint(bounds.Dx()) newHeight := uint(bounds.Dy()) if newWidth > maxDim || newHeight > maxDim { if newWidth > newHeight { ratio := float64(maxDim) / float64(newWidth) newWidth = maxDim newHeight = uint(float64(newHeight) * ratio) } else { ratio := float64(maxDim) / float64(newHeight) newHeight = maxDim newWidth = uint(float64(newWidth) * ratio) } } if newWidth == uint(bounds.Dx()) && newHeight == uint(bounds.Dy()) { return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(imgData), nil } resized := resize.Thumbnail(newWidth, newHeight, img, resize.Lanczos3) var buf bytes.Buffer switch format { case "png": err = png.Encode(&buf, resized) case "gif": err = gif.Encode(&buf, resized, nil) default: err = jpeg.Encode(&buf, resized, &jpeg.Options{Quality: 85}) } if err != nil { return "", err } encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) mimeType := "image/jpeg" if format == "png" { mimeType = "image/png" } else if format == "gif" { mimeType = "image/gif" } return "data:" + mimeType + ";base64," + encoded, nil } // validateURL 校验 URL 防止 SSRF func validateURL(rawURL string) error { if rawURL == "" { return nil } u, err := url.Parse(rawURL) if err != nil { return err } if u.Scheme != "http" && u.Scheme != "https" { return fmt.Errorf("unsupported scheme: %s", u.Scheme) } host := u.Hostname() ip := net.ParseIP(host) if ip != nil { if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() { return fmt.Errorf("private IP not allowed: %s", host) } return nil } lowerHost := strings.ToLower(host) if strings.HasSuffix(lowerHost, ".local") || strings.HasSuffix(lowerHost, ".internal") || strings.HasSuffix(lowerHost, ".private") { return fmt.Errorf("internal domain not allowed: %s", host) } return nil } // cleanupExpiredJobs 清理过期任务 func (s *minimaxService) cleanupExpiredJobs() { ticker := time.NewTicker(1 * time.Hour) for range ticker.C { s.jobsLock.Lock() now := time.Now().UnixMilli() expiredThreshold := int64(24 * 60 * 60 * 1000) // 24h for jobID, job := range s.jobs { if job.Status == StatusCompleted || job.Status == StatusFailed { if now-job.UpdatedAt > expiredThreshold { delete(s.jobs, jobID) } } } s.jobsLock.Unlock() } }