topfans/backend/gateway/service/segment_service.go
2026-06-03 22:19:22 +08:00

306 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/topfans/backend/gateway/config"
"github.com/topfans/backend/pkg/logger"
"go.uber.org/zap"
)
const (
SegmentErrorCodeFailed = "LC_SEGMENT_FAILED"
maxSegmentImageBytes = 5 * 1024 * 1024
)
// MaxSegmentImageBytes 上传大小上限(与控制器校验一致)
func MaxSegmentImageBytes() int { return maxSegmentImageBytes }
// SegmentPortraitResult 人像抠图结果
type SegmentPortraitResult struct {
Success bool `json:"success"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
CutoutOssKey string `json:"cutout_oss_key,omitempty"`
CutoutURLSigned string `json:"cutout_url_signed,omitempty"`
Provider string `json:"provider,omitempty"`
}
// SegmentService 服务端人像抠图imageseg / IVPD / 自部署 HTTP
type SegmentService struct {
oss *OssHelper
ivpd *IvpdClient
imageseg *ImagesegClient
httpSeg *SegmentHTTPClient
provider string
client *http.Client
}
func NewSegmentService(cfg *config.Config) *SegmentService {
return &SegmentService{
oss: NewOssHelper(cfg.OSS),
ivpd: NewIvpdClient(cfg.OSS),
imageseg: NewImagesegClient(cfg.OSS),
httpSeg: NewSegmentHTTPClient(cfg.Segment.InferenceURL),
provider: strings.ToLower(strings.TrimSpace(cfg.Segment.Provider)),
client: &http.Client{Timeout: 120 * time.Second},
}
}
// Portrait 上传原图 → 抠图 → 结果写入 OSS
func (s *SegmentService) Portrait(ctx context.Context, userID, starID int64, imageData []byte, contentType string) (*SegmentPortraitResult, error) {
if len(imageData) == 0 {
return &SegmentPortraitResult{
Success: false,
ErrorCode: SegmentErrorCodeFailed,
Message: "图片为空",
}, nil
}
if len(imageData) > maxSegmentImageBytes {
return &SegmentPortraitResult{
Success: false,
ErrorCode: SegmentErrorCodeFailed,
Message: "图片超过 5MB",
}, nil
}
ext := "jpg"
ct := strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(ct, "png") {
ext = "png"
}
if ct == "" {
if ext == "png" {
ct = "image/png"
} else {
ct = "image/jpeg"
}
}
fileID := strings.ReplaceAll(uuid.New().String(), "-", "")
outKey := BuildLaserCardCutoutKey(starID, userID, fileID)
return s.doPortrait(ctx, starID, userID, fileID, ext, imageData, ct, outKey)
}
// PortraitFromURL 从 URL 下载图片 → 抠图 → 写入 OSS供 Dify 工作流调用)
// 使用默认的 starID=0, userID=0路径为 laser-card/dify/
func (s *SegmentService) PortraitFromURL(ctx context.Context, imageURL string) (*SegmentPortraitResult, error) {
if !isHTTPURL(imageURL) {
return &SegmentPortraitResult{
Success: false,
ErrorCode: SegmentErrorCodeFailed,
Message: "image_url 格式不正确",
}, nil
}
imageData, err := s.downloadURL(ctx, imageURL)
if err != nil {
return &SegmentPortraitResult{
Success: false,
ErrorCode: SegmentErrorCodeFailed,
Message: "下载图片失败: " + err.Error(),
}, nil
}
if len(imageData) > maxSegmentImageBytes {
return &SegmentPortraitResult{
Success: false,
ErrorCode: SegmentErrorCodeFailed,
Message: "图片超过 5MB",
}, nil
}
fileID := strings.ReplaceAll(uuid.New().String(), "-", "")
// Dify 调用使用固定 star_id=0, user_id=0
outKey := BuildLaserCardCutoutKey(0, 0, fileID)
// 推断 Content-Type
contentType := "image/jpeg"
if len(imageData) > 4 && imageData[0] == 0x89 && imageData[1] == 0x50 && imageData[2] == 0x4e && imageData[3] == 0x47 {
contentType = "image/png"
}
return s.doPortrait(ctx, 0, 0, fileID, "png", imageData, contentType, outKey)
}
func (s *SegmentService) doPortrait(ctx context.Context, starID, userID int64, fileID, ext string, imageData []byte, contentType, outKey string) (*SegmentPortraitResult, error) {
cutout, err := s.inferCutout(ctx, starID, userID, fileID, ext, imageData, contentType)
if err != nil {
logger.Logger.Warn("segment infer failed", zap.Error(err))
return &SegmentPortraitResult{
Success: false,
ErrorCode: SegmentErrorCodeFailed,
Message: err.Error(),
}, nil
}
if err := s.oss.PutObject(outKey, bytes.NewReader(cutout.Bytes), "image/png"); err != nil {
logger.Logger.Warn("segment upload cutout failed", zap.Error(err))
return &SegmentPortraitResult{
Success: false,
ErrorCode: SegmentErrorCodeFailed,
Message: "保存抠图失败: " + truncate(err.Error(), 200),
}, nil
}
outSigned, err := s.oss.SignGetURL(outKey, 3600)
if err != nil {
return &SegmentPortraitResult{
Success: true,
CutoutOssKey: outKey,
Provider: cutout.Provider,
Message: "抠图成功但签名 URL 生成失败",
}, nil
}
return &SegmentPortraitResult{
Success: true,
CutoutOssKey: outKey,
CutoutURLSigned: outSigned,
Provider: cutout.Provider,
}, nil
}
type cutoutInferResult struct {
Bytes []byte
Provider string
}
func (s *SegmentService) inferCutout(ctx context.Context, starID, userID int64, fileID, ext string, imageData []byte, contentType string) (*cutoutInferResult, error) {
inKey := BuildSegmentTempInputKey(starID, userID, fileID, ext)
var inSigned string
if err := s.oss.PutObject(inKey, bytes.NewReader(imageData), contentType); err == nil {
if signed, signErr := s.oss.SignGetURL(inKey, 3600); signErr == nil {
inSigned = signed
}
} else {
logger.Logger.Warn("segment upload input failed", zap.Error(err), zap.String("key", inKey))
}
provider := s.provider
if provider == "" {
provider = "auto"
}
tryImageseg := func() (*cutoutInferResult, error) {
if inSigned == "" {
return nil, fmt.Errorf("分割抠图需要 OSS 签名 URL")
}
outURL, err := s.imageseg.SegmentHDBodyURL(ctx, inSigned)
if err != nil {
return nil, err
}
raw, err := s.downloadURL(ctx, outURL)
if err != nil {
return nil, err
}
return &cutoutInferResult{Bytes: raw, Provider: "imageseg"}, nil
}
tryIVPD := func() (*cutoutInferResult, error) {
if inSigned == "" {
return nil, fmt.Errorf("IVPD 需要 OSS 签名 URL")
}
outURL, err := s.ivpd.SegmentImageURL(ctx, inSigned)
if err != nil {
return nil, err
}
raw, err := s.downloadURL(ctx, outURL)
if err != nil {
return nil, err
}
return &cutoutInferResult{Bytes: raw, Provider: "ivpd"}, nil
}
tryHTTP := func() (*cutoutInferResult, error) {
raw, err := s.httpSeg.RemoveBackground(ctx, imageData, contentType)
if err != nil {
return nil, err
}
return &cutoutInferResult{Bytes: raw, Provider: "http"}, nil
}
var lastErr error
run := func(name string, fn func() (*cutoutInferResult, error)) (*cutoutInferResult, bool) {
res, err := fn()
if err == nil {
return res, true
}
lastErr = err
logger.Logger.Warn("segment provider failed", zap.String("provider", name), zap.Error(err))
return nil, false
}
switch provider {
case "imageseg", "viapi":
if res, ok := run("imageseg", tryImageseg); ok {
return res, nil
}
case "ivpd":
if res, ok := run("imageseg", tryImageseg); ok {
return res, nil
}
if res, ok := run("ivpd", tryIVPD); ok {
return res, nil
}
case "http":
if res, ok := run("http", tryHTTP); ok {
return res, nil
}
default: // auto
if s.httpSeg.enabled() {
if res, ok := run("http", tryHTTP); ok {
return res, nil
}
}
if res, ok := run("imageseg", tryImageseg); ok {
return res, nil
}
if res, ok := run("ivpd", tryIVPD); ok {
return res, nil
}
}
if lastErr != nil {
return nil, lastErr
}
return nil, fmt.Errorf("抠图服务未配置")
}
func (s *SegmentService) downloadURL(ctx context.Context, rawURL string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return nil, err
}
res, err := s.client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %d", res.StatusCode)
}
return io.ReadAll(io.LimitReader(res.Body, maxSegmentImageBytes*2))
}
func isHTTPURL(s string) bool {
s = strings.TrimSpace(s)
return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://")
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n]
}