124 lines
3.3 KiB
Go
124 lines
3.3 KiB
Go
package controller
|
||
|
||
import (
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/topfans/backend/gateway/config"
|
||
"github.com/topfans/backend/gateway/pkg/response"
|
||
"github.com/topfans/backend/gateway/service"
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// SegmentController 人像抠图代理(密钥仅在服务端,客户端走 multipart 上传)
|
||
type SegmentController struct {
|
||
svc *service.SegmentService
|
||
}
|
||
|
||
func NewSegmentController() *SegmentController {
|
||
cfg := config.Load()
|
||
return &SegmentController{svc: service.NewSegmentService(cfg)}
|
||
}
|
||
|
||
// Portrait POST /api/v1/segment — scene=portrait(客户端 multipart 上传)
|
||
func (ctrl *SegmentController) Portrait(c *gin.Context) {
|
||
scene := strings.TrimSpace(c.DefaultPostForm("scene", "portrait"))
|
||
if scene != "portrait" {
|
||
response.Error(c, http.StatusBadRequest, "本期仅支持 scene=portrait")
|
||
return
|
||
}
|
||
|
||
userID, ok := c.Get("user_id")
|
||
if !ok {
|
||
response.Unauthorized(c, "未登录")
|
||
return
|
||
}
|
||
starID, ok := c.Get("star_id")
|
||
if !ok {
|
||
response.Unauthorized(c, "缺少 star_id")
|
||
return
|
||
}
|
||
|
||
file, header, err := c.Request.FormFile("image")
|
||
if err != nil {
|
||
response.Error(c, http.StatusBadRequest, "请上传 image 字段")
|
||
return
|
||
}
|
||
defer file.Close()
|
||
|
||
limit := int64(service.MaxSegmentImageBytes()) + 1
|
||
data, err := io.ReadAll(io.LimitReader(file, limit))
|
||
if err != nil {
|
||
response.Error(c, http.StatusBadRequest, "读取图片失败")
|
||
return
|
||
}
|
||
if len(data) > service.MaxSegmentImageBytes() {
|
||
response.Error(c, http.StatusBadRequest, "图片超过 5MB")
|
||
return
|
||
}
|
||
|
||
contentType := header.Header.Get("Content-Type")
|
||
result, err := ctrl.svc.Portrait(c.Request.Context(), userID.(int64), starID.(int64), data, contentType)
|
||
if err != nil {
|
||
logger.Logger.Error("segment portrait", zap.Error(err))
|
||
response.Error(c, http.StatusInternalServerError, "抠图服务异常")
|
||
return
|
||
}
|
||
|
||
// 失败返回 200 + success:false;前端须中断 thinking,不做降级
|
||
response.Success(c, result)
|
||
}
|
||
|
||
// segmentJSONReq Dify 工作流调用的 JSON 请求体
|
||
type segmentJSONReq struct {
|
||
ImageURL string `json:"image_url"`
|
||
Scene string `json:"scene"`
|
||
}
|
||
|
||
// PortraitByURL POST /api/v1/segment/json — Dify 工作流调用
|
||
// 接收 JSON body { "image_url": "...", "scene": "portrait" }
|
||
// 无需 JWT 鉴权(仅限内网 Dify→Gateway),使用默认 star_id/user_id 写入 OSS
|
||
func (ctrl *SegmentController) PortraitByURL(c *gin.Context) {
|
||
var req segmentJSONReq
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.BadRequest(c, "参数错误: "+err.Error())
|
||
return
|
||
}
|
||
|
||
scene := strings.TrimSpace(req.Scene)
|
||
if scene == "" {
|
||
scene = "portrait"
|
||
}
|
||
if scene != "portrait" {
|
||
response.Error(c, http.StatusBadRequest, "本期仅支持 scene=portrait")
|
||
return
|
||
}
|
||
|
||
imageURL := strings.TrimSpace(req.ImageURL)
|
||
if imageURL == "" {
|
||
response.BadRequest(c, "image_url 不能为空")
|
||
return
|
||
}
|
||
|
||
logger.Logger.Info("segment json: downloading image", zap.String("url", imageURL[:min(len(imageURL), 80)]))
|
||
|
||
result, err := ctrl.svc.PortraitFromURL(c.Request.Context(), imageURL)
|
||
if err != nil {
|
||
logger.Logger.Error("segment json failed", zap.Error(err))
|
||
response.Error(c, http.StatusInternalServerError, "抠图服务异常: "+err.Error())
|
||
return
|
||
}
|
||
|
||
response.Success(c, result)
|
||
}
|
||
|
||
func min(a, b int) int {
|
||
if a < b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|