306 lines
8.2 KiB
Go
306 lines
8.2 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
"github.com/topfans/backend/gateway/config"
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
const (
|
||
SegmentErrorCodeFailed = "LC_SEGMENT_FAILED"
|
||
maxSegmentImageBytes = 5 * 1024 * 1024
|
||
)
|
||
|
||
// MaxSegmentImageBytes 上传大小上限(与控制器校验一致)
|
||
func MaxSegmentImageBytes() int { return maxSegmentImageBytes }
|
||
|
||
// SegmentPortraitResult 人像抠图结果
|
||
type SegmentPortraitResult struct {
|
||
Success bool `json:"success"`
|
||
ErrorCode string `json:"error_code,omitempty"`
|
||
Message string `json:"message,omitempty"`
|
||
CutoutOssKey string `json:"cutout_oss_key,omitempty"`
|
||
CutoutURLSigned string `json:"cutout_url_signed,omitempty"`
|
||
Provider string `json:"provider,omitempty"`
|
||
}
|
||
|
||
// SegmentService 服务端人像抠图(imageseg / IVPD / 自部署 HTTP)
|
||
type SegmentService struct {
|
||
oss *OssHelper
|
||
ivpd *IvpdClient
|
||
imageseg *ImagesegClient
|
||
httpSeg *SegmentHTTPClient
|
||
provider string
|
||
client *http.Client
|
||
}
|
||
|
||
func NewSegmentService(cfg *config.Config) *SegmentService {
|
||
return &SegmentService{
|
||
oss: NewOssHelper(cfg.OSS),
|
||
ivpd: NewIvpdClient(cfg.OSS),
|
||
imageseg: NewImagesegClient(cfg.OSS),
|
||
httpSeg: NewSegmentHTTPClient(cfg.Segment.InferenceURL),
|
||
provider: strings.ToLower(strings.TrimSpace(cfg.Segment.Provider)),
|
||
client: &http.Client{Timeout: 120 * time.Second},
|
||
}
|
||
}
|
||
|
||
// Portrait 上传原图 → 抠图 → 结果写入 OSS
|
||
func (s *SegmentService) Portrait(ctx context.Context, userID, starID int64, imageData []byte, contentType string) (*SegmentPortraitResult, error) {
|
||
if len(imageData) == 0 {
|
||
return &SegmentPortraitResult{
|
||
Success: false,
|
||
ErrorCode: SegmentErrorCodeFailed,
|
||
Message: "图片为空",
|
||
}, nil
|
||
}
|
||
if len(imageData) > maxSegmentImageBytes {
|
||
return &SegmentPortraitResult{
|
||
Success: false,
|
||
ErrorCode: SegmentErrorCodeFailed,
|
||
Message: "图片超过 5MB",
|
||
}, nil
|
||
}
|
||
|
||
ext := "jpg"
|
||
ct := strings.ToLower(strings.TrimSpace(contentType))
|
||
if strings.Contains(ct, "png") {
|
||
ext = "png"
|
||
}
|
||
if ct == "" {
|
||
if ext == "png" {
|
||
ct = "image/png"
|
||
} else {
|
||
ct = "image/jpeg"
|
||
}
|
||
}
|
||
|
||
fileID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||
outKey := BuildLaserCardCutoutKey(starID, userID, fileID)
|
||
|
||
return s.doPortrait(ctx, starID, userID, fileID, ext, imageData, ct, outKey)
|
||
}
|
||
|
||
// PortraitFromURL 从 URL 下载图片 → 抠图 → 写入 OSS(供 Dify 工作流调用)
|
||
// 使用默认的 starID=0, userID=0,路径为 laser-card/dify/
|
||
func (s *SegmentService) PortraitFromURL(ctx context.Context, imageURL string) (*SegmentPortraitResult, error) {
|
||
if !isHTTPURL(imageURL) {
|
||
return &SegmentPortraitResult{
|
||
Success: false,
|
||
ErrorCode: SegmentErrorCodeFailed,
|
||
Message: "image_url 格式不正确",
|
||
}, nil
|
||
}
|
||
|
||
imageData, err := s.downloadURL(ctx, imageURL)
|
||
if err != nil {
|
||
return &SegmentPortraitResult{
|
||
Success: false,
|
||
ErrorCode: SegmentErrorCodeFailed,
|
||
Message: "下载图片失败: " + err.Error(),
|
||
}, nil
|
||
}
|
||
|
||
if len(imageData) > maxSegmentImageBytes {
|
||
return &SegmentPortraitResult{
|
||
Success: false,
|
||
ErrorCode: SegmentErrorCodeFailed,
|
||
Message: "图片超过 5MB",
|
||
}, nil
|
||
}
|
||
|
||
fileID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||
// Dify 调用使用固定 star_id=0, user_id=0
|
||
outKey := BuildLaserCardCutoutKey(0, 0, fileID)
|
||
|
||
// 推断 Content-Type
|
||
contentType := "image/jpeg"
|
||
if len(imageData) > 4 && imageData[0] == 0x89 && imageData[1] == 0x50 && imageData[2] == 0x4e && imageData[3] == 0x47 {
|
||
contentType = "image/png"
|
||
}
|
||
|
||
return s.doPortrait(ctx, 0, 0, fileID, "png", imageData, contentType, outKey)
|
||
}
|
||
|
||
func (s *SegmentService) doPortrait(ctx context.Context, starID, userID int64, fileID, ext string, imageData []byte, contentType, outKey string) (*SegmentPortraitResult, error) {
|
||
|
||
cutout, err := s.inferCutout(ctx, starID, userID, fileID, ext, imageData, contentType)
|
||
if err != nil {
|
||
logger.Logger.Warn("segment infer failed", zap.Error(err))
|
||
return &SegmentPortraitResult{
|
||
Success: false,
|
||
ErrorCode: SegmentErrorCodeFailed,
|
||
Message: err.Error(),
|
||
}, nil
|
||
}
|
||
|
||
if err := s.oss.PutObject(outKey, bytes.NewReader(cutout.Bytes), "image/png"); err != nil {
|
||
logger.Logger.Warn("segment upload cutout failed", zap.Error(err))
|
||
return &SegmentPortraitResult{
|
||
Success: false,
|
||
ErrorCode: SegmentErrorCodeFailed,
|
||
Message: "保存抠图失败: " + truncate(err.Error(), 200),
|
||
}, nil
|
||
}
|
||
|
||
outSigned, err := s.oss.SignGetURL(outKey, 3600)
|
||
if err != nil {
|
||
return &SegmentPortraitResult{
|
||
Success: true,
|
||
CutoutOssKey: outKey,
|
||
Provider: cutout.Provider,
|
||
Message: "抠图成功但签名 URL 生成失败",
|
||
}, nil
|
||
}
|
||
|
||
return &SegmentPortraitResult{
|
||
Success: true,
|
||
CutoutOssKey: outKey,
|
||
CutoutURLSigned: outSigned,
|
||
Provider: cutout.Provider,
|
||
}, nil
|
||
}
|
||
|
||
type cutoutInferResult struct {
|
||
Bytes []byte
|
||
Provider string
|
||
}
|
||
|
||
func (s *SegmentService) inferCutout(ctx context.Context, starID, userID int64, fileID, ext string, imageData []byte, contentType string) (*cutoutInferResult, error) {
|
||
inKey := BuildSegmentTempInputKey(starID, userID, fileID, ext)
|
||
var inSigned string
|
||
if err := s.oss.PutObject(inKey, bytes.NewReader(imageData), contentType); err == nil {
|
||
if signed, signErr := s.oss.SignGetURL(inKey, 3600); signErr == nil {
|
||
inSigned = signed
|
||
}
|
||
} else {
|
||
logger.Logger.Warn("segment upload input failed", zap.Error(err), zap.String("key", inKey))
|
||
}
|
||
|
||
provider := s.provider
|
||
if provider == "" {
|
||
provider = "auto"
|
||
}
|
||
|
||
tryImageseg := func() (*cutoutInferResult, error) {
|
||
if inSigned == "" {
|
||
return nil, fmt.Errorf("分割抠图需要 OSS 签名 URL")
|
||
}
|
||
outURL, err := s.imageseg.SegmentHDBodyURL(ctx, inSigned)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
raw, err := s.downloadURL(ctx, outURL)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &cutoutInferResult{Bytes: raw, Provider: "imageseg"}, nil
|
||
}
|
||
|
||
tryIVPD := func() (*cutoutInferResult, error) {
|
||
if inSigned == "" {
|
||
return nil, fmt.Errorf("IVPD 需要 OSS 签名 URL")
|
||
}
|
||
outURL, err := s.ivpd.SegmentImageURL(ctx, inSigned)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
raw, err := s.downloadURL(ctx, outURL)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &cutoutInferResult{Bytes: raw, Provider: "ivpd"}, nil
|
||
}
|
||
|
||
tryHTTP := func() (*cutoutInferResult, error) {
|
||
raw, err := s.httpSeg.RemoveBackground(ctx, imageData, contentType)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &cutoutInferResult{Bytes: raw, Provider: "http"}, nil
|
||
}
|
||
|
||
var lastErr error
|
||
run := func(name string, fn func() (*cutoutInferResult, error)) (*cutoutInferResult, bool) {
|
||
res, err := fn()
|
||
if err == nil {
|
||
return res, true
|
||
}
|
||
lastErr = err
|
||
logger.Logger.Warn("segment provider failed", zap.String("provider", name), zap.Error(err))
|
||
return nil, false
|
||
}
|
||
|
||
switch provider {
|
||
case "imageseg", "viapi":
|
||
if res, ok := run("imageseg", tryImageseg); ok {
|
||
return res, nil
|
||
}
|
||
case "ivpd":
|
||
if res, ok := run("imageseg", tryImageseg); ok {
|
||
return res, nil
|
||
}
|
||
if res, ok := run("ivpd", tryIVPD); ok {
|
||
return res, nil
|
||
}
|
||
case "http":
|
||
if res, ok := run("http", tryHTTP); ok {
|
||
return res, nil
|
||
}
|
||
default: // auto
|
||
if s.httpSeg.enabled() {
|
||
if res, ok := run("http", tryHTTP); ok {
|
||
return res, nil
|
||
}
|
||
}
|
||
if res, ok := run("imageseg", tryImageseg); ok {
|
||
return res, nil
|
||
}
|
||
if res, ok := run("ivpd", tryIVPD); ok {
|
||
return res, nil
|
||
}
|
||
}
|
||
|
||
if lastErr != nil {
|
||
return nil, lastErr
|
||
}
|
||
return nil, fmt.Errorf("抠图服务未配置")
|
||
}
|
||
|
||
func (s *SegmentService) downloadURL(ctx context.Context, rawURL string) ([]byte, error) {
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
res, err := s.client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer res.Body.Close()
|
||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||
return nil, fmt.Errorf("HTTP %d", res.StatusCode)
|
||
}
|
||
return io.ReadAll(io.LimitReader(res.Body, maxSegmentImageBytes*2))
|
||
}
|
||
|
||
func isHTTPURL(s string) bool {
|
||
s = strings.TrimSpace(s)
|
||
return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://")
|
||
}
|
||
|
||
func truncate(s string, n int) string {
|
||
if len(s) <= n {
|
||
return s
|
||
}
|
||
return s[:n]
|
||
}
|