169 lines
4.4 KiB
Go
169 lines
4.4 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strconv"
|
|
|
|
"dubbo.apache.org/dubbo-go/v3/client"
|
|
"dubbo.apache.org/dubbo-go/v3/common/constant"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/topfans/backend/pkg/logger"
|
|
"github.com/topfans/backend/gateway/pkg/response"
|
|
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// AIChatController AI Chat HTTP 控制器
|
|
type AIChatController struct {
|
|
aiChatClient *client.Client
|
|
}
|
|
|
|
// NewAIChatController 创建 AIChatController
|
|
func NewAIChatController(aiChatClient *client.Client) (*AIChatController, error) {
|
|
return &AIChatController{
|
|
aiChatClient: aiChatClient,
|
|
}, nil
|
|
}
|
|
|
|
// GetPersonas 获取人设列表
|
|
// @Summary 获取用户的所有人设
|
|
// @Tags ai-chat
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Security BearerAuth
|
|
// @Success 200 {object} response.Response
|
|
// @Router /api/v1/ai-chat/personas [get]
|
|
func (c *AIChatController) GetPersonas(ctx *gin.Context) {
|
|
userIDVal, exists := ctx.Get("user_id")
|
|
if !exists {
|
|
response.Error(ctx, http.StatusUnauthorized, "用户未认证")
|
|
return
|
|
}
|
|
uid, _ := userIDVal.(int64)
|
|
|
|
aiChatService, err := pbAIChat.NewAIChatService(c.aiChatClient)
|
|
if err != nil {
|
|
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
|
|
response.Error(ctx, http.StatusInternalServerError, "服务暂时不可用")
|
|
return
|
|
}
|
|
|
|
// 创建带 Attachments 的 context
|
|
userIDStr := strconv.FormatInt(uid, 10)
|
|
attachments := map[string]interface{}{
|
|
"user_id": userIDStr,
|
|
}
|
|
rpcCtx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
|
|
|
|
req := &pbAIChat.GetPersonasRequest{
|
|
UserId: uid,
|
|
}
|
|
|
|
resp, err := aiChatService.GetPersonas(rpcCtx, req)
|
|
if err != nil {
|
|
logger.Logger.Error("GetPersonas call failed", zap.Error(err))
|
|
response.Error(ctx, http.StatusInternalServerError, "获取人设失败")
|
|
return
|
|
}
|
|
|
|
personas := make([]map[string]interface{}, len(resp.Personas))
|
|
for i, p := range resp.Personas {
|
|
personas[i] = map[string]interface{}{
|
|
"id": p.Id,
|
|
"name": p.Name,
|
|
"description": p.Description,
|
|
"avatar_url": p.AvatarUrl,
|
|
"talk_style": p.TalkStyle,
|
|
"is_default": p.IsDefault,
|
|
"created_at": p.CreatedAt,
|
|
"updated_at": p.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
response.Success(ctx, map[string]interface{}{
|
|
"personas": personas,
|
|
})
|
|
}
|
|
|
|
// GetHistory 获取对话历史
|
|
// @Summary 获取对话历史
|
|
// @Tags ai-chat
|
|
// @Accept json
|
|
// @Produce json
|
|
// @Security BearerAuth
|
|
// @Param sessionId path string true "会话ID"
|
|
// @Success 200 {object} response.Response
|
|
// @Router /api/v1/ai-chat/history/{sessionId} [get]
|
|
func (c *AIChatController) GetHistory(ctx *gin.Context) {
|
|
sessionID := ctx.Param("sessionId")
|
|
if sessionID == "" {
|
|
response.Error(ctx, http.StatusBadRequest, "sessionId不能为空")
|
|
return
|
|
}
|
|
|
|
aiChatService, err := pbAIChat.NewAIChatService(c.aiChatClient)
|
|
if err != nil {
|
|
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
|
|
response.Error(ctx, http.StatusInternalServerError, "服务暂时不可用")
|
|
return
|
|
}
|
|
|
|
req := &pbAIChat.ChatHistoryRequest{
|
|
SessionId: sessionID,
|
|
Limit: 20,
|
|
}
|
|
|
|
// 获取 user_id 和 star_id 用于 Attachments
|
|
userIDVal, exists := ctx.Get("user_id")
|
|
if !exists {
|
|
response.Error(ctx, http.StatusUnauthorized, "用户未认证")
|
|
return
|
|
}
|
|
uid, _ := userIDVal.(int64)
|
|
starIDVal, _ := ctx.Get("star_id")
|
|
sid, _ := starIDVal.(int64)
|
|
|
|
// 创建带 Attachments 的 context
|
|
userIDStr := strconv.FormatInt(uid, 10)
|
|
starIDStr := strconv.FormatInt(sid, 10)
|
|
attachments := map[string]interface{}{
|
|
"user_id": userIDStr,
|
|
"star_id": starIDStr,
|
|
}
|
|
rpcCtx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
|
|
|
|
resp, err := aiChatService.GetHistory(rpcCtx, req)
|
|
if err != nil {
|
|
logger.Logger.Error("GetHistory call failed", zap.Error(err))
|
|
response.Error(ctx, http.StatusInternalServerError, "获取历史失败")
|
|
return
|
|
}
|
|
|
|
history := make([]map[string]string, len(resp.History))
|
|
for i, m := range resp.History {
|
|
history[i] = map[string]string{
|
|
"role": m.Role,
|
|
"content": m.Content,
|
|
}
|
|
}
|
|
|
|
response.Success(ctx, map[string]interface{}{
|
|
"history": history,
|
|
})
|
|
}
|
|
|
|
// extractUserInfoFromContext 从 gin.Context 中提取用户信息
|
|
func extractUserInfoFromContext(ctx *gin.Context) (int64, int64, error) {
|
|
userID, exists := ctx.Get("user_id")
|
|
if !exists {
|
|
return 0, 0, nil
|
|
}
|
|
|
|
starID, _ := ctx.Get("star_id")
|
|
|
|
uid, _ := userID.(int64)
|
|
sid, _ := starID.(int64)
|
|
|
|
return uid, sid, nil
|
|
} |