From 17c68c9233256cdfccb06d2e3bd7567720b551bc Mon Sep 17 00:00:00 2001 From: zerosaturation Date: Thu, 28 May 2026 12:00:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=A2=9E=E5=8A=A0=E7=9A=84ai=E6=90=AD?= =?UTF-8?q?=E5=AD=90=E5=AF=B9=E8=AF=9D=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 13 + backend/dev.sh | 24 +- backend/gateway/config/config.go | 49 +- .../gateway/controller/ai_chat_controller.go | 169 ++++ backend/gateway/main.go | 11 +- backend/gateway/router/router.go | 24 +- backend/gateway/socket/ai_chat_socket.go | 478 ++++++++++++ backend/go.mod | 1 + backend/go.sum | 2 + backend/go.work | 1 + backend/go.work.sum | 87 +-- backend/pkg/proto/ai_chat/ai_chat.pb.go | 726 ++++++++++++++++++ backend/pkg/proto/ai_chat/ai_chat.triple.go | 258 +++++++ backend/proto/ai_chat.proto | 94 +++ backend/scripts/compile-proto.sh | 18 +- .../migrations/migrate_ai_chat_tables.sql | 150 ++++ .../services/aiChatService/configs/dubbo.yaml | 35 + backend/services/aiChatService/go.mod | 23 + backend/services/aiChatService/main.go | 265 +++++++ .../aiChatService/model/ai_chat_errors.go | 26 + .../aiChatService/model/ai_chat_models.go | 102 +++ .../provider/ai_chat_provider.go | 442 +++++++++++ .../repository/config_repository.go | 60 ++ .../repository/memory_repository.go | 156 ++++ .../repository/persona_repository.go | 122 +++ .../aiChatService/service/audit_service.go | 59 ++ .../aiChatService/service/chat_service.go | 60 ++ .../aiChatService/service/llm_service.go | 312 ++++++++ .../aiChatService/service/memory_service.go | 211 +++++ .../aiChatService/service/persona_service.go | 68 ++ .../aiChatService/service/prompt_builder.go | 180 +++++ frontend/App.vue | 20 + frontend/pages/ai-dazi/index.vue | 417 ++++++++-- frontend/utils/api.js | 71 +- frontend/utils/socket/AiChatSocket.js | 187 +++++ frontend/utils/socket/GlobalSocketManager.js | 50 ++ frontend/utils/socket/SocketManager.js | 345 +++++++++ frontend/utils/socket/index.js | 3 + 38 files changed, 5125 insertions(+), 194 deletions(-) create mode 100644 backend/gateway/controller/ai_chat_controller.go create mode 100644 backend/gateway/socket/ai_chat_socket.go create mode 100644 backend/pkg/proto/ai_chat/ai_chat.pb.go create mode 100644 backend/pkg/proto/ai_chat/ai_chat.triple.go create mode 100644 backend/proto/ai_chat.proto create mode 100644 backend/scripts/migrations/migrate_ai_chat_tables.sql create mode 100644 backend/services/aiChatService/configs/dubbo.yaml create mode 100644 backend/services/aiChatService/go.mod create mode 100644 backend/services/aiChatService/main.go create mode 100644 backend/services/aiChatService/model/ai_chat_errors.go create mode 100644 backend/services/aiChatService/model/ai_chat_models.go create mode 100644 backend/services/aiChatService/provider/ai_chat_provider.go create mode 100644 backend/services/aiChatService/repository/config_repository.go create mode 100644 backend/services/aiChatService/repository/memory_repository.go create mode 100644 backend/services/aiChatService/repository/persona_repository.go create mode 100644 backend/services/aiChatService/service/audit_service.go create mode 100644 backend/services/aiChatService/service/chat_service.go create mode 100644 backend/services/aiChatService/service/llm_service.go create mode 100644 backend/services/aiChatService/service/memory_service.go create mode 100644 backend/services/aiChatService/service/persona_service.go create mode 100644 backend/services/aiChatService/service/prompt_builder.go create mode 100644 frontend/utils/socket/AiChatSocket.js create mode 100644 frontend/utils/socket/GlobalSocketManager.js create mode 100644 frontend/utils/socket/SocketManager.js create mode 100644 frontend/utils/socket/index.js diff --git a/.gitignore b/.gitignore index e0ca283..73076b2 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,19 @@ frontend/pages/.DS_Store frontend/.hbuilderx/launch.json .idea +# Backend 构建产物 +backend/services/activityService/activityService +backend/services/assetService/assetService +backend/services/galleryService/galleryService +backend/services/socialService/socialService +backend/services/starbookService/starbookService +backend/services/taskService/taskService +backend/services/userService/userService +backend/services/aiChatService/aiChatService +backend/gateway/gateway +backend/gateway/aiChatService +bin/ + node_modules package-lock.json # Added by code-review-graph diff --git a/backend/dev.sh b/backend/dev.sh index 1d04dd4..22c937e 100755 --- a/backend/dev.sh +++ b/backend/dev.sh @@ -21,7 +21,7 @@ cleanup() { fi # 清理所有 PID 文件并杀服务进程 - for service in gateway activityService galleryService socialService assetService userService taskService starbookService; do + for service in gateway activityService galleryService socialService assetService userService taskService starbookService aiChatService; do pkill -9 -f "$service" 2>/dev/null || true rm -f "/tmp/dev_sh_${service}.pid" "/tmp/dev_sh_${service}_restart" "/tmp/dev_sh_${service}.lock" echo -e "${YELLOW} 🛑 $service 已停止${NC}" @@ -82,15 +82,16 @@ if [ -f "$ENV_FILE" ]; then fi DB_HOST="${DB_HOST:-localhost}" -DB_PORT="${DB_PORT:-5432}" -DB_USER="${DB_USER:-haihuizhu}" -DB_PASSWORD="${DB_PASSWORD:-admin}" +DB_PORT="${DB_PORT:-15432}" +DB_USER="${DB_USER:-postgres}" +DB_PASSWORD="${DB_PASSWORD:-123456}" DB_NAME="${DB_NAME:-top-fans}" REDIS_HOST="${REDIS_HOST:-localhost}" REDIS_PORT="${REDIS_PORT:-6379}" +REDIS_PASSWORD="${REDIS_PASSWORD:-123456}" REDIS_DB="${REDIS_DB:-0}" DB_ARGS=(-db-host="$DB_HOST" -db-port="$DB_PORT" -db-user="$DB_USER" -db-password="$DB_PASSWORD" -db-name="$DB_NAME") -REDIS_ARGS=(-redis-host="$REDIS_HOST" -redis-port="$REDIS_PORT" -redis-db="$REDIS_DB") +REDIS_ARGS=(-redis-host="$REDIS_HOST" -redis-port="$REDIS_PORT" -redis-db="$REDIS_DB" -redis-password="$REDIS_PASSWORD") # 启动一个服务 # 用法: start_service name binary port use_db use_redis @@ -280,7 +281,8 @@ start_watcher() { --exclude='galleryService$' \ --exclude='activityService$' \ --exclude='taskService$' \ - --exclude='starbookService$' & + --exclude='starbookService$' \ + --exclude='aiChatService$' & fi done else @@ -295,7 +297,8 @@ start_watcher() { --exclude='galleryService$' \ --exclude='activityService$' \ --exclude='taskService$' \ - --exclude='starbookService$' & + --exclude='starbookService$' \ + --exclude='aiChatService$' & fi wait ) & @@ -359,13 +362,13 @@ echo "" > /tmp/dev_sh_watchers.tmp # 清理残留 PID 文件(上次非正常退出可能留下) -for service in activityService galleryService socialService assetService userService taskService gateway starbookService; do +for service in activityService galleryService socialService assetService userService taskService gateway starbookService aiChatService; do rm -f "/tmp/dev_sh_${service}.pid" "/tmp/dev_sh_${service}_restart" done # 停止现有服务(清理环境) echo -e "${YELLOW}🛑 停止现有服务...${NC}" -for service in gateway userService socialService assetService galleryService activityService taskService starbookService; do +for service in gateway userService socialService assetService galleryService activityService taskService starbookService aiChatService; do pkill -9 -f "$service" 2>/dev/null || true done sleep 1 @@ -396,6 +399,7 @@ build_service "galleryService" "services/galleryService" "services/galleryServic build_service "activityService" "services/activityService" "services/activityService/activityService" build_service "taskService" "services/taskService" "services/taskService/taskService" build_service "starbookService" "services/starbookService" "services/starbookService/starbookService" +build_service "aiChatService" "services/aiChatService" "services/aiChatService/aiChatService" cd "$SCRIPT_DIR" # 启动所有服务 @@ -413,6 +417,7 @@ echo -e "${GREEN}✅ galleryService 已启动 (PID: $(cat /tmp/dev_sh_gallerySer start_service "activityService" "services/activityService/activityService" 20005 1 0 start_service "taskService" "services/taskService/taskService" 20006 1 0 start_service "starbookService" "services/starbookService/starbookService" 20007 1 0 +start_service "aiChatService" "services/aiChatService/aiChatService" 20008 1 1 start_service "gateway" "gateway/gateway" 8080 0 0 # 启动所有文件监听器 @@ -426,6 +431,7 @@ start_watcher "galleryService" "services/galleryService" "services/gallerySer start_watcher "activityService" "services/activityService" "services/activityService/activityService" 20005 1 0 start_watcher "taskService" "services/taskService" "services/taskService/taskService" 20006 1 0 start_watcher "starbookService" "services/starbookService" "services/starbookService/starbookService" 20007 1 0 +start_watcher "aiChatService" "services/aiChatService:pkg/proto" "services/aiChatService/aiChatService" 20008 1 1 echo "" echo -e "${GREEN}========================================${NC}" diff --git a/backend/gateway/config/config.go b/backend/gateway/config/config.go index b91430e..ad348c3 100644 --- a/backend/gateway/config/config.go +++ b/backend/gateway/config/config.go @@ -8,12 +8,13 @@ import ( // Config 网关配置 type Config struct { - Server ServerConfig - Dubbo DubboConfig - JWT JWTConfig - OSS OSSConfig - Redis RedisConfig - Root string + Server ServerConfig + Dubbo DubboConfig + JWT JWTConfig + OSS OSSConfig + Redis RedisConfig + WebSocket WebSocketConfig + Root string } // RedisConfig Redis 配置 @@ -33,12 +34,13 @@ type ServerConfig struct { // DubboConfig Dubbo 服务配置 type DubboConfig struct { UserServiceURL string - SocialServiceURL string - AssetServiceURL string - GalleryServiceURL string - ActivityServiceURL string - TaskServiceURL string - StarbookServiceURL string + SocialServiceURL string + AssetServiceURL string + GalleryServiceURL string + ActivityServiceURL string + TaskServiceURL string + StarbookServiceURL string + AIChatServiceURL string } // JWTConfig JWT 配置 @@ -70,6 +72,11 @@ func (c *OSSConfig) GetUploadDir(uploadType string) string { } } +// WebSocketConfig WebSocket 配置 +type WebSocketConfig struct { + AIChatPath string // WebSocket 路径,默认 /ws/ai-chat +} + // Load 加载配置 func Load() *Config { root, _ := os.Getwd() @@ -81,12 +88,13 @@ func Load() *Config { }, Dubbo: DubboConfig{ UserServiceURL: getEnv("DUBBO_USER_SERVICE_URL", "tri://127.0.0.1:20000"), - SocialServiceURL: getEnv("DUBBO_SOCIAL_SERVICE_URL", "tri://127.0.0.1:20002"), - AssetServiceURL: getEnv("DUBBO_ASSET_SERVICE_URL", "tri://127.0.0.1:20003"), - GalleryServiceURL: getEnv("DUBBO_GALLERY_SERVICE_URL", "tri://127.0.0.1:20004"), - ActivityServiceURL: getEnv("DUBBO_ACTIVITY_SERVICE_URL", "tri://127.0.0.1:20005"), - TaskServiceURL: getEnv("DUBBO_TASK_SERVICE_URL", "tri://127.0.0.1:20006"), - StarbookServiceURL: getEnv("DUBBO_STARBOOK_SERVICE_URL", "tri://127.0.0.1:20007"), + SocialServiceURL: getEnv("DUBBO_SOCIAL_SERVICE_URL", "tri://127.0.0.1:20002"), + AssetServiceURL: getEnv("DUBBO_ASSET_SERVICE_URL", "tri://127.0.0.1:20003"), + GalleryServiceURL: getEnv("DUBBO_GALLERY_SERVICE_URL", "tri://127.0.0.1:20004"), + ActivityServiceURL: getEnv("DUBBO_ACTIVITY_SERVICE_URL", "tri://127.0.0.1:20005"), + TaskServiceURL: getEnv("DUBBO_TASK_SERVICE_URL", "tri://127.0.0.1:20006"), + StarbookServiceURL: getEnv("DUBBO_STARBOOK_SERVICE_URL", "tri://127.0.0.1:20007"), + AIChatServiceURL: getEnv("DUBBO_AI_CHAT_SERVICE_URL", "tri://127.0.0.1:20008"), }, JWT: JWTConfig{ Secret: getEnv("JWT_SECRET", "topfans-secret-key-please-change-in-production"), @@ -107,6 +115,9 @@ func Load() *Config { Password: getEnv("REDIS_PASSWORD", "123456"), DB: getEnvInt("REDIS_DB", 0), }, + WebSocket: WebSocketConfig{ + AIChatPath: getEnv("WS_AI_CHAT_PATH", "/ws/ai-chat"), + }, } } @@ -159,4 +170,4 @@ func (c *Config) Validate() error { return fmt.Errorf("JWT secret is required") } return nil -} +} \ No newline at end of file diff --git a/backend/gateway/controller/ai_chat_controller.go b/backend/gateway/controller/ai_chat_controller.go new file mode 100644 index 0000000..e8c08f7 --- /dev/null +++ b/backend/gateway/controller/ai_chat_controller.go @@ -0,0 +1,169 @@ +package controller + +import ( + "context" + "net/http" + "strconv" + + "dubbo.apache.org/dubbo-go/v3/client" + "dubbo.apache.org/dubbo-go/v3/common/constant" + "github.com/gin-gonic/gin" + "github.com/topfans/backend/pkg/logger" + "github.com/topfans/backend/gateway/pkg/response" + pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat" + "go.uber.org/zap" +) + +// AIChatController AI Chat HTTP 控制器 +type AIChatController struct { + aiChatClient *client.Client +} + +// NewAIChatController 创建 AIChatController +func NewAIChatController(aiChatClient *client.Client) (*AIChatController, error) { + return &AIChatController{ + aiChatClient: aiChatClient, + }, nil +} + +// GetPersonas 获取人设列表 +// @Summary 获取用户的所有人设 +// @Tags ai-chat +// @Accept json +// @Produce json +// @Security BearerAuth +// @Success 200 {object} response.Response +// @Router /api/v1/ai-chat/personas [get] +func (c *AIChatController) GetPersonas(ctx *gin.Context) { + userIDVal, exists := ctx.Get("user_id") + if !exists { + response.Error(ctx, http.StatusUnauthorized, "用户未认证") + return + } + uid, _ := userIDVal.(int64) + + aiChatService, err := pbAIChat.NewAIChatService(c.aiChatClient) + if err != nil { + logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) + response.Error(ctx, http.StatusInternalServerError, "服务暂时不可用") + return + } + + // 创建带 Attachments 的 context + userIDStr := strconv.FormatInt(uid, 10) + attachments := map[string]interface{}{ + "user_id": userIDStr, + } + rpcCtx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) + + req := &pbAIChat.GetPersonasRequest{ + UserId: uid, + } + + resp, err := aiChatService.GetPersonas(rpcCtx, req) + if err != nil { + logger.Logger.Error("GetPersonas call failed", zap.Error(err)) + response.Error(ctx, http.StatusInternalServerError, "获取人设失败") + return + } + + personas := make([]map[string]interface{}, len(resp.Personas)) + for i, p := range resp.Personas { + personas[i] = map[string]interface{}{ + "id": p.Id, + "name": p.Name, + "description": p.Description, + "avatar_url": p.AvatarUrl, + "talk_style": p.TalkStyle, + "is_default": p.IsDefault, + "created_at": p.CreatedAt, + "updated_at": p.UpdatedAt, + } + } + + response.Success(ctx, map[string]interface{}{ + "personas": personas, + }) +} + +// GetHistory 获取对话历史 +// @Summary 获取对话历史 +// @Tags ai-chat +// @Accept json +// @Produce json +// @Security BearerAuth +// @Param sessionId path string true "会话ID" +// @Success 200 {object} response.Response +// @Router /api/v1/ai-chat/history/{sessionId} [get] +func (c *AIChatController) GetHistory(ctx *gin.Context) { + sessionID := ctx.Param("sessionId") + if sessionID == "" { + response.Error(ctx, http.StatusBadRequest, "sessionId不能为空") + return + } + + aiChatService, err := pbAIChat.NewAIChatService(c.aiChatClient) + if err != nil { + logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) + response.Error(ctx, http.StatusInternalServerError, "服务暂时不可用") + return + } + + req := &pbAIChat.ChatHistoryRequest{ + SessionId: sessionID, + Limit: 20, + } + + // 获取 user_id 和 star_id 用于 Attachments + userIDVal, exists := ctx.Get("user_id") + if !exists { + response.Error(ctx, http.StatusUnauthorized, "用户未认证") + return + } + uid, _ := userIDVal.(int64) + starIDVal, _ := ctx.Get("star_id") + sid, _ := starIDVal.(int64) + + // 创建带 Attachments 的 context + userIDStr := strconv.FormatInt(uid, 10) + starIDStr := strconv.FormatInt(sid, 10) + attachments := map[string]interface{}{ + "user_id": userIDStr, + "star_id": starIDStr, + } + rpcCtx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) + + resp, err := aiChatService.GetHistory(rpcCtx, req) + if err != nil { + logger.Logger.Error("GetHistory call failed", zap.Error(err)) + response.Error(ctx, http.StatusInternalServerError, "获取历史失败") + return + } + + history := make([]map[string]string, len(resp.History)) + for i, m := range resp.History { + history[i] = map[string]string{ + "role": m.Role, + "content": m.Content, + } + } + + response.Success(ctx, map[string]interface{}{ + "history": history, + }) +} + +// extractUserInfoFromContext 从 gin.Context 中提取用户信息 +func extractUserInfoFromContext(ctx *gin.Context) (int64, int64, error) { + userID, exists := ctx.Get("user_id") + if !exists { + return 0, 0, nil + } + + starID, _ := ctx.Get("star_id") + + uid, _ := userID.(int64) + sid, _ := starID.(int64) + + return uid, sid, nil +} \ No newline at end of file diff --git a/backend/gateway/main.go b/backend/gateway/main.go index f4ff7e1..7a25d27 100644 --- a/backend/gateway/main.go +++ b/backend/gateway/main.go @@ -146,9 +146,18 @@ func main() { } logger.Logger.Info("Starbook Service Dubbo client connected successfully") + // 4.8 AIChatService Client + aiChatClient, err := client.NewClient( + client.WithClientURL(cfg.Dubbo.AIChatServiceURL), + ) + if err != nil { + logger.Logger.Fatal("Failed to create AI Chat Service Dubbo client", zap.Error(err)) + } + logger.Logger.Info("AI Chat Service Dubbo client connected successfully") + // 5. 设置路由 logger.Logger.Info("Setting up routes...") - r, err := router.SetupRouter(userClient, socialClient, assetClient, galleryClient, activityClient, taskClient, starbookClient) + r, err := router.SetupRouter(userClient, socialClient, assetClient, galleryClient, activityClient, taskClient, starbookClient, aiChatClient, cfg.WebSocket.AIChatPath) if err != nil { logger.Logger.Fatal("Failed to setup router", zap.Error(err)) } diff --git a/backend/gateway/router/router.go b/backend/gateway/router/router.go index 584ef4a..8c1b69d 100644 --- a/backend/gateway/router/router.go +++ b/backend/gateway/router/router.go @@ -1,6 +1,8 @@ package router import ( + "net/http" + "dubbo.apache.org/dubbo-go/v3/client" "github.com/gin-gonic/gin" swaggerfiles "github.com/swaggo/files" @@ -8,10 +10,11 @@ import ( _ "github.com/topfans/backend/gateway/docs" // 导入 docs 包以初始化 Swagger "github.com/topfans/backend/gateway/controller" "github.com/topfans/backend/gateway/middleware" + "github.com/topfans/backend/gateway/socket" ) // SetupRouter 设置路由 -func SetupRouter(userClient *client.Client, socialClient *client.Client, assetClient *client.Client, galleryClient *client.Client, activityClient *client.Client, taskClient *client.Client, starbookClient *client.Client) (*gin.Engine, error) { +func SetupRouter(userClient *client.Client, socialClient *client.Client, assetClient *client.Client, galleryClient *client.Client, activityClient *client.Client, taskClient *client.Client, starbookClient *client.Client, aiChatClient *client.Client, aiChatPath string) (*gin.Engine, error) { r := gin.Default() // 全局中间件 @@ -30,6 +33,12 @@ func SetupRouter(userClient *client.Client, socialClient *client.Client, assetCl swaggerHandler := ginSwagger.WrapHandler(swaggerfiles.Handler) r.GET("/swagger/*any", swaggerHandler) + // AI Chat WebSocket 路由 + aiChatHub := socket.NewHub(aiChatClient, aiChatPath) + r.GET(aiChatPath, gin.WrapF(func(w http.ResponseWriter, r *http.Request) { + aiChatHub.HandleWebSocket(w, r) + })) + // 创建控制器 authCtrl, err := controller.NewAuthController(userClient) if err != nil { @@ -76,6 +85,11 @@ func SetupRouter(userClient *client.Client, socialClient *client.Client, assetCl return nil, err } + aiChatCtrl, err := controller.NewAIChatController(aiChatClient) + if err != nil { + return nil, err + } + // API v1 路由组 v1 := r.Group("/api/v1") { @@ -287,6 +301,14 @@ func SetupRouter(userClient *client.Client, socialClient *client.Client, assetCl starbook.GET("/home", starbookCtrl.GetStarbookHome) // 获取星册首页 starbook.GET("/items", starbookCtrl.GetStarbookItems) // 获取星册藏品列表 } + + // AI Chat 相关路由(需要认证) + aiChat := v1.Group("/ai-chat") + aiChat.Use(middleware.AuthMiddleware()) + { + aiChat.GET("/personas", aiChatCtrl.GetPersonas) // 获取人设列表 + aiChat.GET("/history/:sessionId", aiChatCtrl.GetHistory) // 获取对话历史 + } } return r, nil diff --git a/backend/gateway/socket/ai_chat_socket.go b/backend/gateway/socket/ai_chat_socket.go new file mode 100644 index 0000000..bee3ee4 --- /dev/null +++ b/backend/gateway/socket/ai_chat_socket.go @@ -0,0 +1,478 @@ +package socket + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "dubbo.apache.org/dubbo-go/v3/client" + "dubbo.apache.org/dubbo-go/v3/common/constant" + "github.com/gorilla/websocket" + "github.com/topfans/backend/pkg/jwt" + "github.com/topfans/backend/pkg/logger" + "go.uber.org/zap" + + pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // 允许跨域 + }, +} + +// Hub 管理所有 AI Chat WebSocket 连接 +type Hub struct { + // 用户连接映射: userId -> *Connection + clients map[int64]*Connection + + // Dubbo 客户端 + aiChatClient *client.Client + + // WebSocket 配置 + aiChatPath string + + mu sync.RWMutex +} + +// Connection WebSocket 连接 +type Connection struct { + UserID int64 + StarID int64 + Conn *websocket.Conn + Send chan []byte + Hub *Hub + writeMu sync.Mutex // 保护并发写操作 +} + +// writeJSON 线程安全的 WriteJSON 封装 +func (c *Connection) writeJSON(data interface{}) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + return c.Conn.WriteJSON(data) +} + +// sendError 发送错误消息 +func (c *Connection) sendError(code, message string) { + c.Send <- []byte(fmt.Sprintf(`{"type":"error","code":"%s","message":"%s","session_id":"%d_%d"}`, + code, message, c.UserID, c.StarID)) +} + +// NewHub 创建 Hub 实例 +func NewHub(aiChatClient *client.Client, aiChatPath string) *Hub { + return &Hub{ + clients: make(map[int64]*Connection), + aiChatClient: aiChatClient, + aiChatPath: aiChatPath, + } +} + +// HandleWebSocket 处理 WebSocket 连接 +func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + // 从 URL 参数获取 token + token := r.URL.Query().Get("token") + if token == "" { + // 尝试从 query 参数获取(URL 编码后) + token = r.URL.Query().Get("token") + } + + // 解析 token 获取用户信息 + userID, starID, err := h.validateToken(token) + if err != nil { + logger.Logger.Error("WebSocket token validation failed", zap.Error(err)) + // 返回 401 Unauthorized,而不是尝试升级 WebSocket + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]interface{}{ + "type": "auth_response", + "success": false, + "error": "invalid_token", + }) + return + } + + // 升级为 WebSocket 连接 + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Logger.Error("WebSocket upgrade failed", zap.Error(err)) + return + } + + connection := &Connection{ + UserID: userID, + StarID: starID, + Conn: conn, + Send: make(chan []byte, 256), + Hub: h, + } + + // 注册连接 + h.mu.Lock() + h.clients[userID] = connection + h.mu.Unlock() + + logger.Logger.Info("WebSocket connection established", + zap.Int64("user_id", userID), + zap.Int64("star_id", starID), + ) + + // 发送鉴权成功响应 + authResp := map[string]interface{}{ + "type": "auth_response", + "success": true, + "user_id": userID, + "star_id": starID, + } + conn.WriteJSON(authResp) + + // 启动读 goroutine + go connection.readPump() + + // 启动写 goroutine + go connection.writePump() +} + +// validateToken 验证 token(使用 JWT 解析) +func (h *Hub) validateToken(token string) (int64, int64, error) { + // 去掉 "Bearer_" 前缀 + if strings.HasPrefix(token, "Bearer_") { + token = strings.TrimPrefix(token, "Bearer_") + } + + // 解析 JWT token + claims, err := jwt.ParseToken(token) + if err != nil { + return 0, 0, fmt.Errorf("failed to parse token: %w", err) + } + + userID := claims.UserID + starID := claims.StarID + + if userID == 0 { + return 0, 0, errors.New("invalid user id") + } + + return userID, starID, nil +} + +// readPump 读取客户端消息 +func (c *Connection) readPump() { + defer func() { + c.Hub.mu.Lock() + delete(c.Hub.clients, c.UserID) + c.Hub.mu.Unlock() + c.Conn.Close() + }() + + c.Conn.SetReadLimit(512 * 1024) // 512KB + c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.Conn.SetPongHandler(func(string) error { + c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + for { + _, message, err := c.Conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + logger.Logger.Error("WebSocket read error", zap.Error(err)) + } + break + } + + // 解析消息 + var msg map[string]interface{} + if err := json.Unmarshal(message, &msg); err != nil { + logger.Logger.Error("Failed to parse message", zap.Error(err)) + continue + } + + // 处理消息 + c.handleMessage(msg) + } +} + +// writePump 写消息到客户端 +func (c *Connection) writePump() { + ticker := time.NewTicker(30 * time.Second) + defer func() { + ticker.Stop() + c.Conn.Close() + }() + + for { + select { + case message, ok := <-c.Send: + c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if !ok { + c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.Conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + w.Write(message) + + if err := w.Close(); err != nil { + return + } + + case <-ticker.C: + c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +// handleMessage 处理客户端消息 +func (c *Connection) handleMessage(msg map[string]interface{}) { + action, _ := msg["action"].(string) + sessionID, _ := msg["session_id"].(string) + message, _ := msg["message"].(string) + personaID, _ := msg["persona_id"].(string) + + session := sessionID + if session == "" { + session = fmt.Sprintf("%d_%d", c.UserID, c.StarID) + } + + switch action { + case "ping": + // 心跳 - 通过 Send 通道发送,避免并发写 + c.Send <- []byte(`{"type":"pong"}`) + + case "init_session": + // 初始化会话,返回欢迎消息 + go c.initSession(session) + + case "send_message": + // 发送消息 + go c.sendMessage(session, message, personaID) + + case "get_history": + // 获取历史 + go c.getHistory(session) + + case "get_personas": + // 获取人设列表 + go c.getPersonas() + + default: + logger.Logger.Warn("Unknown action", zap.String("action", action)) + } +} + +// initSession 初始化会话,返回欢迎消息 +func (c *Connection) initSession(sessionID string) { + aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) + if err != nil { + logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) + c.sendError("SERVICE_ERROR", "服务暂时不可用") + return + } + + // 创建带 Attachments 的 context + userIDStr := strconv.FormatInt(c.UserID, 10) + starIDStr := strconv.FormatInt(c.StarID, 10) + attachments := map[string]interface{}{ + "user_id": userIDStr, + "star_id": starIDStr, + } + ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) + + req := &pbAIChat.InitSessionRequest{ + SessionId: sessionID, + UserId: c.UserID, + } + + resp, err := aiChatService.InitSession(ctx, req) + if err != nil { + logger.Logger.Error("InitSession call failed", zap.Error(err)) + c.sendError("SERVICE_ERROR", "初始化会话失败") + return + } + + // 发送欢迎消息 + c.writeJSON(map[string]interface{}{ + "type": "message", + "session_id": sessionID, + "content": resp.WelcomeMessage, + "is_end": true, + }) +} + +// sendMessage 发送消息到 AI Chat Service +func (c *Connection) sendMessage(sessionID, message, personaID string) { + // 创建 Dubbo 客户端调用 + aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) + if err != nil { + logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) + c.sendError("SERVICE_ERROR", "服务暂时不可用") + return + } + + // 创建带 Attachments 的 context + userIDStr := strconv.FormatInt(c.UserID, 10) + starIDStr := strconv.FormatInt(c.StarID, 10) + attachments := map[string]interface{}{ + "user_id": userIDStr, + "star_id": starIDStr, + } + ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) + + req := &pbAIChat.ChatMessageRequest{ + SessionId: sessionID, + Message: message, + PersonaId: personaID, + UserId: c.UserID, + } + + // 流式调用 + stream, err := aiChatService.SendMessage(ctx, req) + if err != nil { + logger.Logger.Error("SendMessage call failed", zap.Error(err)) + c.sendError("SERVICE_ERROR", "消息发送失败") + return + } + + // 接收流式响应 + for { + if !stream.Recv() { + break + } + resp := stream.Msg() + + // 发送消息到客户端 + respMsg := map[string]interface{}{ + "type": "message", + "content": resp.Content, + "session_id": resp.SessionId, + "is_end": resp.IsEnd, + } + if resp.Error != "" { + respMsg["error"] = resp.Error + } + + if err := c.writeJSON(respMsg); err != nil { + logger.Logger.Error("Failed to send message to client", zap.Error(err)) + break + } + + if resp.IsEnd { + break + } + } +} + +// getHistory 获取对话历史 +func (c *Connection) getHistory(sessionID string) { + aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) + if err != nil { + logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) + c.sendError("SERVICE_ERROR", "服务暂时不可用") + return + } + + // 创建带 Attachments 的 context + userIDStr := strconv.FormatInt(c.UserID, 10) + starIDStr := strconv.FormatInt(c.StarID, 10) + attachments := map[string]interface{}{ + "user_id": userIDStr, + "star_id": starIDStr, + } + ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments) + + req := &pbAIChat.ChatHistoryRequest{ + SessionId: sessionID, + Limit: 20, + } + + resp, err := aiChatService.GetHistory(ctx, req) + if err != nil { + logger.Logger.Error("GetHistory call failed", zap.Error(err)) + c.sendError("SERVICE_ERROR", "获取历史失败") + return + } + + // 转换历史消息 + history := make([]map[string]string, len(resp.History)) + for i, m := range resp.History { + history[i] = map[string]string{ + "role": m.Role, + "content": m.Content, + } + } + + // 发送响应 + c.writeJSON(map[string]interface{}{ + "type": "history_response", + "session_id": sessionID, + "history": history, + }) +} + +// getPersonas 获取人设列表 +func (c *Connection) getPersonas() { + aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient) + if err != nil { + logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err)) + c.sendError("SERVICE_ERROR", "服务暂时不可用") + return + } + + ctx := context.Background() + req := &pbAIChat.GetPersonasRequest{ + UserId: c.UserID, + } + + resp, err := aiChatService.GetPersonas(ctx, req) + if err != nil { + logger.Logger.Error("GetPersonas call failed", zap.Error(err)) + c.sendError("SERVICE_ERROR", "获取人设失败") + return + } + + // 转换人设信息 + personas := make([]map[string]interface{}, len(resp.Personas)) + for i, p := range resp.Personas { + personas[i] = map[string]interface{}{ + "id": p.Id, + "name": p.Name, + "description": p.Description, + "avatar_url": p.AvatarUrl, + "talk_style": p.TalkStyle, + "is_default": p.IsDefault, + "created_at": p.CreatedAt, + "updated_at": p.UpdatedAt, + } + } + + // 发送响应 + c.writeJSON(map[string]interface{}{ + "type": "personas_response", + "personas": personas, + }) +} + +// sendError 发送错误消息 (使用 Send 通道) + +// Close 关闭连接 +func (h *Hub) Close() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, conn := range h.clients { + conn.Conn.Close() + } +} \ No newline at end of file diff --git a/backend/go.mod b/backend/go.mod index 3018663..27e4326 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -54,6 +54,7 @@ require ( github.com/k0kubun/pp v3.0.1+incompatible // indirect github.com/knadh/koanf v1.5.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lib/pq v1.12.3 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.5 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/backend/go.sum b/backend/go.sum index b0fb740..3079fef 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -470,6 +470,8 @@ github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjS github.com/lestrrat/go-envload v0.0.0-20180220120943-6ed08b54a570/go.mod h1:BLt8L9ld7wVsvEWQbuLrUZnCMnUmLZ+CGDzKtclrTlE= github.com/lestrrat/go-file-rotatelogs v0.0.0-20180223000712-d3151e2a480f/go.mod h1:UGmTpUd3rjbtfIpwAPrcfmGf/Z1HS95TATB+m57TPB8= github.com/lestrrat/go-strftime v0.0.0-20180220042222-ba3bf9c1d042/go.mod h1:TPpsiPUEh0zFL1Snz4crhMlBe60PYxRHr5oFF3rRYg0= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= diff --git a/backend/go.work b/backend/go.work index 885d8d1..d58db23 100644 --- a/backend/go.work +++ b/backend/go.work @@ -8,4 +8,5 @@ use ( ./services/galleryService ./services/socialService ./services/userService + ./services/aiChatService ) diff --git a/backend/go.work.sum b/backend/go.work.sum index b11b254..48be06c 100644 --- a/backend/go.work.sum +++ b/backend/go.work.sum @@ -26,39 +26,11 @@ github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HR github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAuRjVTiNNhvNRfY2Wxp9nhfyel4rklc= github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= -github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6/go.mod h1:4EUIoxs/do24zMOGGqYVWgw0s9NtiylnJglOeEB5UJo= -github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.4/go.mod h1:sCavSAvdzOjul4cEqeVtvlSaSScfNsTQ+46HwlTL1hc= -github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.5 h1:zE8vH9C7JiZLNJJQ5OwjU9mSi4T9ef9u3BURT6LCLC8= -github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.5/go.mod h1:tWnyE9AjF8J8qqLk645oUmVUnFybApTQWklQmi5tY6g= -github.com/alibabacloud-go/darabonba-array v0.1.0/go.mod h1:BLKxr0brnggqOJPqT09DFJ8g3fsDshapUD3C3aOEFaI= -github.com/alibabacloud-go/darabonba-encode-util v0.0.2/go.mod h1:JiW9higWHYXm7F4PKuMgEUETNZasrDM6vqVr/Can7H8= -github.com/alibabacloud-go/darabonba-map v0.0.2/go.mod h1:28AJaX8FOE/ym8OUFWga+MtEzBunJwQGceGQlvaPGPc= github.com/alibabacloud-go/darabonba-openapi/v2 v2.0.8/go.mod h1:CzQnh+94WDnJOnKZH5YRyouL+OOcdBnXY5VWAf0McgI= github.com/alibabacloud-go/darabonba-openapi/v2 v2.1.13 h1:Q00FU3H94Ts0ZIHDmY+fYGgB7dV9D/YX6FGsgorQPgw= -github.com/alibabacloud-go/darabonba-openapi/v2 v2.1.13/go.mod h1:lxFGfobinVsQ49ntjpgWghXmIF0/Sm4+wvBJ1h5RtaE= -github.com/alibabacloud-go/darabonba-signature-util v0.0.7/go.mod h1:oUzCYV2fcCH797xKdL6BDH8ADIHlzrtKVjeRtunBNTQ= -github.com/alibabacloud-go/darabonba-string v1.0.2/go.mod h1:93cTfV3vuPhhEwGGpKKqhVW4jLe7tDpo3LUM0i0g6mA= -github.com/alibabacloud-go/dysmsapi-20180501/v2 v2.0.8 h1:aDPyz6C+nenypx24N5qEt09NjpS6mu7Cu1A+wf9UTaY= -github.com/alibabacloud-go/dysmsapi-20180501/v2 v2.0.8/go.mod h1:e/vWJ5gLVnraPROSh+3oMSodf5ukaUlqNgH0IIcnz98= -github.com/alibabacloud-go/endpoint-util v1.1.0/go.mod h1:O5FuCALmCKs2Ff7JFJMudHs0I5EBgecXXxZRyswlEjE= -github.com/alibabacloud-go/openapi-util v0.1.0/go.mod h1:sQuElr4ywwFRlCCberQwKRFhRzIyG4QTP/P4y1CJ6Ws= -github.com/alibabacloud-go/tea v1.1.7/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4= -github.com/alibabacloud-go/tea v1.1.8/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4= -github.com/alibabacloud-go/tea v1.1.11/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4= -github.com/alibabacloud-go/tea v1.1.20/go.mod h1:nXxjm6CIFkBhwW4FQkNrolwbfon8Svy6cujmKFUq98A= github.com/alibabacloud-go/tea v1.2.1/go.mod h1:qbzof29bM/IFhLMtJPrgTGK3eauV5J2wSyEUo4OEmnA= -github.com/alibabacloud-go/tea v1.3.13 h1:WhGy6LIXaMbBM6VBYcsDCz6K/TPsT1Ri2hPmmZffZ94= -github.com/alibabacloud-go/tea v1.3.13/go.mod h1:A560v/JTQ1n5zklt2BEpurJzZTI8TUT+Psg2drWlxRg= -github.com/alibabacloud-go/tea-utils v1.3.1/go.mod h1:EI/o33aBfj3hETm4RLiAxF/ThQdSngxrpF8rKUDJjPE= -github.com/alibabacloud-go/tea-utils/v2 v2.0.5/go.mod h1:dL6vbUT35E4F4bFTHL845eUloqaerYBYPsdWR2/jhe4= -github.com/alibabacloud-go/tea-utils/v2 v2.0.7 h1:WDx5qW3Xa5ZgJ1c8NfqJkF6w+AU5wB8835UdhPr6Ax0= -github.com/alibabacloud-go/tea-utils/v2 v2.0.7/go.mod h1:qxn986l+q33J5VkialKMqT/TTs3E+U9MJpd001iWQ9I= github.com/alibabacloud-go/tea-utils/v2 v2.0.8/go.mod h1:qxn986l+q33J5VkialKMqT/TTs3E+U9MJpd001iWQ9I= github.com/alibabacloud-go/tea-xml v1.1.3/go.mod h1:Rq08vgCcCAjHyRi/M7xlHKUykZCEtyBy9+DPF6GgEu8= -github.com/aliyun/credentials-go v1.1.2/go.mod h1:ozcZaMR5kLM7pwtCMEpVmQ242suV6qTJya2bDq4X1Tw= -github.com/aliyun/credentials-go v1.3.1/go.mod h1:8jKYhQuDawt8x2+fusqa1Y6mPxemTsBEN04dgcAcYz0= -github.com/aliyun/credentials-go v1.3.6/go.mod h1:1LxUuX7L5YrZUWzBrRyk0SwSdH4OmPrib8NVePL3fxM= -github.com/aliyun/credentials-go v1.4.5/go.mod h1:Jm6d+xIgwJVLVWT561vy67ZRP4lPTQxMbEYRuT2Ti1U= github.com/antihax/optional v1.0.0 h1:xK2lYat7ZLaVVcIuj82J8kIro4V6kDe0AUDFboUCwcg= github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e h1:QEF07wC0T1rKkctt1RINW/+RMTVmiwxETico2l3gxJA= @@ -81,8 +53,6 @@ github.com/aws/smithy-go v1.8.0 h1:AEwwwXQZtUwP5Mz506FeXXrKBe0jA8gVM+1gEcSRooc= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY= github.com/bketelsen/crypt v0.0.4 h1:w/jqZtC9YD4DS/Vp9GhWfWcCpuAL58oTnLoI8vE9YHU= -github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= -github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/casbin/casbin/v2 v2.1.2 h1:bTwon/ECRx9dwBy2ewRVr5OiqjeXSGiTUY74sDPQi/g= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= @@ -94,8 +64,6 @@ github.com/chzyer/logex v1.2.0 h1:+eqR0HfOetur4tgnC8ftU5imRnhi4te+BadWS95c5AM= github.com/chzyer/readline v1.5.0 h1:lSwwFrbNviGePhkewF1az4oLmcwqCZijQ2/Wi3BGHAI= github.com/chzyer/test v0.0.0-20210722231415-061457976a23 h1:dZ0/VyGgQdVGAss6Ju0dt5P0QltE0SFY5Woh6hbIfiQ= github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= -github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME= -github.com/clbanning/mxj/v2 v2.7.0/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec h1:EdRZT3IeKQmfCSrgo8SZ8V3MEnskuJP0wCYNpe+aiXo= github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4 h1:hzAQntlaYRkVSFEfj9OTWlVV1H155FMD8BTKktLv0QI= @@ -112,6 +80,7 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbp github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM= github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954 h1:RMLoZVzv4GliuWafOuPuQDKSm1SJph7uCRnnS61JAn4= github.com/dubbogo/jsonparser v1.0.1 h1:sAIr8gk+gkahkIm6CnUxh9wTCkbgwLEQ8dTXTnAXyzo= github.com/dubbogo/net v0.0.4 h1:Rn9aMPZwOiRE22YhtxmDEE3H0Q3cfVRNhuEjNMelJ/8= @@ -174,13 +143,11 @@ github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 h1:8jtTdc+Nfj9AR+0s github.com/gonum/lapack v0.0.0-20181123203213-e4cdc5a0bff9 h1:7qnwS9+oeSiOIsiUMajT+0R7HR6hw5NegnKPmn/94oI= github.com/gonum/matrix v0.0.0-20181209220409-c518dec07be9 h1:V2IgdyerlBa/MxaEFRbV5juy/C3MGdj4ePi+g6ePIp4= github.com/gonum/stat v0.0.0-20181125101827-41a0da705a5b h1:fbskpz/cPqWH8VqkQ7LJghFkl2KPAiIFUHrTJ2O3RGk= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian/v3 v3.1.0 h1:wCKgOCHuUEVfsaQLpPSJb7VdYCdTVZQAuOdYm1yc/60= github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA= github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU= @@ -306,6 +273,7 @@ github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0 github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.5.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rhnvrm/simples3 v0.6.1 h1:H0DJwybR6ryQE+Odi9eqkHuzjYAeJgtGcGtuBwOhsH8= github.com/rogpeppe/fastuuid v1.2.0 h1:Ppwyp6VYCF1nvBTXL3trRso7mXMlRrw9ooo375wvi2s= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= @@ -316,18 +284,13 @@ github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da h1:p3Vo3i64TCL github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I= github.com/shirou/gopsutil v3.20.11+incompatible h1:LJr4ZQK4mPpIV5gOa4jCOKOGb4ty4DZO54I4FGqIpto= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= -github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/sony/gobreaker v0.4.1 h1:oMnRNZXX5j85zso6xCPRNPtmAycat+WcoKbklScLDgQ= github.com/spf13/cobra v1.1.1 h1:KfztREH0tPxJJ+geloSLaAkaPkr4ki2Er5quFV1TDo4= github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271 h1:WhxRHzgeVGETMlmVfqhRn8RIeeNoPr2Czh33I4Zdccw= github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a h1:AhmOdSHeswKHBjhsLs/7+1voOxT+LLrSk/Nxvk35fug= -github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/tebeka/strftime v0.1.3 h1:5HQXOqWKYRFfNyBMNVc9z5+QzuBtIXy03psIhtdJYto= -github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w= -github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= -github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= github.com/toolkits/concurrent v0.0.0-20150624120057-a4371d70e3e3 h1:kF/7m/ZU+0D4Jj5eZ41Zm3IH/J8OElK1Qtd7tVKAwLk= github.com/ugorji/go v1.2.6 h1:tGiWC9HENWE2tqYycIqFTNorMmFRVhNwCpDOpWqnk8E= github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY= @@ -335,75 +298,34 @@ github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc= github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU= -github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= -github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opentelemetry.io/contrib/detectors/gcp v1.38.0 h1:ZoYbqX7OaA/TAikspPl3ozPI6iY6LiIY9I8cUfm+pJs= go.opentelemetry.io/contrib/detectors/gcp v1.38.0/go.mod h1:SU+iU7nu5ud4oCb3LQOhIZ3nRLj6FNVrKgtflbaf2ts= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= -golang.org/x/crypto v0.0.0-20191219195013-becbf705a915/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4= golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20200509044756-6aff5f38e54f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 h1:E2/AqCUMZGgd73TQkxUMcMla25GB9i/5HOdLr+uH7Vo= golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= -golang.org/x/tools v0.0.0-20200509030707-2212a7e161a5/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= @@ -417,7 +339,6 @@ gopkg.in/cheggaaa/pb.v1 v1.0.25 h1:Ev7yu1/f6+d+b3pi5vPdRPc6nNtP1umSfcWiEfRqv6I= gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/gcfg.v1 v1.2.3 h1:m8OOJ4ccYHnx2f4gQwpno8nAX5OGOh7RLaaz0pj3Ogs= -gopkg.in/ini.v1 v1.56.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0 h1:CuXP0Pjfw9rOuY6EP+UvtNvt5DSqHpIxILZKT/quCZI= gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= diff --git a/backend/pkg/proto/ai_chat/ai_chat.pb.go b/backend/pkg/proto/ai_chat/ai_chat.pb.go new file mode 100644 index 0000000..3d97051 --- /dev/null +++ b/backend/pkg/proto/ai_chat/ai_chat.pb.go @@ -0,0 +1,726 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v7.34.0 +// source: ai_chat.proto + +package ai_chat + +import ( + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Message struct { + state protoimpl.MessageState `protogen:"open.v1"` + Role string `protobuf:"bytes,1,opt,name=role,proto3" json:"role,omitempty"` // "user" / "assistant" + Content string `protobuf:"bytes,2,opt,name=content,proto3" json:"content,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Message) Reset() { + *x = Message{} + mi := &file_ai_chat_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{0} +} + +func (x *Message) GetRole() string { + if x != nil { + return x.Role + } + return "" +} + +func (x *Message) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +type InitSessionRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + UserId int64 `protobuf:"varint,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InitSessionRequest) Reset() { + *x = InitSessionRequest{} + mi := &file_ai_chat_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InitSessionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InitSessionRequest) ProtoMessage() {} + +func (x *InitSessionRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InitSessionRequest.ProtoReflect.Descriptor instead. +func (*InitSessionRequest) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{1} +} + +func (x *InitSessionRequest) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *InitSessionRequest) GetUserId() int64 { + if x != nil { + return x.UserId + } + return 0 +} + +type InitSessionResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + WelcomeMessage string `protobuf:"bytes,1,opt,name=welcome_message,json=welcomeMessage,proto3" json:"welcome_message,omitempty"` + SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InitSessionResponse) Reset() { + *x = InitSessionResponse{} + mi := &file_ai_chat_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InitSessionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InitSessionResponse) ProtoMessage() {} + +func (x *InitSessionResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InitSessionResponse.ProtoReflect.Descriptor instead. +func (*InitSessionResponse) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{2} +} + +func (x *InitSessionResponse) GetWelcomeMessage() string { + if x != nil { + return x.WelcomeMessage + } + return "" +} + +func (x *InitSessionResponse) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +type ChatMessageRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + PersonaId string `protobuf:"bytes,3,opt,name=persona_id,json=personaId,proto3" json:"persona_id,omitempty"` // 可选,空则用默认人设 + UserId int64 `protobuf:"varint,4,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatMessageRequest) Reset() { + *x = ChatMessageRequest{} + mi := &file_ai_chat_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatMessageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatMessageRequest) ProtoMessage() {} + +func (x *ChatMessageRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatMessageRequest.ProtoReflect.Descriptor instead. +func (*ChatMessageRequest) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{3} +} + +func (x *ChatMessageRequest) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *ChatMessageRequest) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *ChatMessageRequest) GetPersonaId() string { + if x != nil { + return x.PersonaId + } + return "" +} + +func (x *ChatMessageRequest) GetUserId() int64 { + if x != nil { + return x.UserId + } + return 0 +} + +type ChatMessageResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Content string `protobuf:"bytes,1,opt,name=content,proto3" json:"content,omitempty"` + SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + IsEnd bool `protobuf:"varint,3,opt,name=is_end,json=isEnd,proto3" json:"is_end,omitempty"` + Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatMessageResponse) Reset() { + *x = ChatMessageResponse{} + mi := &file_ai_chat_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatMessageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatMessageResponse) ProtoMessage() {} + +func (x *ChatMessageResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatMessageResponse.ProtoReflect.Descriptor instead. +func (*ChatMessageResponse) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{4} +} + +func (x *ChatMessageResponse) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *ChatMessageResponse) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *ChatMessageResponse) GetIsEnd() bool { + if x != nil { + return x.IsEnd + } + return false +} + +func (x *ChatMessageResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +type ChatHistoryRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + Limit int32 `protobuf:"varint,2,opt,name=limit,proto3" json:"limit,omitempty"` // 默认 20 + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatHistoryRequest) Reset() { + *x = ChatHistoryRequest{} + mi := &file_ai_chat_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatHistoryRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatHistoryRequest) ProtoMessage() {} + +func (x *ChatHistoryRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatHistoryRequest.ProtoReflect.Descriptor instead. +func (*ChatHistoryRequest) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{5} +} + +func (x *ChatHistoryRequest) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *ChatHistoryRequest) GetLimit() int32 { + if x != nil { + return x.Limit + } + return 0 +} + +type ChatHistoryResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + History []*Message `protobuf:"bytes,1,rep,name=history,proto3" json:"history,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatHistoryResponse) Reset() { + *x = ChatHistoryResponse{} + mi := &file_ai_chat_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatHistoryResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatHistoryResponse) ProtoMessage() {} + +func (x *ChatHistoryResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatHistoryResponse.ProtoReflect.Descriptor instead. +func (*ChatHistoryResponse) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{6} +} + +func (x *ChatHistoryResponse) GetHistory() []*Message { + if x != nil { + return x.History + } + return nil +} + +type GetPersonasRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + UserId int64 `protobuf:"varint,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPersonasRequest) Reset() { + *x = GetPersonasRequest{} + mi := &file_ai_chat_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPersonasRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPersonasRequest) ProtoMessage() {} + +func (x *GetPersonasRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPersonasRequest.ProtoReflect.Descriptor instead. +func (*GetPersonasRequest) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{7} +} + +func (x *GetPersonasRequest) GetUserId() int64 { + if x != nil { + return x.UserId + } + return 0 +} + +type PersonaListResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Personas []*PersonaInfo `protobuf:"bytes,1,rep,name=personas,proto3" json:"personas,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PersonaListResponse) Reset() { + *x = PersonaListResponse{} + mi := &file_ai_chat_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PersonaListResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PersonaListResponse) ProtoMessage() {} + +func (x *PersonaListResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PersonaListResponse.ProtoReflect.Descriptor instead. +func (*PersonaListResponse) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{8} +} + +func (x *PersonaListResponse) GetPersonas() []*PersonaInfo { + if x != nil { + return x.Personas + } + return nil +} + +type PersonaInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + Description string `protobuf:"bytes,3,opt,name=description,proto3" json:"description,omitempty"` + AvatarUrl string `protobuf:"bytes,4,opt,name=avatar_url,json=avatarUrl,proto3" json:"avatar_url,omitempty"` + TalkStyle string `protobuf:"bytes,5,opt,name=talk_style,json=talkStyle,proto3" json:"talk_style,omitempty"` + IsDefault bool `protobuf:"varint,6,opt,name=is_default,json=isDefault,proto3" json:"is_default,omitempty"` + CreatedAt int64 `protobuf:"varint,7,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + UpdatedAt int64 `protobuf:"varint,8,opt,name=updated_at,json=updatedAt,proto3" json:"updated_at,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PersonaInfo) Reset() { + *x = PersonaInfo{} + mi := &file_ai_chat_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PersonaInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PersonaInfo) ProtoMessage() {} + +func (x *PersonaInfo) ProtoReflect() protoreflect.Message { + mi := &file_ai_chat_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PersonaInfo.ProtoReflect.Descriptor instead. +func (*PersonaInfo) Descriptor() ([]byte, []int) { + return file_ai_chat_proto_rawDescGZIP(), []int{9} +} + +func (x *PersonaInfo) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *PersonaInfo) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *PersonaInfo) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + +func (x *PersonaInfo) GetAvatarUrl() string { + if x != nil { + return x.AvatarUrl + } + return "" +} + +func (x *PersonaInfo) GetTalkStyle() string { + if x != nil { + return x.TalkStyle + } + return "" +} + +func (x *PersonaInfo) GetIsDefault() bool { + if x != nil { + return x.IsDefault + } + return false +} + +func (x *PersonaInfo) GetCreatedAt() int64 { + if x != nil { + return x.CreatedAt + } + return 0 +} + +func (x *PersonaInfo) GetUpdatedAt() int64 { + if x != nil { + return x.UpdatedAt + } + return 0 +} + +var File_ai_chat_proto protoreflect.FileDescriptor + +const file_ai_chat_proto_rawDesc = "" + + "\n" + + "\rai_chat.proto\x12\x0ftopfans.ai_chat\x1a\x1cgoogle/api/annotations.proto\"7\n" + + "\aMessage\x12\x12\n" + + "\x04role\x18\x01 \x01(\tR\x04role\x12\x18\n" + + "\acontent\x18\x02 \x01(\tR\acontent\"L\n" + + "\x12InitSessionRequest\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x17\n" + + "\auser_id\x18\x02 \x01(\x03R\x06userId\"]\n" + + "\x13InitSessionResponse\x12'\n" + + "\x0fwelcome_message\x18\x01 \x01(\tR\x0ewelcomeMessage\x12\x1d\n" + + "\n" + + "session_id\x18\x02 \x01(\tR\tsessionId\"\x85\x01\n" + + "\x12ChatMessageRequest\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\x12\x1d\n" + + "\n" + + "persona_id\x18\x03 \x01(\tR\tpersonaId\x12\x17\n" + + "\auser_id\x18\x04 \x01(\x03R\x06userId\"{\n" + + "\x13ChatMessageResponse\x12\x18\n" + + "\acontent\x18\x01 \x01(\tR\acontent\x12\x1d\n" + + "\n" + + "session_id\x18\x02 \x01(\tR\tsessionId\x12\x15\n" + + "\x06is_end\x18\x03 \x01(\bR\x05isEnd\x12\x14\n" + + "\x05error\x18\x04 \x01(\tR\x05error\"I\n" + + "\x12ChatHistoryRequest\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x14\n" + + "\x05limit\x18\x02 \x01(\x05R\x05limit\"I\n" + + "\x13ChatHistoryResponse\x122\n" + + "\ahistory\x18\x01 \x03(\v2\x18.topfans.ai_chat.MessageR\ahistory\"-\n" + + "\x12GetPersonasRequest\x12\x17\n" + + "\auser_id\x18\x01 \x01(\x03R\x06userId\"O\n" + + "\x13PersonaListResponse\x128\n" + + "\bpersonas\x18\x01 \x03(\v2\x1c.topfans.ai_chat.PersonaInfoR\bpersonas\"\xee\x01\n" + + "\vPersonaInfo\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12 \n" + + "\vdescription\x18\x03 \x01(\tR\vdescription\x12\x1d\n" + + "\n" + + "avatar_url\x18\x04 \x01(\tR\tavatarUrl\x12\x1d\n" + + "\n" + + "talk_style\x18\x05 \x01(\tR\ttalkStyle\x12\x1d\n" + + "\n" + + "is_default\x18\x06 \x01(\bR\tisDefault\x12\x1d\n" + + "\n" + + "created_at\x18\a \x01(\x03R\tcreatedAt\x12\x1d\n" + + "\n" + + "updated_at\x18\b \x01(\x03R\tupdatedAt2\xc9\x03\n" + + "\rAIChatService\x12X\n" + + "\vInitSession\x12#.topfans.ai_chat.InitSessionRequest\x1a$.topfans.ai_chat.InitSessionResponse\x12Z\n" + + "\vSendMessage\x12#.topfans.ai_chat.ChatMessageRequest\x1a$.topfans.ai_chat.ChatMessageResponse0\x01\x12\x85\x01\n" + + "\n" + + "GetHistory\x12#.topfans.ai_chat.ChatHistoryRequest\x1a$.topfans.ai_chat.ChatHistoryResponse\",\x82\xd3\xe4\x93\x02&\x12$/api/v1/ai-chat/history/{session_id}\x12z\n" + + "\vGetPersonas\x12#.topfans.ai_chat.GetPersonasRequest\x1a$.topfans.ai_chat.PersonaListResponse\" \x82\xd3\xe4\x93\x02\x1a\x12\x18/api/v1/ai-chat/personasB6Z4github.com/topfans/backend/pkg/proto/ai_chat;ai_chatb\x06proto3" + +var ( + file_ai_chat_proto_rawDescOnce sync.Once + file_ai_chat_proto_rawDescData []byte +) + +func file_ai_chat_proto_rawDescGZIP() []byte { + file_ai_chat_proto_rawDescOnce.Do(func() { + file_ai_chat_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_ai_chat_proto_rawDesc), len(file_ai_chat_proto_rawDesc))) + }) + return file_ai_chat_proto_rawDescData +} + +var file_ai_chat_proto_msgTypes = make([]protoimpl.MessageInfo, 10) +var file_ai_chat_proto_goTypes = []any{ + (*Message)(nil), // 0: topfans.ai_chat.Message + (*InitSessionRequest)(nil), // 1: topfans.ai_chat.InitSessionRequest + (*InitSessionResponse)(nil), // 2: topfans.ai_chat.InitSessionResponse + (*ChatMessageRequest)(nil), // 3: topfans.ai_chat.ChatMessageRequest + (*ChatMessageResponse)(nil), // 4: topfans.ai_chat.ChatMessageResponse + (*ChatHistoryRequest)(nil), // 5: topfans.ai_chat.ChatHistoryRequest + (*ChatHistoryResponse)(nil), // 6: topfans.ai_chat.ChatHistoryResponse + (*GetPersonasRequest)(nil), // 7: topfans.ai_chat.GetPersonasRequest + (*PersonaListResponse)(nil), // 8: topfans.ai_chat.PersonaListResponse + (*PersonaInfo)(nil), // 9: topfans.ai_chat.PersonaInfo +} +var file_ai_chat_proto_depIdxs = []int32{ + 0, // 0: topfans.ai_chat.ChatHistoryResponse.history:type_name -> topfans.ai_chat.Message + 9, // 1: topfans.ai_chat.PersonaListResponse.personas:type_name -> topfans.ai_chat.PersonaInfo + 1, // 2: topfans.ai_chat.AIChatService.InitSession:input_type -> topfans.ai_chat.InitSessionRequest + 3, // 3: topfans.ai_chat.AIChatService.SendMessage:input_type -> topfans.ai_chat.ChatMessageRequest + 5, // 4: topfans.ai_chat.AIChatService.GetHistory:input_type -> topfans.ai_chat.ChatHistoryRequest + 7, // 5: topfans.ai_chat.AIChatService.GetPersonas:input_type -> topfans.ai_chat.GetPersonasRequest + 2, // 6: topfans.ai_chat.AIChatService.InitSession:output_type -> topfans.ai_chat.InitSessionResponse + 4, // 7: topfans.ai_chat.AIChatService.SendMessage:output_type -> topfans.ai_chat.ChatMessageResponse + 6, // 8: topfans.ai_chat.AIChatService.GetHistory:output_type -> topfans.ai_chat.ChatHistoryResponse + 8, // 9: topfans.ai_chat.AIChatService.GetPersonas:output_type -> topfans.ai_chat.PersonaListResponse + 6, // [6:10] is the sub-list for method output_type + 2, // [2:6] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_ai_chat_proto_init() } +func file_ai_chat_proto_init() { + if File_ai_chat_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_chat_proto_rawDesc), len(file_ai_chat_proto_rawDesc)), + NumEnums: 0, + NumMessages: 10, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_ai_chat_proto_goTypes, + DependencyIndexes: file_ai_chat_proto_depIdxs, + MessageInfos: file_ai_chat_proto_msgTypes, + }.Build() + File_ai_chat_proto = out.File + file_ai_chat_proto_goTypes = nil + file_ai_chat_proto_depIdxs = nil +} diff --git a/backend/pkg/proto/ai_chat/ai_chat.triple.go b/backend/pkg/proto/ai_chat/ai_chat.triple.go new file mode 100644 index 0000000..136aab4 --- /dev/null +++ b/backend/pkg/proto/ai_chat/ai_chat.triple.go @@ -0,0 +1,258 @@ +// Code generated by protoc-gen-triple. DO NOT EDIT. +// +// Source: ai_chat.proto +package ai_chat + +import ( + "context" + "net/http" +) + +import ( + "dubbo.apache.org/dubbo-go/v3" + "dubbo.apache.org/dubbo-go/v3/client" + "dubbo.apache.org/dubbo-go/v3/common" + "dubbo.apache.org/dubbo-go/v3/common/constant" + "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" + "dubbo.apache.org/dubbo-go/v3/server" +) + +// This is a compile-time assertion to ensure that this generated file and the Triple package +// are compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of Triple newer than the one compiled into your binary. You can fix the +// problem by either regenerating this code with an older version of Triple or updating the Triple +// version compiled into your binary. +const _ = triple_protocol.IsAtLeastVersion0_1_0 + +const ( + // AIChatServiceName is the fully-qualified name of the AIChatService service. + AIChatServiceName = "topfans.ai_chat.AIChatService" +) + +// These constants are the fully-qualified names of the RPCs defined in this package. They're +// exposed at runtime as procedure and as the final two segments of the HTTP route. +// +// Note that these are different from the fully-qualified method names used by +// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to +// reflection-formatted method names, remove the leading slash and convert the remaining slash to a +// period. +const ( + // AIChatServiceInitSessionProcedure is the fully-qualified name of the AIChatService's InitSession RPC. + AIChatServiceInitSessionProcedure = "/topfans.ai_chat.AIChatService/InitSession" + // AIChatServiceSendMessageProcedure is the fully-qualified name of the AIChatService's SendMessage RPC. + AIChatServiceSendMessageProcedure = "/topfans.ai_chat.AIChatService/SendMessage" + // AIChatServiceGetHistoryProcedure is the fully-qualified name of the AIChatService's GetHistory RPC. + AIChatServiceGetHistoryProcedure = "/topfans.ai_chat.AIChatService/GetHistory" + // AIChatServiceGetPersonasProcedure is the fully-qualified name of the AIChatService's GetPersonas RPC. + AIChatServiceGetPersonasProcedure = "/topfans.ai_chat.AIChatService/GetPersonas" +) + +var ( + _ AIChatService = (*AIChatServiceImpl)(nil) + + _ AIChatService_SendMessageClient = (*AIChatServiceSendMessageClient)(nil) + + _ AIChatService_SendMessageServer = (*AIChatServiceSendMessageServer)(nil) +) + +// AIChatService is a client for the topfans.ai_chat.AIChatService service. +type AIChatService interface { + InitSession(ctx context.Context, req *InitSessionRequest, opts ...client.CallOption) (*InitSessionResponse, error) + SendMessage(ctx context.Context, req *ChatMessageRequest, opts ...client.CallOption) (AIChatService_SendMessageClient, error) + GetHistory(ctx context.Context, req *ChatHistoryRequest, opts ...client.CallOption) (*ChatHistoryResponse, error) + GetPersonas(ctx context.Context, req *GetPersonasRequest, opts ...client.CallOption) (*PersonaListResponse, error) +} + +// NewAIChatService constructs a client for the ai_chat.AIChatService service. +func NewAIChatService(cli *client.Client, opts ...client.ReferenceOption) (AIChatService, error) { + conn, err := cli.DialWithInfo("topfans.ai_chat.AIChatService", &AIChatService_ClientInfo, opts...) + if err != nil { + return nil, err + } + return &AIChatServiceImpl{ + conn: conn, + }, nil +} + +func SetConsumerAIChatService(srv common.RPCService) { + dubbo.SetConsumerServiceWithInfo(srv, &AIChatService_ClientInfo) +} + +// AIChatServiceImpl implements AIChatService. +type AIChatServiceImpl struct { + conn *client.Connection +} + +func (c *AIChatServiceImpl) InitSession(ctx context.Context, req *InitSessionRequest, opts ...client.CallOption) (*InitSessionResponse, error) { + resp := new(InitSessionResponse) + if err := c.conn.CallUnary(ctx, []interface{}{req}, resp, "InitSession", opts...); err != nil { + return nil, err + } + return resp, nil +} + +func (c *AIChatServiceImpl) SendMessage(ctx context.Context, req *ChatMessageRequest, opts ...client.CallOption) (AIChatService_SendMessageClient, error) { + stream, err := c.conn.CallServerStream(ctx, req, "SendMessage", opts...) + if err != nil { + return nil, err + } + rawStream := stream.(*triple_protocol.ServerStreamForClient) + return &AIChatServiceSendMessageClient{rawStream}, nil +} + +func (c *AIChatServiceImpl) GetHistory(ctx context.Context, req *ChatHistoryRequest, opts ...client.CallOption) (*ChatHistoryResponse, error) { + resp := new(ChatHistoryResponse) + if err := c.conn.CallUnary(ctx, []interface{}{req}, resp, "GetHistory", opts...); err != nil { + return nil, err + } + return resp, nil +} + +func (c *AIChatServiceImpl) GetPersonas(ctx context.Context, req *GetPersonasRequest, opts ...client.CallOption) (*PersonaListResponse, error) { + resp := new(PersonaListResponse) + if err := c.conn.CallUnary(ctx, []interface{}{req}, resp, "GetPersonas", opts...); err != nil { + return nil, err + } + return resp, nil +} + +type AIChatService_SendMessageClient interface { + Recv() bool + ResponseHeader() http.Header + ResponseTrailer() http.Header + Msg() *ChatMessageResponse + Err() error + Conn() (triple_protocol.StreamingClientConn, error) + Close() error +} + +type AIChatServiceSendMessageClient struct { + *triple_protocol.ServerStreamForClient +} + +func (cli *AIChatServiceSendMessageClient) Recv() bool { + msg := new(ChatMessageResponse) + return cli.ServerStreamForClient.Receive(msg) +} + +func (cli *AIChatServiceSendMessageClient) Msg() *ChatMessageResponse { + msg := cli.ServerStreamForClient.Msg() + if msg == nil { + return new(ChatMessageResponse) + } + return msg.(*ChatMessageResponse) +} + +func (cli *AIChatServiceSendMessageClient) Conn() (triple_protocol.StreamingClientConn, error) { + return cli.ServerStreamForClient.Conn() +} + +var AIChatService_ClientInfo = client.ClientInfo{ + InterfaceName: "topfans.ai_chat.AIChatService", + MethodNames: []string{"InitSession", "SendMessage", "GetHistory", "GetPersonas"}, + ConnectionInjectFunc: func(dubboCliRaw interface{}, conn *client.Connection) { + dubboCli := dubboCliRaw.(*AIChatServiceImpl) + dubboCli.conn = conn + }, +} + +// AIChatServiceHandler is an implementation of the topfans.ai_chat.AIChatService service. +type AIChatServiceHandler interface { + InitSession(context.Context, *InitSessionRequest) (*InitSessionResponse, error) + SendMessage(context.Context, *ChatMessageRequest, AIChatService_SendMessageServer) error + GetHistory(context.Context, *ChatHistoryRequest) (*ChatHistoryResponse, error) + GetPersonas(context.Context, *GetPersonasRequest) (*PersonaListResponse, error) +} + +func RegisterAIChatServiceHandler(srv *server.Server, hdlr AIChatServiceHandler, opts ...server.ServiceOption) error { + return srv.Register(hdlr, &AIChatService_ServiceInfo, opts...) +} + +func SetProviderAIChatService(srv common.RPCService) { + dubbo.SetProviderServiceWithInfo(srv, &AIChatService_ServiceInfo) +} + +type AIChatService_SendMessageServer interface { + Send(*ChatMessageResponse) error + ResponseHeader() http.Header + ResponseTrailer() http.Header + Conn() triple_protocol.StreamingHandlerConn +} + +type AIChatServiceSendMessageServer struct { + *triple_protocol.ServerStream +} + +func (g *AIChatServiceSendMessageServer) Send(msg *ChatMessageResponse) error { + return g.ServerStream.Send(msg) +} + +var AIChatService_ServiceInfo = server.ServiceInfo{ + InterfaceName: "topfans.ai_chat.AIChatService", + ServiceType: (*AIChatServiceHandler)(nil), + Methods: []server.MethodInfo{ + { + Name: "InitSession", + Type: constant.CallUnary, + ReqInitFunc: func() interface{} { + return new(InitSessionRequest) + }, + MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) { + req := args[0].(*InitSessionRequest) + res, err := handler.(AIChatServiceHandler).InitSession(ctx, req) + if err != nil { + return nil, err + } + return triple_protocol.NewResponse(res), nil + }, + }, + { + Name: "SendMessage", + Type: constant.CallServerStream, + ReqInitFunc: func() interface{} { + return new(ChatMessageRequest) + }, + StreamInitFunc: func(baseStream interface{}) interface{} { + return &AIChatServiceSendMessageServer{baseStream.(*triple_protocol.ServerStream)} + }, + MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) { + req := args[0].(*ChatMessageRequest) + stream := args[1].(AIChatService_SendMessageServer) + if err := handler.(AIChatServiceHandler).SendMessage(ctx, req, stream); err != nil { + return nil, err + } + return nil, nil + }, + }, + { + Name: "GetHistory", + Type: constant.CallUnary, + ReqInitFunc: func() interface{} { + return new(ChatHistoryRequest) + }, + MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) { + req := args[0].(*ChatHistoryRequest) + res, err := handler.(AIChatServiceHandler).GetHistory(ctx, req) + if err != nil { + return nil, err + } + return triple_protocol.NewResponse(res), nil + }, + }, + { + Name: "GetPersonas", + Type: constant.CallUnary, + ReqInitFunc: func() interface{} { + return new(GetPersonasRequest) + }, + MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) { + req := args[0].(*GetPersonasRequest) + res, err := handler.(AIChatServiceHandler).GetPersonas(ctx, req) + if err != nil { + return nil, err + } + return triple_protocol.NewResponse(res), nil + }, + }, + }, +} diff --git a/backend/proto/ai_chat.proto b/backend/proto/ai_chat.proto new file mode 100644 index 0000000..15b9259 --- /dev/null +++ b/backend/proto/ai_chat.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package topfans.ai_chat; + +option go_package = "github.com/topfans/backend/pkg/proto/ai_chat;ai_chat"; + +import "google/api/annotations.proto"; + +// ==================== 对话消息 ==================== + +message Message { + string role = 1; // "user" / "assistant" + string content = 2; +} + +// ==================== 对话相关 ==================== + +message InitSessionRequest { + string session_id = 1; + int64 user_id = 2; +} + +message InitSessionResponse { + string welcome_message = 1; + string session_id = 2; +} + +message ChatMessageRequest { + string session_id = 1; + string message = 2; + string persona_id = 3; // 可选,空则用默认人设 + int64 user_id = 4; +} + +message ChatMessageResponse { + string content = 1; + string session_id = 2; + bool is_end = 3; + string error = 4; +} + +message ChatHistoryRequest { + string session_id = 1; + int32 limit = 2; // 默认 20 +} + +message ChatHistoryResponse { + repeated Message history = 1; +} + +// ==================== 人设管理 ==================== + +message GetPersonasRequest { + int64 user_id = 1; +} + +message PersonaListResponse { + repeated PersonaInfo personas = 1; +} + +message PersonaInfo { + string id = 1; + string name = 2; + string description = 3; + string avatar_url = 4; + string talk_style = 5; + bool is_default = 6; + int64 created_at = 7; + int64 updated_at = 8; +} + +// ==================== AI Chat Service ==================== + +service AIChatService { + // 初始化会话,获取欢迎消息 + rpc InitSession(InitSessionRequest) returns (InitSessionResponse); + + // 发送消息,流式返回 + rpc SendMessage(ChatMessageRequest) returns (stream ChatMessageResponse); + + // 获取对话历史 + rpc GetHistory(ChatHistoryRequest) returns (ChatHistoryResponse) { + option (google.api.http) = { + get: "/api/v1/ai-chat/history/{session_id}" + }; + } + + // 获取用户的所有人设 + rpc GetPersonas(GetPersonasRequest) returns (PersonaListResponse) { + option (google.api.http) = { + get: "/api/v1/ai-chat/personas" + }; + } +} \ No newline at end of file diff --git a/backend/scripts/compile-proto.sh b/backend/scripts/compile-proto.sh index eb50999..fa3a4fe 100755 --- a/backend/scripts/compile-proto.sh +++ b/backend/scripts/compile-proto.sh @@ -67,7 +67,7 @@ move_triple_files() { # 预先创建目标目录 echo "📁 创建目标目录..." -for name in common user social asset gallery ranking activity task starbook; do +for name in common user social asset gallery ranking activity task starbook ai_chat; do mkdir -p "pkg/proto/$name" done echo "" @@ -194,6 +194,20 @@ move_triple_files "topfans/backend/pkg/proto/starbook" "pkg/proto/starbook" echo "✅ starbook.proto 编译完成" echo "" +# 编译 ai_chat.proto +echo "📦 编译 ai_chat.proto ..." +protoc --proto_path=proto \ + --proto_path=. \ + --go_out=pkg/proto/ai_chat \ + --go_opt=paths=source_relative \ + --go-triple_out=pkg/proto/ai_chat \ + --go-triple_opt=paths=source_relative \ + ai_chat.proto +move_triple_files "topfans/backend/pkg/proto/ai_chat" "pkg/proto/ai_chat" + +echo "✅ ai_chat.proto 编译完成" +echo "" + # 清理可能存在的冗余目录和文件 echo "🔄 清理冗余文件..." @@ -204,7 +218,7 @@ if [ -d "github.com" ]; then fi # 删除 proto 目录下的生成文件(如果存在) -for name in common user social asset gallery ranking activity task starbook; do +for name in common user social asset gallery ranking activity task starbook ai_chat; do if [ -f "proto/$name.pb.go" ]; then rm "proto/$name.pb.go" echo " ✅ proto/$name.pb.go 已清理" diff --git a/backend/scripts/migrations/migrate_ai_chat_tables.sql b/backend/scripts/migrations/migrate_ai_chat_tables.sql new file mode 100644 index 0000000..e17d1d5 --- /dev/null +++ b/backend/scripts/migrations/migrate_ai_chat_tables.sql @@ -0,0 +1,150 @@ +-- AI Chat Service 数据库迁移 +-- 创建人设表、用户记忆表、配置表 + +-- ============================================ +-- 1. ai_personas 表 - 人设表 +-- 存储用户创建的 AI 聊天人设配置 +-- ============================================ +CREATE TABLE IF NOT EXISTS ai_personas ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), -- 人设唯一标识符 + user_id BIGINT NOT NULL, -- 所属用户 ID + name VARCHAR(64) NOT NULL, -- 人设名称 + description TEXT, -- 人设描述/简介 + avatar_url VARCHAR(512), -- 头像 URL + talk_style VARCHAR(256), -- 说话风格描述 + system_prompt TEXT NOT NULL, -- 系统提示词(人设核心定义) + is_default BOOLEAN DEFAULT FALSE, -- 是否为默认人设 + created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000, -- 创建时间(毫秒时间戳) + updated_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000 -- 更新时间(毫秒时间戳) +); + +-- 索引 +CREATE INDEX IF NOT EXISTS idx_ai_personas_user_id ON ai_personas(user_id); +CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_personas_user_default ON ai_personas(user_id) WHERE is_default = TRUE; + +-- ============================================ +-- 2. ai_user_memories 表 - 用户长期记忆表 +-- 存储用户的长期记忆信息,用于 AI 对话时召回 +-- ============================================ +CREATE TABLE IF NOT EXISTS ai_user_memories ( + id SERIAL PRIMARY KEY, -- 记忆唯一标识符 + user_id BIGINT NOT NULL, -- 所属用户 ID + content TEXT NOT NULL, -- 记忆内容 + keywords TEXT[], -- 关键词数组(用于检索) + weight INTEGER DEFAULT 50, -- 权重(0-100,影响召回优先级) + is_core BOOLEAN DEFAULT FALSE, -- 是否为核心记忆(核心记忆优先召回) + created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000, -- 创建时间 + updated_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000 -- 更新时间 +); + +-- 索引 +CREATE INDEX IF NOT EXISTS idx_ai_user_memories_user_id ON ai_user_memories(user_id); +CREATE INDEX IF NOT EXISTS idx_ai_user_memories_keywords ON ai_user_memories USING GIN(keywords); +CREATE INDEX IF NOT EXISTS idx_ai_user_memories_weight ON ai_user_memories(weight DESC); + +-- ============================================ +-- 3. ai_chat_configs 表 - AI Chat 配置表 +-- 存储 AI Chat 服务的各类配置项 +-- ============================================ +CREATE TABLE IF NOT EXISTS ai_chat_configs ( + id SERIAL PRIMARY KEY, -- 配置项唯一 ID + config_key VARCHAR(128) NOT NULL UNIQUE, -- 配置键 + config_value TEXT NOT NULL, -- 配置值 + config_type VARCHAR(32) NOT NULL DEFAULT 'string', -- 配置类型(string/number/boolean/json) + category VARCHAR(64) NOT NULL, -- 配置分类(redis/llm/dialog/token/circuit/summary) + description VARCHAR(256), -- 配置描述 + is_encrypted BOOLEAN DEFAULT FALSE, -- 是否加密存储(加密字段不返回明文) + updated_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000, -- 更新时间 + created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000 -- 创建时间 +); + +-- 索引 +CREATE INDEX IF NOT EXISTS idx_ai_chat_configs_category ON ai_chat_configs(category); +CREATE INDEX IF NOT EXISTS idx_ai_chat_configs_key ON ai_chat_configs(config_key); + +-- 初始配置数据 +INSERT INTO ai_chat_configs (config_key, config_value, config_type, category, description, is_encrypted) VALUES +-- Redis 配置 +('redis.host', '127.0.0.1', 'string', 'redis', 'Redis 主机地址', FALSE), +('redis.port', '6379', 'number', 'redis', 'Redis 端口', FALSE), +('redis.password', '123456', 'string', 'redis', 'Redis 密码', TRUE), +('redis.db', '0', 'number', 'redis', 'Redis 数据库编号', FALSE), + +-- MiniMax 大模型配置 +('minimax.api_key', '', 'string', 'llm', 'MiniMax API Key', TRUE), +('minimax.api_url', 'https://api.minimaxi.com/v1', 'string', 'llm', 'MiniMax API 地址', FALSE), +('minimax.model', 'M2-her', 'string', 'llm', 'MiniMax 模型名称', FALSE), + +-- 通义备用模型配置 +('qwen.api_key', '', 'string', 'llm', '通义 API Key', TRUE), +('qwen.api_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'string', 'llm', '通义 API 地址', FALSE), +('qwen.model', 'qwen-plus', 'string', 'llm', '通义模型名称', FALSE), + +-- 对话配置 +('dialog.max_context_turns', '10', 'number', 'dialog', '最大上下文轮数', FALSE), +('dialog.context_expire_seconds', '86400', 'number', 'dialog', '上下文过期时间(秒)', FALSE), +('dialog.memory_recall_topn', '5', 'number', 'dialog', '记忆召回返回条数', FALSE), +('dialog.fallback_threshold', '3', 'number', 'dialog', '模型降级连续失败次数阈值', FALSE), +('dialog.slow_response_ms', '5000', 'number', 'dialog', '慢响应判定阈值(毫秒)', FALSE), + +-- AI Token 限制配置 +('token.max_total', '32000', 'number', 'token', '总 Token 上限', FALSE), +('token.max_history', '24000', 'number', 'token', '对话历史最大 Token', FALSE), +('token.max_system', '4000', 'number', 'token', 'System Prompt 最大 Token', FALSE), +('token.max_memory', '2000', 'number', 'token', '记忆召回最大 Token', FALSE), +('token.reserved', '2000', 'number', 'token', '保留空间 Token', FALSE), + +-- 熔断配置 +('circuit.max_fail_count', '5', 'number', 'circuit', '熔断连续失败次数', FALSE), +('circuit.breaker_timeout', '60', 'number', 'circuit', '熔断恢复超时(秒)', FALSE), + +-- 摘要配置 +('summary.trigger_turns', '10', 'number', 'summary', '自动摘要触发轮数', FALSE), +('summary.max_length', '100', 'number', 'summary', '摘要最大字数', FALSE) +ON CONFLICT (config_key) DO NOTHING; + +-- ============================================ +-- 表结构注释 +-- ============================================ +COMMENT ON TABLE ai_personas IS 'AI 人设表 - 存储用户创建的聊天人设配置'; +COMMENT ON TABLE ai_user_memories IS '用户长期记忆表 - 存储用户记忆信息,用于对话时召回'; +COMMENT ON TABLE ai_chat_configs IS 'AI Chat 配置表 - 存储服务各类配置项'; + +-- ============================================ +-- ai_personas 列注释 +-- ============================================ +COMMENT ON COLUMN ai_personas.id IS '人设唯一标识符'; +COMMENT ON COLUMN ai_personas.user_id IS '所属用户ID'; +COMMENT ON COLUMN ai_personas.name IS '人设名称'; +COMMENT ON COLUMN ai_personas.description IS '人设描述/简介'; +COMMENT ON COLUMN ai_personas.avatar_url IS '人设头像URL'; +COMMENT ON COLUMN ai_personas.talk_style IS 'AI说话风格描述'; +COMMENT ON COLUMN ai_personas.system_prompt IS '系统提示词,定义人设核心性格和行为'; +COMMENT ON COLUMN ai_personas.is_default IS '是否为用户的默认人设'; +COMMENT ON COLUMN ai_personas.created_at IS '记录创建时间(毫秒时间戳)'; +COMMENT ON COLUMN ai_personas.updated_at IS '记录更新时间(毫秒时间戳)'; + +-- ============================================ +-- ai_user_memories 列注释 +-- ============================================ +COMMENT ON COLUMN ai_user_memories.id IS '记忆记录唯一ID'; +COMMENT ON COLUMN ai_user_memories.user_id IS '所属用户ID'; +COMMENT ON COLUMN ai_user_memories.content IS '记忆内容文本'; +COMMENT ON COLUMN ai_user_memories.keywords IS '关键词数组,用于记忆检索'; +COMMENT ON COLUMN ai_user_memories.weight IS '记忆权重(0-100),影响召回优先级'; +COMMENT ON COLUMN ai_user_memories.is_core IS '是否为核心记忆,核心记忆优先召回'; +COMMENT ON COLUMN ai_user_memories.created_at IS '记忆创建时间'; +COMMENT ON COLUMN ai_user_memories.updated_at IS '记忆更新时间'; + +-- ============================================ +-- ai_chat_configs 列注释 +-- ============================================ +COMMENT ON COLUMN ai_chat_configs.id IS '配置项唯一ID'; +COMMENT ON COLUMN ai_chat_configs.config_key IS '配置键名称'; +COMMENT ON COLUMN ai_chat_configs.config_value IS '配置值'; +COMMENT ON COLUMN ai_chat_configs.config_type IS '配置类型:string/number/boolean/json'; +COMMENT ON COLUMN ai_chat_configs.category IS '配置分类:redis/llm/dialog/token/circuit/summary'; +COMMENT ON COLUMN ai_chat_configs.description IS '配置项的中文描述'; +COMMENT ON COLUMN ai_chat_configs.is_encrypted IS '是否加密存储,加密字段不会返回明文'; +COMMENT ON COLUMN ai_chat_configs.updated_at IS '配置更新时间'; +COMMENT ON COLUMN ai_chat_configs.created_at IS '配置创建时间'; \ No newline at end of file diff --git a/backend/services/aiChatService/configs/dubbo.yaml b/backend/services/aiChatService/configs/dubbo.yaml new file mode 100644 index 0000000..e92c3c4 --- /dev/null +++ b/backend/services/aiChatService/configs/dubbo.yaml @@ -0,0 +1,35 @@ +# Dubbo AI Chat Service 配置文件 +dubbo: + # 应用配置 + application: + name: ai-chat-service + version: 1.0.0 + + # 注册中心配置 + registries: + nacos: + protocol: nacos + address: 127.0.0.1:8848 + timeout: 5s + + # 协议配置 + protocols: + triple: + name: tri + port: 20008 # AI Chat Service 端口 + + # Provider 配置 + provider: + registry-ids: nacos + protocol-ids: triple + services: + # AI Chat Service 服务定义 + AIChatService: + interface: "github.com/topfans/backend/pkg/proto/ai_chat.AIChatService" + + # Consumer 配置 + consumer: + registry-ids: nacos + + # 超时配置 + timeout: 60s \ No newline at end of file diff --git a/backend/services/aiChatService/go.mod b/backend/services/aiChatService/go.mod new file mode 100644 index 0000000..fd07a7b --- /dev/null +++ b/backend/services/aiChatService/go.mod @@ -0,0 +1,23 @@ +module github.com/topfans/backend/services/aiChatService + +go 1.25.5 + +require ( + dubbo.apache.org/dubbo-go/v3 v3.3.1 + github.com/google/uuid v1.6.0 + github.com/redis/go-redis/v9 v9.5.1 + github.com/stretchr/testify v1.11.1 + github.com/topfans/backend v0.0.0 + go.uber.org/zap v1.27.1 + gorm.io/gorm v1.31.1 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/stretchr/objx v0.4.0 // indirect + github.com/stretchr/testify v1.7.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/topfans/backend => ../.. \ No newline at end of file diff --git a/backend/services/aiChatService/main.go b/backend/services/aiChatService/main.go new file mode 100644 index 0000000..476283c --- /dev/null +++ b/backend/services/aiChatService/main.go @@ -0,0 +1,265 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + _ "dubbo.apache.org/dubbo-go/v3/imports" + _ "dubbo.apache.org/dubbo-go/v3/protocol/triple" + "dubbo.apache.org/dubbo-go/v3/protocol" + "dubbo.apache.org/dubbo-go/v3/server" + "github.com/joho/godotenv" + "github.com/redis/go-redis/v9" + + "github.com/topfans/backend/pkg/database" + "github.com/topfans/backend/pkg/health" + "github.com/topfans/backend/pkg/logger" + "github.com/topfans/backend/services/aiChatService/model" + "github.com/topfans/backend/services/aiChatService/provider" + "github.com/topfans/backend/services/aiChatService/repository" + "github.com/topfans/backend/services/aiChatService/service" + pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat" + "go.uber.org/zap" +) + +var ( + port = flag.Int("port", getEnvInt("PORT", 20008), "Dubbo service port") + dbHost = flag.String("db-host", getEnv("DB_HOST", "localhost"), "Database host") + dbPort = flag.Int("db-port", getEnvInt("DB_PORT", 5432), "Database port") + dbUser = flag.String("db-user", getEnv("DB_USER", "postgres"), "Database user") + dbPassword = flag.String("db-password", getEnv("DB_PASSWORD", ""), "Database password") + dbName = flag.String("db-name", getEnv("DB_NAME", "top-fans"), "Database name") + redisHost = flag.String("redis-host", getEnv("REDIS_HOST", "127.0.0.1"), "Redis host") + redisPort = flag.Int("redis-port", getEnvInt("REDIS_PORT", 6379), "Redis port") + redisPassword = flag.String("redis-password", getEnv("REDIS_PASSWORD", ""), "Redis password") + redisDB = flag.Int("redis-db", getEnvInt("REDIS_DB", 0), "Redis db") + contextTTL = flag.Int("context-ttl", getEnvInt("CONTEXT_TTL", 86400), "Context TTL in seconds") + triggerTurns = flag.Int("trigger-turns", getEnvInt("TRIGGER_TURNS", 5), "Turns to trigger memory extraction") + healthHandler *health.Handler +) + +func getEnv(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +func getEnvInt(key string, fallback int) int { + if v := os.Getenv(key); v != "" { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return fallback +} + +func main() { + godotenv.Load() + flag.Parse() + + env := os.Getenv("ENV") + if env == "" { + env = "development" + } + + if err := logger.Init(logger.Config{ + ServiceName: "ai-chat-service", + Environment: env, + LogLevel: os.Getenv("LOG_LEVEL"), + }); err != nil { + panic(fmt.Sprintf("Failed to initialize logger: %v", err)) + } + defer logger.Sync() + + logger.Logger.Info("Starting AI Chat Service...") + + // 初始化数据库 + dbConfig := database.Config{ + Host: *dbHost, + Port: *dbPort, + User: *dbUser, + Password: *dbPassword, + DBName: *dbName, + SSLMode: "disable", + TimeZone: "Asia/Shanghai", + } + + if err := database.Init(dbConfig); err != nil { + logger.Logger.Fatal(fmt.Sprintf("Failed to initialize database: %v", err)) + } + logger.Logger.Info("Database initialized successfully") + + // 初始化 Redis + redisClient := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", *redisHost, *redisPort), + Password: *redisPassword, + DB: *redisDB, + }) + + // 测试 Redis 连接 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := redisClient.Ping(ctx).Err(); err != nil { + logger.Logger.Fatal(fmt.Sprintf("Failed to connect to Redis: %v", err)) + } + cancel() + logger.Logger.Info("Redis connected successfully") + + // 启动健康检查 HTTP 服务器 + healthPort := *port + 1000 + healthHandler = health.NewHandler("ai-chat-service", healthPort) + healthHandler.Start() + + // 自动迁移数据库表 + if err := autoMigrate(); err != nil { + logger.Logger.Fatal(fmt.Sprintf("Failed to migrate database: %v", err)) + } + + // 创建 Repository 层实例 + personaRepo := repository.NewPostgreSQLPersonaRepository(database.GetDB()) + configRepo := repository.NewPostgreSQLConfigRepository(database.GetDB()) + shortTermMemoryRepo := repository.NewRedisMemoryRepository(redisClient, *contextTTL) + longTermMemoryRepo := repository.NewPostgreSQLMemoryRepository(database.GetDB()) + logger.Logger.Info("Repository layer initialized") + + // 从数据库加载 LLM 配置 + loadCtx := context.Background() + llmConfigs, err := configRepo.GetByCategory(loadCtx, "llm") + if err != nil { + logger.Logger.Warn("Failed to load LLM configs from database, using env defaults", zap.Error(err)) + } + + // 获取 MiniMax 配置 + miniMaxAPIKey := getEnv("MINIMAX_API_KEY", "") + miniMaxAPIURL := getEnv("MINIMAX_API_URL", "https://api.minimaxi.com/v1") + miniMaxModel := getEnv("MINIMAX_MODEL", "M2-her") + if val, ok := llmConfigs["minimax.api_key"]; ok && val != "" { + miniMaxAPIKey = val + } + if val, ok := llmConfigs["minimax.api_url"]; ok && val != "" { + miniMaxAPIURL = val + } + if val, ok := llmConfigs["minimax.model"]; ok && val != "" { + miniMaxModel = val + } + + // 获取 Qwen 配置 + qwenAPIKey := getEnv("QWEN_API_KEY", "") + qwenAPIURL := getEnv("QWEN_API_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + qwenModel := getEnv("QWEN_MODEL", "qwen-plus") + if val, ok := llmConfigs["qwen.api_key"]; ok && val != "" { + qwenAPIKey = val + } + if val, ok := llmConfigs["qwen.api_url"]; ok && val != "" { + qwenAPIURL = val + } + if val, ok := llmConfigs["qwen.model"]; ok && val != "" { + qwenModel = val + } + + logger.Logger.Info("LLM configs loaded", + zap.String("minimax_url", miniMaxAPIURL), + zap.String("minimax_model", miniMaxModel), + zap.Bool("has_minimax_key", miniMaxAPIKey != ""), + zap.String("qwen_url", qwenAPIURL), + zap.String("qwen_model", qwenModel), + zap.Bool("has_qwen_key", qwenAPIKey != ""), + ) + + // 创建 Service 层实例 + llmService := service.NewLLMService( + miniMaxAPIURL, + miniMaxAPIKey, + miniMaxModel, + qwenAPIURL, + qwenAPIKey, + qwenModel, + ) + personaService := service.NewPersonaService(personaRepo) + memoryService := service.NewMemoryService(shortTermMemoryRepo, longTermMemoryRepo) + auditService := service.NewAuditService() + chatService := service.NewChatService( + llmService, + personaService, + memoryService, + auditService, + *contextTTL, + *triggerTurns, + ) + logger.Logger.Info("Service layer initialized") + + // 创建 Provider 层实例 + aiChatProvider := provider.NewAIChatProvider( + chatService, + personaService, + memoryService, + auditService, + ) + logger.Logger.Info("Provider layer initialized") + + // 创建 Dubbo 服务器 + srv, err := server.NewServer( + server.WithServerProtocol( + protocol.WithPort(*port), + protocol.WithTriple(), + ), + ) + if err != nil { + logger.Logger.Fatal(fmt.Sprintf("Failed to create Dubbo server: %v", err)) + } + + // 注册 AI Chat Service + if err := pbAIChat.RegisterAIChatServiceHandler(srv, aiChatProvider); err != nil { + logger.Logger.Fatal(fmt.Sprintf("Failed to register AI Chat Service: %v", err)) + } + + // 启动服务 + if err := srv.Serve(); err != nil { + logger.Logger.Fatal(fmt.Sprintf("Failed to start AI Chat Service: %v", err)) + } + + logger.Logger.Info(fmt.Sprintf("AI Chat Service started successfully on port %d", *port)) + + // 等待退出信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + logger.Logger.Info("Shutting down AI Chat Service...") + + if healthHandler != nil { + healthHandler.Stop() + } + + redisClient.Close() +} + +// autoMigrate 自动迁移数据库表 +func autoMigrate() error { + db := database.GetDB() + if db == nil { + return fmt.Errorf("database is not initialized") + } + + tables := []interface{}{ + &model.Persona{}, + &model.UserMemory{}, + &model.Config{}, + } + + for _, table := range tables { + if err := db.AutoMigrate(table); err != nil { + // 如果约束删除失败(表已存在该约束),忽略并继续 + logger.Logger.Warn(fmt.Sprintf("Migration warning for %T: %v", table, err)) + } + } + + logger.Logger.Info("Database migration completed successfully") + return nil +} \ No newline at end of file diff --git a/backend/services/aiChatService/model/ai_chat_errors.go b/backend/services/aiChatService/model/ai_chat_errors.go new file mode 100644 index 0000000..f5448b0 --- /dev/null +++ b/backend/services/aiChatService/model/ai_chat_errors.go @@ -0,0 +1,26 @@ +package model + +import "errors" + +var ( + // ErrPersonaNotFound 人设不存在 + ErrPersonaNotFound = errors.New("persona_not_found") + + // ErrMemoryNotFound 记忆不存在 + ErrMemoryNotFound = errors.New("memory_not_found") + + // ErrAuditBlocked 内容审核未通过 + ErrAuditBlocked = errors.New("audit_blocked") + + // ErrSessionNotFound 会话不存在 + ErrSessionNotFound = errors.New("session_not_found") + + // ErrConfigNotFound 配置不存在 + ErrConfigNotFound = errors.New("config_not_found") + + // ErrLLMFailed 大模型调用失败 + ErrLLMFailed = errors.New("llm_failed") + + // ErrInvalidRequest 无效请求 + ErrInvalidRequest = errors.New("invalid_request") +) \ No newline at end of file diff --git a/backend/services/aiChatService/model/ai_chat_models.go b/backend/services/aiChatService/model/ai_chat_models.go new file mode 100644 index 0000000..2010223 --- /dev/null +++ b/backend/services/aiChatService/model/ai_chat_models.go @@ -0,0 +1,102 @@ +package model + +import ( + "time" + + "github.com/google/uuid" +) + +// Persona 人设结构 +type Persona struct { + ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;default:gen_random_uuid()"` + UserID int64 `json:"user_id" gorm:"index;not null"` + Name string `json:"name" gorm:"type:varchar(64);not null"` + Description string `json:"description" gorm:"type:text"` + AvatarURL string `json:"avatar_url" gorm:"type:varchar(512)"` + TalkStyle string `json:"talk_style" gorm:"type:varchar(256)"` + SystemPrompt string `json:"system_prompt" gorm:"type:text;not null"` + IsDefault bool `json:"is_default" gorm:"default:false"` + CreatedAt int64 `json:"created_at" gorm:"autoCreateTime:milli"` + UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime:milli"` +} + +// TableName 指定表名 +func (Persona) TableName() string { + return "ai_personas" +} + +// UserMemory 用户长期记忆 +type UserMemory struct { + ID uint `json:"id" gorm:"primaryKey;autoIncrement"` + UserID int64 `json:"user_id" gorm:"index;not null"` + Content string `json:"content" gorm:"type:text;not null"` + Keywords []string `json:"keywords" gorm:"type:text[]"` + Weight int `json:"weight" gorm:"default:50"` + IsCore bool `json:"is_core" gorm:"default:false"` + CreatedAt int64 `json:"created_at" gorm:"autoCreateTime:milli"` + UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime:milli"` +} + +// TableName 指定表名 +func (UserMemory) TableName() string { + return "ai_user_memories" +} + +// Config 配置结构 +type Config struct { + ID uint `json:"id" gorm:"primaryKey;autoIncrement"` + ConfigKey string `json:"config_key" gorm:"type:varchar(128);uniqueIndex;not null"` + ConfigValue string `json:"config_value" gorm:"type:text;not null"` + ConfigType string `json:"config_type" gorm:"type:varchar(32);default:string"` + Category string `json:"category" gorm:"type:varchar(64);index"` + Description string `json:"description" gorm:"type:varchar(256)"` + IsEncrypted bool `json:"is_encrypted" gorm:"default:false"` + UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime:milli"` + CreatedAt int64 `json:"created_at" gorm:"autoCreateTime:milli"` +} + +// TableName 指定表名 +func (Config) TableName() string { + return "ai_chat_configs" +} + +// Message 对话消息 +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatContext 短期上下文(Redis 存储) +type ChatContext struct { + SessionID string `json:"session_id"` + Messages []Message `json:"messages"` + PersonaID string `json:"persona_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// PersonaInfo 人设信息(API 返回) +type PersonaInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + AvatarURL string `json:"avatar_url"` + TalkStyle string `json:"talk_style"` + IsDefault bool `json:"is_default"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +// ToPersonaInfo converts Persona to PersonaInfo +func (p *Persona) ToPersonaInfo() PersonaInfo { + return PersonaInfo{ + ID: p.ID.String(), + Name: p.Name, + Description: p.Description, + AvatarURL: p.AvatarURL, + TalkStyle: p.TalkStyle, + IsDefault: p.IsDefault, + CreatedAt: p.CreatedAt, + UpdatedAt: p.UpdatedAt, + } +} \ No newline at end of file diff --git a/backend/services/aiChatService/provider/ai_chat_provider.go b/backend/services/aiChatService/provider/ai_chat_provider.go new file mode 100644 index 0000000..2cdb732 --- /dev/null +++ b/backend/services/aiChatService/provider/ai_chat_provider.go @@ -0,0 +1,442 @@ +package provider + +import ( + "context" + "fmt" + "io" + + "dubbo.apache.org/dubbo-go/v3/common/constant" + "github.com/topfans/backend/pkg/logger" + "github.com/topfans/backend/services/aiChatService/model" + "github.com/topfans/backend/services/aiChatService/service" + pb "github.com/topfans/backend/pkg/proto/ai_chat" + "go.uber.org/zap" +) + +// AIChatProvider AI Chat 服务 Provider 实现 +type AIChatProvider struct { + chatService *service.ChatService + personaService *service.PersonaService + memoryService *service.MemoryService + auditService *service.AuditService +} + +// 确保 AIChatProvider 实现了 AIChatServiceHandler 接口 +var _ pb.AIChatServiceHandler = (*AIChatProvider)(nil) + +// NewAIChatProvider 创建 AIChatProvider 实例 +func NewAIChatProvider( + chatService *service.ChatService, + personaService *service.PersonaService, + memoryService *service.MemoryService, + auditService *service.AuditService, +) *AIChatProvider { + return &AIChatProvider{ + chatService: chatService, + personaService: personaService, + memoryService: memoryService, + auditService: auditService, + } +} + +// InitSession 初始化会话,返回欢迎消息 +func (p *AIChatProvider) InitSession(ctx context.Context, req *pb.InitSessionRequest) (*pb.InitSessionResponse, error) { + userID, starID, err := extractUserInfoFromDubboAttachments(ctx) + sessionID := req.SessionId + if sessionID == "" { + sessionID = fmt.Sprintf("%d_%d", userID, starID) + } + + if err != nil { + logger.Logger.Error("Failed to extract user info from attachments", + zap.Error(err), + ) + return nil, err + } + + logger.Logger.Info("Received InitSession request", + zap.Int64("user_id", userID), + zap.String("session_id", sessionID), + ) + + // 获取欢迎消息 + welcomeMessage := p.chatService.GetWelcomeMessage(sessionID, userID, starID) + + return &pb.InitSessionResponse{ + WelcomeMessage: welcomeMessage, + SessionId: sessionID, + }, nil +} + +// SendMessage 发送消息(流式返回) +func (p *AIChatProvider) SendMessage(ctx context.Context, req *pb.ChatMessageRequest, stream pb.AIChatService_SendMessageServer) error { + userID, starID, err := extractUserInfoFromDubboAttachments(ctx) + sessionID := req.SessionId + if sessionID == "" { + sessionID = fmt.Sprintf("%d_%d", userID, starID) + } + + if err != nil { + logger.Logger.Error("Failed to extract user info from attachments", + zap.Error(err), + ) + stream.Send(&pb.ChatMessageResponse{ + Content: "user authentication required", + SessionId: sessionID, + IsEnd: true, + Error: err.Error(), + }) + return err + } + if sessionID == "" { + // 如果没有 sessionID,生成一个 + sessionID = fmt.Sprintf("%d_%d", userID, starID) + } + + message := req.Message + personaID := req.PersonaId + + logger.Logger.Info("Received SendMessage request", + zap.Int64("user_id", userID), + zap.String("session_id", sessionID), + zap.String("message", message), + ) + + // 1. 前置审核 + if !p.auditService.AuditText(message) { + logger.Logger.Info("Message blocked by audit") + stream.Send(&pb.ChatMessageResponse{ + Content: p.auditService.DefaultSafeResponse(), + SessionId: sessionID, + IsEnd: false, + }) + stream.Send(&pb.ChatMessageResponse{ + SessionId: sessionID, + IsEnd: true, + }) + return nil + } + + // 2. 获取人设 + persona, err := p.personaService.GetPersonaOrDefault(ctx, userID, personaID) + if err != nil { + logger.Logger.Error("Failed to get persona", zap.Error(err)) + stream.Send(&pb.ChatMessageResponse{ + Content: err.Error(), + IsEnd: true, + Error: err.Error(), + }) + return err + } + + // 3. 记忆召回 + memoryText, _ := p.memoryService.RecallMemories(ctx, userID, message, 5) + + // 4. 获取对话历史 + history, _ := p.memoryService.GetContext(ctx, sessionID) + + // 5. 构建 Prompt + tokenizer := &service.Tokenizer{} + messages, _ := service.BuildPrompt( + persona.SystemPrompt, + memoryText, + history, + message, + tokenizer, + ) + + // 6. 检查是否需要调用大模型 + if service.IsNoNeedLLMCall(message) { + stream.Send(&pb.ChatMessageResponse{ + Content: "好的,我听到了。", + SessionId: sessionID, + IsEnd: false, + }) + stream.Send(&pb.ChatMessageResponse{ + SessionId: sessionID, + IsEnd: true, + }) + return nil + } + + // 7. 调用大模型(流式) + streamReader, err := p.chatService.LLMService.StreamChat(ctx, messages) + if err != nil { + logger.Logger.Error("LLM call failed", zap.Error(err)) + + // 检查是否是敏感内容错误 + if _, ok := err.(*service.SensitiveContentError); ok { + logger.Logger.Info("Content blocked by MiniMax safety filter, trying backup model") + // 尝试备用模型 + streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages) + if err != nil { + // 备用模型也失败 + logger.Logger.Error("Backup model also failed", zap.Error(err)) + stream.Send(&pb.ChatMessageResponse{ + Content: p.auditService.DefaultSafeResponse(), + SessionId: sessionID, + IsEnd: true, + }) + return nil + } + } else { + // 其他错误,尝试备用模型 + streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages) + if err != nil { + logger.Logger.Error("Backup model also failed", zap.Error(err)) + stream.Send(&pb.ChatMessageResponse{ + Content: "抱歉,服务暂时不可用", + SessionId: sessionID, + IsEnd: true, + Error: err.Error(), + }) + return err + } + } + } + defer streamReader.Close() + + // 8. 流式处理 + var fullResponse string + var sentEnd = false + for { + content, done, err := streamReader.Next() + if err != nil { + if err == io.EOF { + // 流结束,发送 is_end + if !sentEnd { + stream.Send(&pb.ChatMessageResponse{ + SessionId: sessionID, + IsEnd: true, + }) + sentEnd = true + } + break + } + logger.Logger.Error("Stream read error", zap.Error(err)) + break + } + + // 后置审核(逐 token) + if !p.auditService.AuditResponse(content) { + logger.Logger.Info("Response blocked by audit") + streamReader.Close() + // 发送安全回复作为替代 + stream.Send(&pb.ChatMessageResponse{ + Content: p.auditService.DefaultSafeResponse(), + SessionId: sessionID, + IsEnd: false, + }) + stream.Send(&pb.ChatMessageResponse{ + SessionId: sessionID, + IsEnd: true, + }) + sentEnd = true + return nil + } + + fullResponse += content + + // 发送 token 给客户端 + if err := stream.Send(&pb.ChatMessageResponse{ + Content: content, + SessionId: sessionID, + IsEnd: done, + }); err != nil { + logger.Logger.Error("Failed to send message to stream", zap.Error(err)) + return err + } + if done { + sentEnd = true + } + } + + // 9. 保存上下文 + newHistory := append(history, model.Message{Role: "user", Content: message}) + newHistory = append(newHistory, model.Message{Role: "assistant", Content: fullResponse}) + p.memoryService.SaveContext(ctx, sessionID, newHistory, personaID) + + // 10. 触发记忆提取(每5轮) + newTurns := len(newHistory) / 2 + logger.Logger.Info("Memory extraction check", + zap.Int("message_count", len(newHistory)), + zap.Int("turns", newTurns), + zap.Bool("should_extract", newTurns >= 5), + ) + if newTurns >= 5 { + logger.Logger.Info("Triggering memory extraction", zap.Int64("user_id", userID)) + if err := p.memoryService.ExtractMemory(ctx, userID, newHistory); err != nil { + logger.Logger.Error("Failed to extract memory", zap.Error(err)) + } else { + logger.Logger.Info("Memory extracted successfully", zap.Int64("user_id", userID)) + } + } + + logger.Logger.Info("SendMessage completed", + zap.Int64("user_id", userID), + zap.String("session_id", sessionID), + zap.Int("response_length", len(fullResponse)), + ) + + return nil +} + +// GetHistory 获取对话历史 +func (p *AIChatProvider) GetHistory(ctx context.Context, req *pb.ChatHistoryRequest) (*pb.ChatHistoryResponse, error) { + userID, starID, err := extractUserInfoFromDubboAttachments(ctx) + if err != nil { + logger.Logger.Error("Failed to extract user info from attachments", + zap.Error(err), + ) + return nil, err + } + + sessionID := req.SessionId + if sessionID == "" { + sessionID = fmt.Sprintf("%d_%d", userID, starID) + } + + logger.Logger.Info("Received GetHistory request", + zap.Int64("user_id", userID), + zap.String("session_id", sessionID), + ) + + messages, err := p.memoryService.GetContext(ctx, sessionID) + if err != nil { + return nil, err + } + + pbMessages := make([]*pb.Message, len(messages)) + for i, m := range messages { + pbMessages[i] = &pb.Message{ + Role: m.Role, + Content: m.Content, + } + } + + return &pb.ChatHistoryResponse{ + History: pbMessages, + }, nil +} + +// GetPersonas 获取用户的所有人设 +func (p *AIChatProvider) GetPersonas(ctx context.Context, req *pb.GetPersonasRequest) (*pb.PersonaListResponse, error) { + userID := req.UserId + + logger.Logger.Info("Received GetPersonas request", + zap.Int64("user_id", userID), + ) + + personas, err := p.personaService.GetPersonas(ctx, userID) + if err != nil { + return nil, err + } + + pbPersonas := make([]*pb.PersonaInfo, len(personas)) + for i, persona := range personas { + pbPersonas[i] = &pb.PersonaInfo{ + Id: persona.ID, + Name: persona.Name, + Description: persona.Description, + AvatarUrl: persona.AvatarURL, + TalkStyle: persona.TalkStyle, + IsDefault: persona.IsDefault, + CreatedAt: persona.CreatedAt, + UpdatedAt: persona.UpdatedAt, + } + } + + return &pb.PersonaListResponse{ + Personas: pbPersonas, + }, nil +} + +// extractUserInfoFromDubboAttachments 从 Dubbo attachments 中提取用户信息 +func extractUserInfoFromDubboAttachments(ctx context.Context) (int64, int64, error) { + logger.Logger.Debug("Extracting user info from Dubbo attachments", + zap.Any("context_type", fmt.Sprintf("%T", ctx)), + ) + + // Try to get any value from context + if attachments := ctx.Value(constant.AttachmentKey); attachments != nil { + logger.Logger.Debug("Found attachments via constant.AttachmentKey", + zap.Any("attachments", attachments), + zap.String("type", fmt.Sprintf("%T", attachments)), + ) + if attMap, ok := attachments.(map[string]interface{}); ok { + logger.Logger.Debug("Attachments map content", + zap.Any("map", attMap), + ) + userID := parseIntValue(attMap["user_id"]) + starID := parseIntValue(attMap["star_id"]) + logger.Logger.Debug("Parsed user info from attachments", + zap.Any("user_id_raw", attMap["user_id"]), + zap.Int64("user_id", userID), + zap.Any("star_id_raw", attMap["star_id"]), + zap.Int64("star_id", starID), + ) + if userID > 0 && starID > 0 { + return userID, starID, nil + } + logger.Logger.Warn("Parsed user_id or star_id is zero", + zap.Int64("user_id", userID), + zap.Int64("star_id", starID), + ) + } else { + logger.Logger.Warn("Attachments is not map[string]interface{}", + zap.String("actual_type", fmt.Sprintf("%T", attachments)), + ) + } + } else { + logger.Logger.Warn("ctx.Value(constant.AttachmentKey) returned nil", + zap.String("constant_attachment_key", string(constant.AttachmentKey)), + ) + } + + // Debug: list all keys in context + logger.Logger.Warn("Checking alternative key: 'attachment'") + if val := ctx.Value("attachment"); val != nil { + logger.Logger.Debug("Found value with key 'attachment'", + zap.Any("value", val), + zap.String("type", fmt.Sprintf("%T", val)), + ) + } + + return 0, 0, fmt.Errorf("user info not found in Dubbo attachments") +} + +// parseIntValue 解析各种类型的值为 int64 +func parseIntValue(v interface{}) int64 { + switch val := v.(type) { + case int64: + return val + case int: + return int64(val) + case float64: + return int64(val) + case string: + var result int64 + fmt.Sscanf(val, "%d", &result) + return result + case []string: + if len(val) > 0 { + var result int64 + fmt.Sscanf(val[0], "%d", &result) + return result + } + case []interface{}: + if len(val) > 0 { + switch s := val[0].(type) { + case string: + var result int64 + fmt.Sscanf(s, "%d", &result) + return result + case int: + return int64(s) + case int64: + return s + } + } + } + return 0 +} diff --git a/backend/services/aiChatService/repository/config_repository.go b/backend/services/aiChatService/repository/config_repository.go new file mode 100644 index 0000000..17aa09c --- /dev/null +++ b/backend/services/aiChatService/repository/config_repository.go @@ -0,0 +1,60 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/topfans/backend/services/aiChatService/model" + "gorm.io/gorm" +) + +// ConfigRepository 配置仓库接口 +type ConfigRepository interface { + Get(ctx context.Context, key string) (string, error) + GetByCategory(ctx context.Context, category string) (map[string]string, error) + Update(ctx context.Context, key, value string) error +} + +// PostgreSQLConfigRepository PostgreSQL 配置仓库实现 +type PostgreSQLConfigRepository struct { + db *gorm.DB +} + +// NewPostgreSQLConfigRepository 创建配置仓库 +func NewPostgreSQLConfigRepository(db *gorm.DB) *PostgreSQLConfigRepository { + return &PostgreSQLConfigRepository{db: db} +} + +// Get 获取单个配置值 +func (r *PostgreSQLConfigRepository) Get(ctx context.Context, key string) (string, error) { + var config model.Config + if err := r.db.WithContext(ctx).Where("config_key = ?", key).First(&config).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return "", model.ErrConfigNotFound + } + return "", fmt.Errorf("failed to get config: %w", err) + } + return config.ConfigValue, nil +} + +// GetByCategory 按分类获取所有配置 +func (r *PostgreSQLConfigRepository) GetByCategory(ctx context.Context, category string) (map[string]string, error) { + var configs []model.Config + if err := r.db.WithContext(ctx).Where("category = ?", category).Find(&configs).Error; err != nil { + return nil, fmt.Errorf("failed to get configs: %w", err) + } + + result := make(map[string]string) + for _, c := range configs { + result[c.ConfigKey] = c.ConfigValue + } + return result, nil +} + +// Update 更新配置值 +func (r *PostgreSQLConfigRepository) Update(ctx context.Context, key, value string) error { + return r.db.WithContext(ctx). + Model(&model.Config{}). + Where("config_key = ?", key). + Update("config_value", value).Error +} \ No newline at end of file diff --git a/backend/services/aiChatService/repository/memory_repository.go b/backend/services/aiChatService/repository/memory_repository.go new file mode 100644 index 0000000..642d0b8 --- /dev/null +++ b/backend/services/aiChatService/repository/memory_repository.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/lib/pq" + "github.com/redis/go-redis/v9" + "github.com/topfans/backend/services/aiChatService/model" + "gorm.io/gorm" +) + +// ShortTermMemoryRepository 短期记忆仓库接口 (Redis) +type ShortTermMemoryRepository interface { + SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error + GetContext(ctx context.Context, sessionID string) ([]model.Message, error) + DeleteContext(ctx context.Context, sessionID string) error +} + +// LongTermMemoryRepository 长期记忆仓库接口 (PostgreSQL) +type LongTermMemoryRepository interface { + SaveMemory(ctx context.Context, memory *model.UserMemory) error + GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error) + GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error) +} + +// MemoryRepository 记忆仓库接口(兼容性别名,用于不需要区分短期/长期的场景) +type MemoryRepository interface { + ShortTermMemoryRepository + LongTermMemoryRepository +} + +// RedisMemoryRepository Redis 短期记忆实现 +type RedisMemoryRepository struct { + client *redis.Client + ttl time.Duration +} + +// NewRedisMemoryRepository 创建 Redis 短期记忆仓库 +func NewRedisMemoryRepository(client *redis.Client, ttlSeconds int) *RedisMemoryRepository { + return &RedisMemoryRepository{ + client: client, + ttl: time.Duration(ttlSeconds) * time.Second, + } +} + +// contextKey 生成 Redis key +func contextKey(sessionID string) string { + return fmt.Sprintf("context:%s", sessionID) +} + +// SaveContext 保存短期上下文到 Redis +func (r *RedisMemoryRepository) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error { + data := model.ChatContext{ + SessionID: sessionID, + Messages: messages, + PersonaID: personaID, + UpdatedAt: time.Now(), + } + + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal context: %w", err) + } + + return r.client.Set(ctx, contextKey(sessionID), jsonData, r.ttl).Err() +} + +// GetContext 从 Redis 获取短期上下文 +func (r *RedisMemoryRepository) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) { + data, err := r.client.Get(ctx, contextKey(sessionID)).Bytes() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, fmt.Errorf("failed to get context: %w", err) + } + + var chatCtx model.ChatContext + if err := json.Unmarshal(data, &chatCtx); err != nil { + return nil, fmt.Errorf("failed to unmarshal context: %w", err) + } + + return chatCtx.Messages, nil +} + +// DeleteContext 删除短期上下文 +func (r *RedisMemoryRepository) DeleteContext(ctx context.Context, sessionID string) error { + return r.client.Del(ctx, contextKey(sessionID)).Err() +} + +// PostgreSQLMemoryRepository PostgreSQL 长期记忆实现 +type PostgreSQLMemoryRepository struct { + db *gorm.DB +} + +// NewPostgreSQLMemoryRepository 创建 PostgreSQL 长期记忆仓库 +func NewPostgreSQLMemoryRepository(db *gorm.DB) *PostgreSQLMemoryRepository { + return &PostgreSQLMemoryRepository{db: db} +} + +// SaveMemory 保存长期记忆 +func (r *PostgreSQLMemoryRepository) SaveMemory(ctx context.Context, memory *model.UserMemory) error { + now := time.Now().UnixMilli() + if memory.CreatedAt == 0 { + memory.CreatedAt = now + } + if memory.UpdatedAt == 0 { + memory.UpdatedAt = now + } + + return r.db.WithContext(ctx).Exec( + `INSERT INTO "ai_user_memories" ("user_id", "content", "keywords", "weight", "is_core", "created_at", "updated_at") + VALUES ($1, $2, $3, $4, $5, $6, $7)`, + memory.UserID, + memory.Content, + pq.Array(memory.Keywords), + memory.Weight, + memory.IsCore, + memory.CreatedAt, + memory.UpdatedAt, + ).Error +} + +// GetMemories 根据关键词查询记忆 +func (r *PostgreSQLMemoryRepository) GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error) { + var memories []model.UserMemory + query := r.db.WithContext(ctx). + Where("user_id = ?", userID). + Order("weight DESC, created_at DESC"). + Limit(limit) + + if len(keywords) > 0 { + query = query.Where("keywords && $1", pq.Array(keywords)) + } + + if err := query.Find(&memories).Error; err != nil { + return nil, fmt.Errorf("failed to get memories: %w", err) + } + + return memories, nil +} + +// GetMemoriesByUserID 获取用户所有记忆 +func (r *PostgreSQLMemoryRepository) GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error) { + var memories []model.UserMemory + if err := r.db.WithContext(ctx). + Where("user_id = ?", userID). + Order("weight DESC, created_at DESC"). + Find(&memories).Error; err != nil { + return nil, fmt.Errorf("failed to get memories: %w", err) + } + return memories, nil +} \ No newline at end of file diff --git a/backend/services/aiChatService/repository/persona_repository.go b/backend/services/aiChatService/repository/persona_repository.go new file mode 100644 index 0000000..1260a2c --- /dev/null +++ b/backend/services/aiChatService/repository/persona_repository.go @@ -0,0 +1,122 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/topfans/backend/services/aiChatService/model" + "gorm.io/gorm" +) + +// PersonaRepository 人设仓库接口 +type PersonaRepository interface { + Create(ctx context.Context, persona *model.Persona) error + GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error) + GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error) + GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error) + Update(ctx context.Context, persona *model.Persona) error + Delete(ctx context.Context, id uuid.UUID) error + EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error) +} + +// DefaultSystemPrompt 默认系统提示词 +const DefaultSystemPrompt = `你是一个温柔体贴的AI伴侣,名字叫角角。你善于倾听,能理解用户的情绪, +用温暖的话语陪伴用户。说话风格亲切自然,像朋友聊天一样。 +不要过于正式或说教,当用户情绪低落时,先给予共情和安慰。` + +// DefaultPersonaName 默认人设名称 +const DefaultPersonaName = "角角" + +// DefaultPersonaDescription 默认人设描述 +const DefaultPersonaDescription = "温柔陪伴型闺蜜" + +// PostgreSQLPersonaRepository PostgreSQL 人设仓库实现 +type PostgreSQLPersonaRepository struct { + db *gorm.DB +} + +// NewPostgreSQLPersonaRepository 创建人设仓库 +func NewPostgreSQLPersonaRepository(db *gorm.DB) *PostgreSQLPersonaRepository { + return &PostgreSQLPersonaRepository{db: db} +} + +// Create 创建人设 +func (r *PostgreSQLPersonaRepository) Create(ctx context.Context, persona *model.Persona) error { + return r.db.WithContext(ctx).Create(persona).Error +} + +// GetByID 根据 ID 获取人设 +func (r *PostgreSQLPersonaRepository) GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error) { + var persona model.Persona + if err := r.db.WithContext(ctx).Where("id = ?", id).First(&persona).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, model.ErrPersonaNotFound + } + return nil, fmt.Errorf("failed to get persona: %w", err) + } + return &persona, nil +} + +// GetByUserID 获取用户的所有人设 +func (r *PostgreSQLPersonaRepository) GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error) { + var personas []model.Persona + if err := r.db.WithContext(ctx). + Where("user_id = ?", userID). + Order("created_at DESC"). + Find(&personas).Error; err != nil { + return nil, fmt.Errorf("failed to get personas: %w", err) + } + return personas, nil +} + +// GetDefaultByUserID 获取用户的默认人设 +func (r *PostgreSQLPersonaRepository) GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error) { + var persona model.Persona + if err := r.db.WithContext(ctx). + Where("user_id = ? AND is_default = TRUE", userID). + First(&persona).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, model.ErrPersonaNotFound + } + return nil, fmt.Errorf("failed to get default persona: %w", err) + } + return &persona, nil +} + +// Update 更新人设 +func (r *PostgreSQLPersonaRepository) Update(ctx context.Context, persona *model.Persona) error { + return r.db.WithContext(ctx).Save(persona).Error +} + +// Delete 删除人设 +func (r *PostgreSQLPersonaRepository) Delete(ctx context.Context, id uuid.UUID) error { + return r.db.WithContext(ctx).Delete(&model.Persona{}, "id = ?", id).Error +} + +// EnsureDefaultPersona 确保用户有默认人设 +func (r *PostgreSQLPersonaRepository) EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error) { + // 检查是否已有默认人设 + persona, err := r.GetDefaultByUserID(ctx, userID) + if err == nil { + return persona, nil + } + if err != model.ErrPersonaNotFound { + return nil, err + } + + // 创建默认人设 + persona = &model.Persona{ + UserID: userID, + Name: DefaultPersonaName, + Description: DefaultPersonaDescription, + SystemPrompt: DefaultSystemPrompt, + IsDefault: true, + } + + if err := r.Create(ctx, persona); err != nil { + return nil, fmt.Errorf("failed to create default persona: %w", err) + } + + return persona, nil +} diff --git a/backend/services/aiChatService/service/audit_service.go b/backend/services/aiChatService/service/audit_service.go new file mode 100644 index 0000000..a5b37c6 --- /dev/null +++ b/backend/services/aiChatService/service/audit_service.go @@ -0,0 +1,59 @@ +package service + +import "strings" + +// AuditService 审核服务 +type AuditService struct { + // 敏感词列表 + sensitiveWords []string +} + +// NewAuditService 创建审核服务 +func NewAuditService() *AuditService { + return &AuditService{ + sensitiveWords: []string{ + // 政治类 + "台独", "港独", "藏独", "疆独", "分裂", "颠覆", + // 色情类 + "色情", "裸聊", "约炮", "成人", + // 暴力类 + "杀人", "虐待", "暴力", + // 违规诱导 + "转账", "汇款", "银行卡", "密码", + // AI 身份冒充 + "你是真人", "你是人类", "真人在吗", + }, + } +} + +// AuditText 审核文本内容 +// 返回 true 表示通过,false 表示违规 +func (s *AuditService) AuditText(text string) bool { + for _, word := range s.sensitiveWords { + if strings.Contains(text, word) { + return false + } + } + return true +} + +// AuditResponse 审核 AI 回复(逐 token) +func (s *AuditService) AuditResponse(token string) bool { + // 累加 token 判断是否违规 + for _, word := range s.sensitiveWords { + if strings.Contains(token, word) { + return false + } + } + return true +} + +// DefaultSafeResponse 获取默认安全回复 +func (s *AuditService) DefaultSafeResponse() string { + return "抱歉,这个话题我无法继续,我们换个话题聊聊吧。" +} + +// AddSensitiveWord 添加敏感词 +func (s *AuditService) AddSensitiveWord(word string) { + s.sensitiveWords = append(s.sensitiveWords, word) +} \ No newline at end of file diff --git a/backend/services/aiChatService/service/chat_service.go b/backend/services/aiChatService/service/chat_service.go new file mode 100644 index 0000000..fc72f60 --- /dev/null +++ b/backend/services/aiChatService/service/chat_service.go @@ -0,0 +1,60 @@ +package service + +import ( + "context" + + "github.com/topfans/backend/services/aiChatService/model" +) + +// ChatService 对话服务 +type ChatService struct { + LLMService *LLMService + personaService *PersonaService + memoryService *MemoryService + auditService *AuditService + contextTTL int // 上下文过期时间(秒) + triggerTurns int // 触发记忆提取的轮数 +} + +// NewChatService 创建对话服务 +func NewChatService( + llmService *LLMService, + personaService *PersonaService, + memoryService *MemoryService, + auditService *AuditService, + contextTTL int, + triggerTurns int, +) *ChatService { + return &ChatService{ + LLMService: llmService, + personaService: personaService, + memoryService: memoryService, + auditService: auditService, + contextTTL: contextTTL, + triggerTurns: triggerTurns, + } +} + +// GetHistory 获取对话历史 +func (s *ChatService) GetHistory(ctx context.Context, sessionID string) ([]model.Message, error) { + return s.memoryService.GetContext(ctx, sessionID) +} + +// SaveContext 保存对话上下文 +func (s *ChatService) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error { + return s.memoryService.SaveContext(ctx, sessionID, messages, personaID) +} + +// ExtractMemory 提取记忆 +func (s *ChatService) ExtractMemory(ctx context.Context, userID int64, recentMessages []model.Message) error { + if ShouldExtractMemory(recentMessages, s.triggerTurns) { + return s.memoryService.ExtractMemory(ctx, userID, recentMessages) + } + return nil +} + +// GetWelcomeMessage 获取欢迎消息 +func (s *ChatService) GetWelcomeMessage(sessionID string, userID int64, starID int64) string { + // 默认欢迎消息 + return "亲爱的你来辣 ~~" +} \ No newline at end of file diff --git a/backend/services/aiChatService/service/llm_service.go b/backend/services/aiChatService/service/llm_service.go new file mode 100644 index 0000000..61158c2 --- /dev/null +++ b/backend/services/aiChatService/service/llm_service.go @@ -0,0 +1,312 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/topfans/backend/services/aiChatService/model" +) + +// StreamReader 流式读取器接口 +type StreamReader interface { + Next() (content string, done bool, err error) + Close() error +} + +// SensitiveContentError 敏感内容错误 +type SensitiveContentError struct { + Message string +} + +func (e *SensitiveContentError) Error() string { + return e.Message +} + +// LLMService 大模型服务 +type LLMService struct { + miniMaxClient *http.Client + qwenClient *http.Client + + miniMaxURL string + miniMaxKey string + miniMaxModel string + + qwenURL string + qwenKey string + qwenModel string + + failCount int + fallbackCount int + useBackup bool + mu sync.RWMutex +} + +// MiniMaxMessage MiniMax 消息格式 +type MiniMaxMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// MiniMaxRequest MiniMax 请求格式 +type MiniMaxRequest struct { + Model string `json:"model"` + Messages []MiniMaxMessage `json:"messages"` + Stream bool `json:"stream"` +} + +// MiniMaxResponse MiniMax 响应格式 +type MiniMaxResponse struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` +} + +// MiniMaxSSERecord SSE 格式记录 +type MiniMaxSSERecord struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` +} + +// NewLLMService 创建大模型服务 +func NewLLMService(miniMaxURL, miniMaxKey, miniMaxModel, qwenURL, qwenKey, qwenModel string) *LLMService { + return &LLMService{ + miniMaxClient: &http.Client{Timeout: 60 * time.Second}, + qwenClient: &http.Client{Timeout: 60 * time.Second}, + miniMaxURL: miniMaxURL, + miniMaxKey: miniMaxKey, + miniMaxModel: miniMaxModel, + qwenURL: qwenURL, + qwenKey: qwenKey, + qwenModel: qwenModel, + } +} + +// StreamChat 流式聊天 +func (s *LLMService) StreamChat(ctx context.Context, messages []model.Message) (StreamReader, error) { + // Qwen fallback 已禁用 + return s.streamChatMiniMax(ctx, messages) +} + +// StreamChatWithBackup 禁用备用模型 +func (s *LLMService) StreamChatWithBackup(ctx context.Context, messages []model.Message) (StreamReader, error) { + // Qwen fallback 已禁用,直接返回 MiniMax + return s.streamChatMiniMax(ctx, messages) +} + +// streamChatMiniMax MiniMax 流式聊天 +func (s *LLMService) streamChatMiniMax(ctx context.Context, messages []model.Message) (*MiniMaxStreamReader, error) { + reqMessages := make([]MiniMaxMessage, len(messages)) + for i, m := range messages { + reqMessages[i] = MiniMaxMessage{ + Role: m.Role, + Content: m.Content, + } + } + + reqBody := MiniMaxRequest{ + Model: s.miniMaxModel, + Messages: reqMessages, + Stream: true, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", s.miniMaxURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+s.miniMaxKey) + + resp, err := s.miniMaxClient.Do(req) + if err != nil { + s.incrementFailCount() + return nil, fmt.Errorf("failed to call MiniMax: %w", err) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + bodyStr := string(body) + + // 检查是否是敏感内容错误 (1027) + if strings.Contains(bodyStr, "1027") || strings.Contains(bodyStr, "new_sensitive") { + return nil, &SensitiveContentError{Message: "content blocked by MiniMax safety filter"} + } + + s.incrementFailCount() + // 如果是 401/403 等错误,切换到备用模型 + if resp.StatusCode == 401 || resp.StatusCode == 403 { + s.SwitchToBackup() + } + return nil, fmt.Errorf("MiniMax API error: %d, body: %s", resp.StatusCode, bodyStr) + } + + s.resetFailCount() + return &MiniMaxStreamReader{ + reader: resp.Body, + decoder: NewSSEDecoder(resp.Body), + }, nil +} + +// streamChatQwen 通义流式聊天 +func (s *LLMService) streamChatQwen(ctx context.Context, messages []model.Message) (*MiniMaxStreamReader, error) { + reqMessages := make([]MiniMaxMessage, len(messages)) + for i, m := range messages { + reqMessages[i] = MiniMaxMessage{ + Role: m.Role, + Content: m.Content, + } + } + + reqBody := MiniMaxRequest{ + Model: s.qwenModel, + Messages: reqMessages, + Stream: true, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", s.qwenURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+s.qwenKey) + + resp, err := s.qwenClient.Do(req) + if err != nil { + s.incrementFailCount() + return nil, fmt.Errorf("failed to call Qwen: %w", err) + } + + if resp.StatusCode != http.StatusOK { + s.incrementFailCount() + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("Qwen API error: %d, body: %s", resp.StatusCode, string(body)) + } + + s.resetFailCount() + return &MiniMaxStreamReader{ + reader: resp.Body, + decoder: NewSSEDecoder(resp.Body), + }, nil +} + +func (s *LLMService) incrementFailCount() { + s.mu.Lock() + s.failCount++ + if s.failCount >= 3 { + s.useBackup = true + } + s.mu.Unlock() +} + +func (s *LLMService) resetFailCount() { + s.mu.Lock() + s.failCount = 0 + s.mu.Unlock() +} + +func (s *LLMService) SwitchToBackup() { + s.mu.Lock() + s.useBackup = true + s.mu.Unlock() +} + +// MiniMaxStreamReader MiniMax 流式读取器 +type MiniMaxStreamReader struct { + reader io.ReadCloser + decoder *SSEDecoder + buffer string +} + +func (r *MiniMaxStreamReader) Next() (content string, done bool, err error) { + for { + line, err := r.decoder.Next() + if err != nil { + r.Close() + return "", true, err + } + if line == "" { + continue + } + + // SSE data: 格式 + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + return "", true, nil + } + + var record MiniMaxSSERecord + if err := json.Unmarshal([]byte(data), &record); err != nil { + continue + } + + if len(record.Choices) > 0 && record.Choices[0].Delta.Content != "" { + return record.Choices[0].Delta.Content, false, nil + } + } + } +} + +func (r *MiniMaxStreamReader) Close() error { + return r.reader.Close() +} + +// SSEDecoder SSE 解码器 +type SSEDecoder struct { + reader io.Reader +} + +func NewSSEDecoder(reader io.Reader) *SSEDecoder { + return &SSEDecoder{reader: reader} +} + +func (d *SSEDecoder) Next() (string, error) { + var line []byte + buf := make([]byte, 1) + + for { + n, err := d.reader.Read(buf) + if err != nil { + return "", err + } + if n == 0 { + continue + } + + if buf[0] == '\n' { + break + } + line = append(line, buf[0]) + } + + return strings.TrimRight(string(line), "\r"), nil +} \ No newline at end of file diff --git a/backend/services/aiChatService/service/memory_service.go b/backend/services/aiChatService/service/memory_service.go new file mode 100644 index 0000000..7577bf0 --- /dev/null +++ b/backend/services/aiChatService/service/memory_service.go @@ -0,0 +1,211 @@ +package service + +import ( + "context" + "strings" + + "github.com/topfans/backend/pkg/logger" + "github.com/topfans/backend/services/aiChatService/model" + "github.com/topfans/backend/services/aiChatService/repository" + "go.uber.org/zap" +) + +// MemoryService 记忆服务 +type MemoryService struct { + shortTermRepo repository.ShortTermMemoryRepository + longTermRepo repository.LongTermMemoryRepository +} + +// NewMemoryService 创建记忆服务 +func NewMemoryService(shortTermRepo repository.ShortTermMemoryRepository, longTermRepo repository.LongTermMemoryRepository) *MemoryService { + return &MemoryService{ + shortTermRepo: shortTermRepo, + longTermRepo: longTermRepo, + } +} + +// SaveContext 保存短期上下文 +func (s *MemoryService) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error { + return s.shortTermRepo.SaveContext(ctx, sessionID, messages, personaID) +} + +// GetContext 获取短期上下文 +func (s *MemoryService) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) { + return s.shortTermRepo.GetContext(ctx, sessionID) +} + +// RecallMemories 召回相关记忆 +func (s *MemoryService) RecallMemories(ctx context.Context, userID int64, userInput string, limit int) (string, error) { + // 从用户输入提取关键词 + keywords := extractKeywords(userInput) + + // 查询长期记忆 + memories, err := s.longTermRepo.GetMemories(ctx, userID, keywords, limit) + if err != nil { + return "", err + } + + if len(memories) == 0 { + return "", nil + } + + // 组装记忆文本 + var builder strings.Builder + builder.WriteString("# 用户核心记忆\n") + for _, m := range memories { + builder.WriteString("- ") + builder.WriteString(m.Content) + builder.WriteString("\n") + } + + return builder.String(), nil +} + +// ExtractMemory 提取记忆(每5轮对话触发一次) +func (s *MemoryService) ExtractMemory(ctx context.Context, userID int64, recentMessages []model.Message) error { + // 从最近的用户消息中提取关键词 + var userMessages []string + for i := len(recentMessages) - 1; i >= 0 && len(userMessages) < 5; i-- { + if recentMessages[i].Role == "user" { + userMessages = append(userMessages, recentMessages[i].Content) + } + } + + if len(userMessages) == 0 { + return nil + } + + // 简单关键词提取 + keywords := extractKeywordsFromMessages(userMessages) + + // 生成记忆摘要 + content := summarizeMessages(userMessages) + if content == "" { + return nil + } + + // 保存到长期记忆 + memory := &model.UserMemory{ + UserID: userID, + Content: content, + Keywords: keywords, + Weight: 50, + } + + return s.longTermRepo.SaveMemory(ctx, memory) +} + +// GetUserMemories 获取用户所有长期记忆 +func (s *MemoryService) GetUserMemories(ctx context.Context, userID int64) ([]model.UserMemory, error) { + return s.longTermRepo.GetMemoriesByUserID(ctx, userID) +} + +// extractKeywords 从用户输入提取关键词 +func extractKeywords(text string) []string { + // 简单实现:按标点符号分割,检查长度大于2的词 + var keywords []string + words := strings.FieldsFunc(text, func(r rune) bool { + return r == ',' || r == '.' || r == '!' || r == '?' || r == ',' || r == '。' || r == '!' || r == '?' || r == ' ' || r == '\n' + }) + + for _, word := range words { + word = strings.TrimSpace(word) + if len(word) >= 2 { + keywords = append(keywords, word) + } + } + + return keywords +} + +// extractKeywordsFromMessages 从多条消息提取关键词 +func extractKeywordsFromMessages(messages []string) []string { + var allKeywords []string + seen := make(map[string]bool) + + for _, msg := range messages { + keywords := extractKeywords(msg) + for _, k := range keywords { + if !seen[k] { + seen[k] = true + allKeywords = append(allKeywords, k) + } + } + } + + // 只保留前5个关键词 + if len(allKeywords) > 5 { + allKeywords = allKeywords[:5] + } + + return allKeywords +} + +// summarizeMessages 生成记忆摘要 +func summarizeMessages(messages []string) string { + if len(messages) == 0 { + return "" + } + + // 简单实现:拼接前3条消息的核心内容 + var summary strings.Builder + count := 0 + for _, msg := range messages { + if count >= 3 { + break + } + // 截断过长的消息 + if len(msg) > 100 { + msg = msg[:100] + "..." + } + summary.WriteString(msg) + summary.WriteString("; ") + count++ + } + + result := strings.TrimSpace(summary.String()) + if len(result) > 200 { + result = result[:200] + "..." + } + + return result +} + +// FormatMemoriesForPrompt 将记忆格式化为 prompt 片段 +func FormatMemoriesForPrompt(memories []model.UserMemory) string { + if len(memories) == 0 { + return "" + } + + var builder strings.Builder + builder.WriteString("# 用户核心记忆\n") + for _, m := range memories { + builder.WriteString("- ") + builder.WriteString(m.Content) + builder.WriteString("\n") + } + + return builder.String() +} + +// GetTurnCount 计算对话轮数 +func GetTurnCount(messages []model.Message) int { + turns := 0 + for i := 0; i < len(messages)-1; i += 2 { + if i+1 < len(messages) && messages[i].Role == "user" && messages[i+1].Role == "assistant" { + turns++ + } + } + return turns +} + +// ShouldExtractMemory 判断是否应该触发记忆提取 +func ShouldExtractMemory(messages []model.Message, triggerTurns int) bool { + turns := GetTurnCount(messages) + logger.Logger.Info("ShouldExtractMemory check", + zap.Int("turns", turns), + zap.Int("trigger_turns", triggerTurns), + zap.Bool("should_extract", turns >= triggerTurns), + ) + return turns >= triggerTurns +} \ No newline at end of file diff --git a/backend/services/aiChatService/service/persona_service.go b/backend/services/aiChatService/service/persona_service.go new file mode 100644 index 0000000..61d85a6 --- /dev/null +++ b/backend/services/aiChatService/service/persona_service.go @@ -0,0 +1,68 @@ +package service + +import ( + "context" + + "github.com/google/uuid" + "github.com/topfans/backend/services/aiChatService/model" + "github.com/topfans/backend/services/aiChatService/repository" +) + +// PersonaService 人设服务 +type PersonaService struct { + personaRepo repository.PersonaRepository +} + +// NewPersonaService 创建人设服务 +func NewPersonaService(personaRepo repository.PersonaRepository) *PersonaService { + return &PersonaService{ + personaRepo: personaRepo, + } +} + +// GetPersonas 获取用户的所有人设 +func (s *PersonaService) GetPersonas(ctx context.Context, userID int64) ([]model.PersonaInfo, error) { + personas, err := s.personaRepo.GetByUserID(ctx, userID) + if err != nil { + return nil, err + } + + infos := make([]model.PersonaInfo, len(personas)) + for i, p := range personas { + infos[i] = p.ToPersonaInfo() + } + return infos, nil +} + +// GetPersona 获取指定人设 +func (s *PersonaService) GetPersona(ctx context.Context, personaID string) (*model.Persona, error) { + id, err := uuid.Parse(personaID) + if err != nil { + return nil, model.ErrInvalidRequest + } + return s.personaRepo.GetByID(ctx, id) +} + +// GetDefaultPersona 获取用户默认人设 +func (s *PersonaService) GetDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error) { + return s.personaRepo.GetDefaultByUserID(ctx, userID) +} + +// GetPersonaOrDefault 获取人设,如果 personaID 为空则获取默认人设 +func (s *PersonaService) GetPersonaOrDefault(ctx context.Context, userID int64, personaID string) (*model.Persona, error) { + // 如果指定了 personaID,优先使用指定人设 + if personaID != "" { + id, err := uuid.Parse(personaID) + if err != nil { + return nil, model.ErrInvalidRequest + } + persona, err := s.personaRepo.GetByID(ctx, id) + if err == nil { + return persona, nil + } + // 如果人设不存在,fallback 到默认人设 + } + + // 获取或创建默认人设 + return s.personaRepo.EnsureDefaultPersona(ctx, userID) +} \ No newline at end of file diff --git a/backend/services/aiChatService/service/prompt_builder.go b/backend/services/aiChatService/service/prompt_builder.go new file mode 100644 index 0000000..5562b37 --- /dev/null +++ b/backend/services/aiChatService/service/prompt_builder.go @@ -0,0 +1,180 @@ +package service + +import ( + "github.com/topfans/backend/services/aiChatService/model" +) + +// Token 限制配置 +const ( + MaxTotalTokens = 32000 // 总 Token 上限 + MaxHistoryTokens = 24000 // 对话历史最大 Token + MaxSystemTokens = 4000 // System Prompt 最大 Token + MaxMemoryTokens = 2000 // 记忆召回最大 Token + ReservedTokens = 2000 // 保留空间 + MinHistoryMessages = 4 // 最少保留消息对数 + MaxSingleMessageTokens = 4000 // 单条消息最大 Token +) + +// Tokenizer Token 计算器 +type Tokenizer struct{} + +// EstimateTokens 估算 Token 数量 +func (t *Tokenizer) EstimateTokens(text string) int { + var count int + for _, r := range text { + switch { + case r >= 0x4e00 && r <= 0x9fff: // 中文 + count += 2 + case r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z': // 英文 + count += 1 + case r >= '0' && r <= '9': + count += 1 + case r < 128: // ASCII 符号 + count += 1 + default: + count += 2 + } + } + return count +} + +// EstimateMessagesTokens 估算消息列表的总 Token +func (t *Tokenizer) EstimateMessagesTokens(messages []model.Message) int { + var total int + for _, m := range messages { + total += t.EstimateTokens(m.Role) + t.EstimateTokens(m.Content) + 10 + } + return total +} + +// BuildPrompt 组装 Prompt +func BuildPrompt( + systemPrompt string, + userCoreInfo string, + history []model.Message, + userInput string, + tokenizer *Tokenizer, +) ([]model.Message, int) { + // 1. 计算各部分 Token + systemTokens := tokenizer.EstimateTokens(systemPrompt) + memoryTokens := tokenizer.EstimateTokens(userCoreInfo) + + // 2. 预留空间计算 + reserved := ReservedTokens + if systemTokens > MaxSystemTokens { + reserved += systemTokens - MaxSystemTokens + } + + // 3. 计算可用于对话历史的 Token + availableTokens := MaxHistoryTokens - memoryTokens - reserved + if availableTokens < 500 { + availableTokens = 500 + } + + // 4. 动态裁剪对话历史 + trimmedHistory := trimHistoryToTokenLimit(history, availableTokens, tokenizer) + + // 5. 组装最终消息 + messages := []model.Message{ + {Role: "system", Content: systemPrompt}, + } + + if userCoreInfo != "" { + messages = append(messages, model.Message{ + Role: "system", + Content: "# 用户核心记忆\n" + userCoreInfo, + }) + } + + messages = append(messages, trimmedHistory...) + messages = append(messages, model.Message{Role: "user", Content: userInput}) + + // 6. 最终 Token 统计 + totalTokens := tokenizer.EstimateMessagesTokens(messages) + + return messages, totalTokens +} + +// trimHistoryToTokenLimit 裁剪历史消息至 Token 限制内 +func trimHistoryToTokenLimit(history []model.Message, maxTokens int, tokenizer *Tokenizer) []model.Message { + if len(history) == 0 { + return history + } + + currentTokens := tokenizer.EstimateMessagesTokens(history) + if currentTokens <= maxTokens { + return history + } + + result := make([]model.Message, 0) + var usedTokens int + + // 从最新开始保留 + for i := len(history) - 1; i >= 0; i -= 2 { + msgToken := tokenizer.EstimateTokens(history[i].Content) + 10 + prevToken := 0 + if i > 0 { + prevToken = tokenizer.EstimateTokens(history[i-1].Content) + 10 + } + + pairTokens := msgToken + prevToken + + if len(result)/2 >= MinHistoryMessages && usedTokens+pairTokens > maxTokens { + break + } + + if i > 0 { + result = append([]model.Message{history[i-1], history[i]}, result...) + } else { + result = append([]model.Message{history[i]}, result...) + } + usedTokens += pairTokens + } + + return result +} + +// truncateMessageIfNeeded 截断超长单条消息 +func truncateMessageIfNeeded(content string, maxTokens int, tokenizer *Tokenizer) string { + if tokenizer.EstimateTokens(content) <= maxTokens { + return content + } + + runes := []rune(content) + lo, hi := 0, len(runes) + + for lo < hi { + mid := (lo + hi + 1) / 2 + if tokenizer.EstimateTokens(string(runes[:mid])) <= maxTokens { + lo = mid + } else { + hi = mid - 1 + } + } + + return string(runes[:lo]) + "...(已截断)" +} + +// EstimateTurnCount 估算对话轮数 +func EstimateTurnCount(messages []model.Message) int { + turns := 0 + for i := 0; i < len(messages)-1; i += 2 { + if messages[i].Role == "user" && messages[i+1].Role == "assistant" { + turns++ + } + } + return turns +} + +// IsNoNeedLLMCall 判断是否不需要调用大模型 +func IsNoNeedLLMCall(input string) bool { + // 纯符号或数字 + symbolOnly := true + for _, r := range input { + if r != ' ' && r != '\n' && (r < '0' || r > '9') && r != '?' && r != '?' && r != '.' && r != '。' { + symbolOnly = false + break + } + } + return symbolOnly +} \ No newline at end of file diff --git a/frontend/App.vue b/frontend/App.vue index b660703..a920dfa 100644 --- a/frontend/App.vue +++ b/frontend/App.vue @@ -1,13 +1,33 @@ diff --git a/frontend/pages/ai-dazi/index.vue b/frontend/pages/ai-dazi/index.vue index 2ba1a17..235ba70 100644 --- a/frontend/pages/ai-dazi/index.vue +++ b/frontend/pages/ai-dazi/index.vue @@ -39,11 +39,11 @@ - - + + @@ -54,72 +54,305 @@ --> + + + + + {{ msg.content }} + + + {{ msg.content }} + + + + + ... + + + + + - - + + 发送 - -