package service import ( "bytes" "context" "fmt" "io" "net/http" "strings" "time" "github.com/google/uuid" "github.com/topfans/backend/gateway/config" "github.com/topfans/backend/pkg/logger" "go.uber.org/zap" ) const ( SegmentErrorCodeFailed = "LC_SEGMENT_FAILED" maxSegmentImageBytes = 5 * 1024 * 1024 ) // MaxSegmentImageBytes 上传大小上限(与控制器校验一致) func MaxSegmentImageBytes() int { return maxSegmentImageBytes } // SegmentPortraitResult 人像抠图结果 type SegmentPortraitResult struct { Success bool `json:"success"` ErrorCode string `json:"error_code,omitempty"` Message string `json:"message,omitempty"` CutoutOssKey string `json:"cutout_oss_key,omitempty"` CutoutURLSigned string `json:"cutout_url_signed,omitempty"` Provider string `json:"provider,omitempty"` } // SegmentService 服务端人像抠图(imageseg / IVPD / 自部署 HTTP) type SegmentService struct { oss *OssHelper ivpd *IvpdClient imageseg *ImagesegClient httpSeg *SegmentHTTPClient provider string client *http.Client } func NewSegmentService(cfg *config.Config) *SegmentService { return &SegmentService{ oss: NewOssHelper(cfg.OSS), ivpd: NewIvpdClient(cfg.OSS), imageseg: NewImagesegClient(cfg.OSS), httpSeg: NewSegmentHTTPClient(cfg.Segment.InferenceURL), provider: strings.ToLower(strings.TrimSpace(cfg.Segment.Provider)), client: &http.Client{Timeout: 120 * time.Second}, } } // Portrait 上传原图 → 抠图 → 结果写入 OSS func (s *SegmentService) Portrait(ctx context.Context, userID, starID int64, imageData []byte, contentType string) (*SegmentPortraitResult, error) { if len(imageData) == 0 { return &SegmentPortraitResult{ Success: false, ErrorCode: SegmentErrorCodeFailed, Message: "图片为空", }, nil } if len(imageData) > maxSegmentImageBytes { return &SegmentPortraitResult{ Success: false, ErrorCode: SegmentErrorCodeFailed, Message: "图片超过 5MB", }, nil } ext := "jpg" ct := strings.ToLower(strings.TrimSpace(contentType)) if strings.Contains(ct, "png") { ext = "png" } if ct == "" { if ext == "png" { ct = "image/png" } else { ct = "image/jpeg" } } fileID := strings.ReplaceAll(uuid.New().String(), "-", "") outKey := BuildLaserCardCutoutKey(starID, userID, fileID) return s.doPortrait(ctx, starID, userID, fileID, ext, imageData, ct, outKey) } // PortraitFromURL 从 URL 下载图片 → 抠图 → 写入 OSS(供 Dify 工作流调用) // 使用默认的 starID=0, userID=0,路径为 laser-card/dify/ func (s *SegmentService) PortraitFromURL(ctx context.Context, imageURL string) (*SegmentPortraitResult, error) { if !isHTTPURL(imageURL) { return &SegmentPortraitResult{ Success: false, ErrorCode: SegmentErrorCodeFailed, Message: "image_url 格式不正确", }, nil } imageData, err := s.downloadURL(ctx, imageURL) if err != nil { return &SegmentPortraitResult{ Success: false, ErrorCode: SegmentErrorCodeFailed, Message: "下载图片失败: " + err.Error(), }, nil } if len(imageData) > maxSegmentImageBytes { return &SegmentPortraitResult{ Success: false, ErrorCode: SegmentErrorCodeFailed, Message: "图片超过 5MB", }, nil } fileID := strings.ReplaceAll(uuid.New().String(), "-", "") // Dify 调用使用固定 star_id=0, user_id=0 outKey := BuildLaserCardCutoutKey(0, 0, fileID) // 推断 Content-Type contentType := "image/jpeg" if len(imageData) > 4 && imageData[0] == 0x89 && imageData[1] == 0x50 && imageData[2] == 0x4e && imageData[3] == 0x47 { contentType = "image/png" } return s.doPortrait(ctx, 0, 0, fileID, "png", imageData, contentType, outKey) } func (s *SegmentService) doPortrait(ctx context.Context, starID, userID int64, fileID, ext string, imageData []byte, contentType, outKey string) (*SegmentPortraitResult, error) { cutout, err := s.inferCutout(ctx, starID, userID, fileID, ext, imageData, contentType) if err != nil { logger.Logger.Warn("segment infer failed", zap.Error(err)) return &SegmentPortraitResult{ Success: false, ErrorCode: SegmentErrorCodeFailed, Message: err.Error(), }, nil } if err := s.oss.PutObject(outKey, bytes.NewReader(cutout.Bytes), "image/png"); err != nil { logger.Logger.Warn("segment upload cutout failed", zap.Error(err)) return &SegmentPortraitResult{ Success: false, ErrorCode: SegmentErrorCodeFailed, Message: "保存抠图失败: " + truncate(err.Error(), 200), }, nil } outSigned, err := s.oss.SignGetURL(outKey, 3600) if err != nil { return &SegmentPortraitResult{ Success: true, CutoutOssKey: outKey, Provider: cutout.Provider, Message: "抠图成功但签名 URL 生成失败", }, nil } return &SegmentPortraitResult{ Success: true, CutoutOssKey: outKey, CutoutURLSigned: outSigned, Provider: cutout.Provider, }, nil } type cutoutInferResult struct { Bytes []byte Provider string } func (s *SegmentService) inferCutout(ctx context.Context, starID, userID int64, fileID, ext string, imageData []byte, contentType string) (*cutoutInferResult, error) { inKey := BuildSegmentTempInputKey(starID, userID, fileID, ext) var inSigned string if err := s.oss.PutObject(inKey, bytes.NewReader(imageData), contentType); err == nil { if signed, signErr := s.oss.SignGetURL(inKey, 3600); signErr == nil { inSigned = signed } } else { logger.Logger.Warn("segment upload input failed", zap.Error(err), zap.String("key", inKey)) } provider := s.provider if provider == "" { provider = "auto" } tryImageseg := func() (*cutoutInferResult, error) { if inSigned == "" { return nil, fmt.Errorf("分割抠图需要 OSS 签名 URL") } outURL, err := s.imageseg.SegmentHDBodyURL(ctx, inSigned) if err != nil { return nil, err } raw, err := s.downloadURL(ctx, outURL) if err != nil { return nil, err } return &cutoutInferResult{Bytes: raw, Provider: "imageseg"}, nil } tryIVPD := func() (*cutoutInferResult, error) { if inSigned == "" { return nil, fmt.Errorf("IVPD 需要 OSS 签名 URL") } outURL, err := s.ivpd.SegmentImageURL(ctx, inSigned) if err != nil { return nil, err } raw, err := s.downloadURL(ctx, outURL) if err != nil { return nil, err } return &cutoutInferResult{Bytes: raw, Provider: "ivpd"}, nil } tryHTTP := func() (*cutoutInferResult, error) { raw, err := s.httpSeg.RemoveBackground(ctx, imageData, contentType) if err != nil { return nil, err } return &cutoutInferResult{Bytes: raw, Provider: "http"}, nil } var lastErr error run := func(name string, fn func() (*cutoutInferResult, error)) (*cutoutInferResult, bool) { res, err := fn() if err == nil { return res, true } lastErr = err logger.Logger.Warn("segment provider failed", zap.String("provider", name), zap.Error(err)) return nil, false } switch provider { case "imageseg", "viapi": if res, ok := run("imageseg", tryImageseg); ok { return res, nil } case "ivpd": if res, ok := run("imageseg", tryImageseg); ok { return res, nil } if res, ok := run("ivpd", tryIVPD); ok { return res, nil } case "http": if res, ok := run("http", tryHTTP); ok { return res, nil } default: // auto if s.httpSeg.enabled() { if res, ok := run("http", tryHTTP); ok { return res, nil } } if res, ok := run("imageseg", tryImageseg); ok { return res, nil } if res, ok := run("ivpd", tryIVPD); ok { return res, nil } } if lastErr != nil { return nil, lastErr } return nil, fmt.Errorf("抠图服务未配置") } func (s *SegmentService) downloadURL(ctx context.Context, rawURL string) ([]byte, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) if err != nil { return nil, err } res, err := s.client.Do(req) if err != nil { return nil, err } defer res.Body.Close() if res.StatusCode < 200 || res.StatusCode >= 300 { return nil, fmt.Errorf("HTTP %d", res.StatusCode) } return io.ReadAll(io.LimitReader(res.Body, maxSegmentImageBytes*2)) } func isHTTPURL(s string) bool { s = strings.TrimSpace(s) return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://") } func truncate(s string, n int) string { if len(s) <= n { return s } return s[:n] }