- 使用 resize.Lanczos3 替代不存在的 resize.Lanczos - 使用 os.Getenv 替代缺失的 config 方法 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
334 lines
8.3 KiB
Go
334 lines
8.3 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"image"
|
|
"image/gif"
|
|
"image/jpeg"
|
|
"image/png"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/nfnt/resize"
|
|
"github.com/topfans/backend/services/assetService/config"
|
|
dto "github.com/topfans/backend/gateway/dto"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// JobStatus 任务状态
|
|
type JobStatus string
|
|
|
|
const (
|
|
StatusPending JobStatus = "PENDING"
|
|
StatusProcessing JobStatus = "PROCESSING"
|
|
StatusCompleted JobStatus = "COMPLETED"
|
|
StatusFailed JobStatus = "FAILED"
|
|
)
|
|
|
|
// ImageGenerationJob 图生图任务
|
|
type ImageGenerationJob struct {
|
|
JobID string `json:"job_id"`
|
|
UserID int64 `json:"user_id"`
|
|
StarID int64 `json:"star_id"`
|
|
Status JobStatus `json:"status"`
|
|
Progress int `json:"progress"`
|
|
Images []string `json:"images,omitempty"`
|
|
ErrorMsg string `json:"error_msg,omitempty"`
|
|
Request *dto.ImageGenerationRequest `json:"request,omitempty"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
UpdatedAt int64 `json:"updated_at"`
|
|
CompletedAt int64 `json:"completed_at,omitempty"`
|
|
}
|
|
|
|
// MinimaxService MiniMax API 转发服务
|
|
type MinimaxService interface {
|
|
CreateJob(ctx context.Context, userID, starID int64, req *dto.ImageGenerationRequest) (*ImageGenerationJob, error)
|
|
GetJob(ctx context.Context, jobID string, userID, starID int64) (*ImageGenerationJob, error)
|
|
}
|
|
|
|
type minimaxService struct {
|
|
config *config.AssetConfig
|
|
jobs map[string]*ImageGenerationJob
|
|
jobsLock sync.RWMutex
|
|
}
|
|
|
|
// NewMinimaxService 创建 MiniMax 服务
|
|
func NewMinimaxService(cfg *config.AssetConfig) MinimaxService {
|
|
svc := &minimaxService{
|
|
config: cfg,
|
|
jobs: make(map[string]*ImageGenerationJob),
|
|
}
|
|
go svc.cleanupExpiredJobs()
|
|
return svc
|
|
}
|
|
|
|
// CreateJob 创建图生图任务
|
|
func (s *minimaxService) CreateJob(ctx context.Context, userID, starID int64, req *dto.ImageGenerationRequest) (*ImageGenerationJob, error) {
|
|
jobID := uuid.New().String()
|
|
now := time.Now().UnixMilli()
|
|
|
|
job := &ImageGenerationJob{
|
|
JobID: jobID,
|
|
UserID: userID,
|
|
StarID: starID,
|
|
Status: StatusProcessing,
|
|
Progress: 0,
|
|
Request: req,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
s.jobsLock.Lock()
|
|
s.jobs[jobID] = job
|
|
s.jobsLock.Unlock()
|
|
|
|
go s.processJob(job)
|
|
|
|
return job, nil
|
|
}
|
|
|
|
// GetJob 获取任务
|
|
func (s *minimaxService) GetJob(ctx context.Context, jobID string, userID, starID int64) (*ImageGenerationJob, error) {
|
|
s.jobsLock.RLock()
|
|
job, ok := s.jobs[jobID]
|
|
s.jobsLock.RUnlock()
|
|
|
|
if !ok {
|
|
return nil, fmt.Errorf("job not found")
|
|
}
|
|
if job.UserID != userID || job.StarID != starID {
|
|
return nil, fmt.Errorf("access denied")
|
|
}
|
|
|
|
return job, nil
|
|
}
|
|
|
|
// processJob 异步处理任务
|
|
func (s *minimaxService) processJob(job *ImageGenerationJob) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
job.Status = StatusFailed
|
|
job.ErrorMsg = fmt.Sprintf("panic: %v", r)
|
|
job.UpdatedAt = time.Now().UnixMilli()
|
|
}
|
|
}()
|
|
|
|
// 1. 校验 SSRF
|
|
for _, ref := range job.Request.SubjectReference {
|
|
if err := validateURL(ref.ImageFile); err != nil {
|
|
job.Status = StatusFailed
|
|
job.ErrorMsg = "invalid image URL: " + err.Error()
|
|
job.UpdatedAt = time.Now().UnixMilli()
|
|
return
|
|
}
|
|
}
|
|
|
|
// 2. 压缩图片
|
|
processedRefs := make([]dto.SubjectReference, len(job.Request.SubjectReference))
|
|
for i, ref := range job.Request.SubjectReference {
|
|
job.Progress = 10 + i*20
|
|
job.UpdatedAt = time.Now().UnixMilli()
|
|
|
|
compressed, err := s.compressImageIfNeeded(ref.ImageFile)
|
|
if err != nil {
|
|
compressed = ref.ImageFile
|
|
zap.S().Warnf("Image compression failed, using original: %v", err)
|
|
}
|
|
processedRefs[i] = dto.SubjectReference{
|
|
Type: ref.Type,
|
|
ImageFile: compressed,
|
|
}
|
|
}
|
|
|
|
job.Progress = 50
|
|
job.UpdatedAt = time.Now().UnixMilli()
|
|
|
|
// 3. 调用 MiniMax API
|
|
images, err := s.callMiniMaxAPI(job.Request.Model, job.Request.Prompt, job.Request.AspectRatio, processedRefs, job.Request.N)
|
|
if err != nil {
|
|
job.Status = StatusFailed
|
|
job.ErrorMsg = "MiniMax API failed: " + err.Error()
|
|
job.UpdatedAt = time.Now().UnixMilli()
|
|
return
|
|
}
|
|
|
|
job.Progress = 90
|
|
job.UpdatedAt = time.Now().UnixMilli()
|
|
|
|
// 4. 完成
|
|
job.Status = StatusCompleted
|
|
job.Progress = 100
|
|
job.Images = images
|
|
job.CompletedAt = time.Now().UnixMilli()
|
|
job.UpdatedAt = time.Now().UnixMilli()
|
|
}
|
|
|
|
// callMiniMaxAPI 调用 MiniMax API
|
|
func (s *minimaxService) callMiniMaxAPI(model, prompt, aspectRatio string, refs []dto.SubjectReference, n int) ([]string, error) {
|
|
apiURL := os.Getenv("MINIMAX_API_URL")
|
|
apiKey := os.Getenv("MINIMAX_API_KEY")
|
|
|
|
payload := map[string]interface{}{
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"aspect_ratio": aspectRatio,
|
|
"subject_reference": refs,
|
|
"n": n,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := &http.Client{Timeout: 120 * time.Second}
|
|
req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var result struct {
|
|
Images []struct {
|
|
URL string `json:"url"`
|
|
} `json:"images"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
images := make([]string, len(result.Images))
|
|
for i, img := range result.Images {
|
|
images[i] = img.URL
|
|
}
|
|
return images, nil
|
|
}
|
|
|
|
// compressImageIfNeeded 下载并压缩图片
|
|
func (s *minimaxService) compressImageIfNeeded(imageURL string) (string, error) {
|
|
resp, err := http.Get(imageURL)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
imgData, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
img, format, err := image.Decode(bytes.NewReader(imgData))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
bounds := img.Bounds()
|
|
maxDim := uint(1024)
|
|
newWidth := uint(bounds.Dx())
|
|
newHeight := uint(bounds.Dy())
|
|
|
|
if newWidth > maxDim || newHeight > maxDim {
|
|
if newWidth > newHeight {
|
|
ratio := float64(maxDim) / float64(newWidth)
|
|
newWidth = maxDim
|
|
newHeight = uint(float64(newHeight) * ratio)
|
|
} else {
|
|
ratio := float64(maxDim) / float64(newHeight)
|
|
newHeight = maxDim
|
|
newWidth = uint(float64(newWidth) * ratio)
|
|
}
|
|
}
|
|
|
|
if newWidth == uint(bounds.Dx()) && newHeight == uint(bounds.Dy()) {
|
|
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(imgData), nil
|
|
}
|
|
|
|
resized := resize.Thumbnail(newWidth, newHeight, img, resize.Lanczos3)
|
|
|
|
var buf bytes.Buffer
|
|
switch format {
|
|
case "png":
|
|
err = png.Encode(&buf, resized)
|
|
case "gif":
|
|
err = gif.Encode(&buf, resized, nil)
|
|
default:
|
|
err = jpeg.Encode(&buf, resized, &jpeg.Options{Quality: 85})
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
encoded := base64.StdEncoding.EncodeToString(buf.Bytes())
|
|
mimeType := "image/jpeg"
|
|
if format == "png" {
|
|
mimeType = "image/png"
|
|
} else if format == "gif" {
|
|
mimeType = "image/gif"
|
|
}
|
|
return "data:" + mimeType + ";base64," + encoded, nil
|
|
}
|
|
|
|
// validateURL 校验 URL 防止 SSRF
|
|
func validateURL(rawURL string) error {
|
|
if rawURL == "" {
|
|
return nil
|
|
}
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
host := u.Hostname()
|
|
|
|
ip := net.ParseIP(host)
|
|
if ip != nil {
|
|
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() {
|
|
return fmt.Errorf("private IP not allowed: %s", host)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
lowerHost := strings.ToLower(host)
|
|
if strings.HasSuffix(lowerHost, ".local") ||
|
|
strings.HasSuffix(lowerHost, ".internal") ||
|
|
strings.HasSuffix(lowerHost, ".private") {
|
|
return fmt.Errorf("internal domain not allowed: %s", host)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// cleanupExpiredJobs 清理过期任务
|
|
func (s *minimaxService) cleanupExpiredJobs() {
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
for range ticker.C {
|
|
s.jobsLock.Lock()
|
|
now := time.Now().UnixMilli()
|
|
expiredThreshold := int64(24 * 60 * 60 * 1000) // 24h
|
|
for jobID, job := range s.jobs {
|
|
if job.Status == StatusCompleted || job.Status == StatusFailed {
|
|
if now-job.UpdatedAt > expiredThreshold {
|
|
delete(s.jobs, jobID)
|
|
}
|
|
}
|
|
}
|
|
s.jobsLock.Unlock()
|
|
}
|
|
} |