topfans/backend/gateway/controller/segment_controller.go
2026-06-03 22:19:22 +08:00

124 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/topfans/backend/gateway/config"
"github.com/topfans/backend/gateway/pkg/response"
"github.com/topfans/backend/gateway/service"
"github.com/topfans/backend/pkg/logger"
"go.uber.org/zap"
)
// SegmentController 人像抠图代理(密钥仅在服务端,客户端走 multipart 上传)
type SegmentController struct {
svc *service.SegmentService
}
func NewSegmentController() *SegmentController {
cfg := config.Load()
return &SegmentController{svc: service.NewSegmentService(cfg)}
}
// Portrait POST /api/v1/segment — scene=portrait客户端 multipart 上传)
func (ctrl *SegmentController) Portrait(c *gin.Context) {
scene := strings.TrimSpace(c.DefaultPostForm("scene", "portrait"))
if scene != "portrait" {
response.Error(c, http.StatusBadRequest, "本期仅支持 scene=portrait")
return
}
userID, ok := c.Get("user_id")
if !ok {
response.Unauthorized(c, "未登录")
return
}
starID, ok := c.Get("star_id")
if !ok {
response.Unauthorized(c, "缺少 star_id")
return
}
file, header, err := c.Request.FormFile("image")
if err != nil {
response.Error(c, http.StatusBadRequest, "请上传 image 字段")
return
}
defer file.Close()
limit := int64(service.MaxSegmentImageBytes()) + 1
data, err := io.ReadAll(io.LimitReader(file, limit))
if err != nil {
response.Error(c, http.StatusBadRequest, "读取图片失败")
return
}
if len(data) > service.MaxSegmentImageBytes() {
response.Error(c, http.StatusBadRequest, "图片超过 5MB")
return
}
contentType := header.Header.Get("Content-Type")
result, err := ctrl.svc.Portrait(c.Request.Context(), userID.(int64), starID.(int64), data, contentType)
if err != nil {
logger.Logger.Error("segment portrait", zap.Error(err))
response.Error(c, http.StatusInternalServerError, "抠图服务异常")
return
}
// 失败返回 200 + success:false前端须中断 thinking不做降级
response.Success(c, result)
}
// segmentJSONReq Dify 工作流调用的 JSON 请求体
type segmentJSONReq struct {
ImageURL string `json:"image_url"`
Scene string `json:"scene"`
}
// PortraitByURL POST /api/v1/segment/json — Dify 工作流调用
// 接收 JSON body { "image_url": "...", "scene": "portrait" }
// 无需 JWT 鉴权(仅限内网 Dify→Gateway使用默认 star_id/user_id 写入 OSS
func (ctrl *SegmentController) PortraitByURL(c *gin.Context) {
var req segmentJSONReq
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "参数错误: "+err.Error())
return
}
scene := strings.TrimSpace(req.Scene)
if scene == "" {
scene = "portrait"
}
if scene != "portrait" {
response.Error(c, http.StatusBadRequest, "本期仅支持 scene=portrait")
return
}
imageURL := strings.TrimSpace(req.ImageURL)
if imageURL == "" {
response.BadRequest(c, "image_url 不能为空")
return
}
logger.Logger.Info("segment json: downloading image", zap.String("url", imageURL[:min(len(imageURL), 80)]))
result, err := ctrl.svc.PortraitFromURL(c.Request.Context(), imageURL)
if err != nil {
logger.Logger.Error("segment json failed", zap.Error(err))
response.Error(c, http.StatusInternalServerError, "抠图服务异常: "+err.Error())
return
}
response.Success(c, result)
}
func min(a, b int) int {
if a < b {
return a
}
return b
}