feat:增加的ai搭子对话功能

This commit is contained in:
zerosaturation 2026-05-28 12:00:19 +08:00
parent 9b0c3ee3a9
commit 17c68c9233
38 changed files with 5125 additions and 194 deletions

13
.gitignore vendored
View File

@ -6,6 +6,19 @@ frontend/pages/.DS_Store
frontend/.hbuilderx/launch.json
.idea
# Backend 构建产物
backend/services/activityService/activityService
backend/services/assetService/assetService
backend/services/galleryService/galleryService
backend/services/socialService/socialService
backend/services/starbookService/starbookService
backend/services/taskService/taskService
backend/services/userService/userService
backend/services/aiChatService/aiChatService
backend/gateway/gateway
backend/gateway/aiChatService
bin/
node_modules
package-lock.json
# Added by code-review-graph

View File

@ -21,7 +21,7 @@ cleanup() {
fi
# 清理所有 PID 文件并杀服务进程
for service in gateway activityService galleryService socialService assetService userService taskService starbookService; do
for service in gateway activityService galleryService socialService assetService userService taskService starbookService aiChatService; do
pkill -9 -f "$service" 2>/dev/null || true
rm -f "/tmp/dev_sh_${service}.pid" "/tmp/dev_sh_${service}_restart" "/tmp/dev_sh_${service}.lock"
echo -e "${YELLOW} 🛑 $service 已停止${NC}"
@ -82,15 +82,16 @@ if [ -f "$ENV_FILE" ]; then
fi
DB_HOST="${DB_HOST:-localhost}"
DB_PORT="${DB_PORT:-5432}"
DB_USER="${DB_USER:-haihuizhu}"
DB_PASSWORD="${DB_PASSWORD:-admin}"
DB_PORT="${DB_PORT:-15432}"
DB_USER="${DB_USER:-postgres}"
DB_PASSWORD="${DB_PASSWORD:-123456}"
DB_NAME="${DB_NAME:-top-fans}"
REDIS_HOST="${REDIS_HOST:-localhost}"
REDIS_PORT="${REDIS_PORT:-6379}"
REDIS_PASSWORD="${REDIS_PASSWORD:-123456}"
REDIS_DB="${REDIS_DB:-0}"
DB_ARGS=(-db-host="$DB_HOST" -db-port="$DB_PORT" -db-user="$DB_USER" -db-password="$DB_PASSWORD" -db-name="$DB_NAME")
REDIS_ARGS=(-redis-host="$REDIS_HOST" -redis-port="$REDIS_PORT" -redis-db="$REDIS_DB")
REDIS_ARGS=(-redis-host="$REDIS_HOST" -redis-port="$REDIS_PORT" -redis-db="$REDIS_DB" -redis-password="$REDIS_PASSWORD")
# 启动一个服务
# 用法: start_service name binary port use_db use_redis
@ -280,7 +281,8 @@ start_watcher() {
--exclude='galleryService$' \
--exclude='activityService$' \
--exclude='taskService$' \
--exclude='starbookService$' &
--exclude='starbookService$' \
--exclude='aiChatService$' &
fi
done
else
@ -295,7 +297,8 @@ start_watcher() {
--exclude='galleryService$' \
--exclude='activityService$' \
--exclude='taskService$' \
--exclude='starbookService$' &
--exclude='starbookService$' \
--exclude='aiChatService$' &
fi
wait
) &
@ -359,13 +362,13 @@ echo ""
> /tmp/dev_sh_watchers.tmp
# 清理残留 PID 文件(上次非正常退出可能留下)
for service in activityService galleryService socialService assetService userService taskService gateway starbookService; do
for service in activityService galleryService socialService assetService userService taskService gateway starbookService aiChatService; do
rm -f "/tmp/dev_sh_${service}.pid" "/tmp/dev_sh_${service}_restart"
done
# 停止现有服务(清理环境)
echo -e "${YELLOW}🛑 停止现有服务...${NC}"
for service in gateway userService socialService assetService galleryService activityService taskService starbookService; do
for service in gateway userService socialService assetService galleryService activityService taskService starbookService aiChatService; do
pkill -9 -f "$service" 2>/dev/null || true
done
sleep 1
@ -396,6 +399,7 @@ build_service "galleryService" "services/galleryService" "services/galleryServic
build_service "activityService" "services/activityService" "services/activityService/activityService"
build_service "taskService" "services/taskService" "services/taskService/taskService"
build_service "starbookService" "services/starbookService" "services/starbookService/starbookService"
build_service "aiChatService" "services/aiChatService" "services/aiChatService/aiChatService"
cd "$SCRIPT_DIR"
# 启动所有服务
@ -413,6 +417,7 @@ echo -e "${GREEN}✅ galleryService 已启动 (PID: $(cat /tmp/dev_sh_gallerySer
start_service "activityService" "services/activityService/activityService" 20005 1 0
start_service "taskService" "services/taskService/taskService" 20006 1 0
start_service "starbookService" "services/starbookService/starbookService" 20007 1 0
start_service "aiChatService" "services/aiChatService/aiChatService" 20008 1 1
start_service "gateway" "gateway/gateway" 8080 0 0
# 启动所有文件监听器
@ -426,6 +431,7 @@ start_watcher "galleryService" "services/galleryService" "services/gallerySer
start_watcher "activityService" "services/activityService" "services/activityService/activityService" 20005 1 0
start_watcher "taskService" "services/taskService" "services/taskService/taskService" 20006 1 0
start_watcher "starbookService" "services/starbookService" "services/starbookService/starbookService" 20007 1 0
start_watcher "aiChatService" "services/aiChatService:pkg/proto" "services/aiChatService/aiChatService" 20008 1 1
echo ""
echo -e "${GREEN}========================================${NC}"

View File

@ -8,12 +8,13 @@ import (
// Config 网关配置
type Config struct {
Server ServerConfig
Dubbo DubboConfig
JWT JWTConfig
OSS OSSConfig
Redis RedisConfig
Root string
Server ServerConfig
Dubbo DubboConfig
JWT JWTConfig
OSS OSSConfig
Redis RedisConfig
WebSocket WebSocketConfig
Root string
}
// RedisConfig Redis 配置
@ -33,12 +34,13 @@ type ServerConfig struct {
// DubboConfig Dubbo 服务配置
type DubboConfig struct {
UserServiceURL string
SocialServiceURL string
AssetServiceURL string
GalleryServiceURL string
ActivityServiceURL string
TaskServiceURL string
StarbookServiceURL string
SocialServiceURL string
AssetServiceURL string
GalleryServiceURL string
ActivityServiceURL string
TaskServiceURL string
StarbookServiceURL string
AIChatServiceURL string
}
// JWTConfig JWT 配置
@ -70,6 +72,11 @@ func (c *OSSConfig) GetUploadDir(uploadType string) string {
}
}
// WebSocketConfig WebSocket 配置
type WebSocketConfig struct {
AIChatPath string // WebSocket 路径,默认 /ws/ai-chat
}
// Load 加载配置
func Load() *Config {
root, _ := os.Getwd()
@ -81,12 +88,13 @@ func Load() *Config {
},
Dubbo: DubboConfig{
UserServiceURL: getEnv("DUBBO_USER_SERVICE_URL", "tri://127.0.0.1:20000"),
SocialServiceURL: getEnv("DUBBO_SOCIAL_SERVICE_URL", "tri://127.0.0.1:20002"),
AssetServiceURL: getEnv("DUBBO_ASSET_SERVICE_URL", "tri://127.0.0.1:20003"),
GalleryServiceURL: getEnv("DUBBO_GALLERY_SERVICE_URL", "tri://127.0.0.1:20004"),
ActivityServiceURL: getEnv("DUBBO_ACTIVITY_SERVICE_URL", "tri://127.0.0.1:20005"),
TaskServiceURL: getEnv("DUBBO_TASK_SERVICE_URL", "tri://127.0.0.1:20006"),
StarbookServiceURL: getEnv("DUBBO_STARBOOK_SERVICE_URL", "tri://127.0.0.1:20007"),
SocialServiceURL: getEnv("DUBBO_SOCIAL_SERVICE_URL", "tri://127.0.0.1:20002"),
AssetServiceURL: getEnv("DUBBO_ASSET_SERVICE_URL", "tri://127.0.0.1:20003"),
GalleryServiceURL: getEnv("DUBBO_GALLERY_SERVICE_URL", "tri://127.0.0.1:20004"),
ActivityServiceURL: getEnv("DUBBO_ACTIVITY_SERVICE_URL", "tri://127.0.0.1:20005"),
TaskServiceURL: getEnv("DUBBO_TASK_SERVICE_URL", "tri://127.0.0.1:20006"),
StarbookServiceURL: getEnv("DUBBO_STARBOOK_SERVICE_URL", "tri://127.0.0.1:20007"),
AIChatServiceURL: getEnv("DUBBO_AI_CHAT_SERVICE_URL", "tri://127.0.0.1:20008"),
},
JWT: JWTConfig{
Secret: getEnv("JWT_SECRET", "topfans-secret-key-please-change-in-production"),
@ -107,6 +115,9 @@ func Load() *Config {
Password: getEnv("REDIS_PASSWORD", "123456"),
DB: getEnvInt("REDIS_DB", 0),
},
WebSocket: WebSocketConfig{
AIChatPath: getEnv("WS_AI_CHAT_PATH", "/ws/ai-chat"),
},
}
}
@ -159,4 +170,4 @@ func (c *Config) Validate() error {
return fmt.Errorf("JWT secret is required")
}
return nil
}
}

View File

@ -0,0 +1,169 @@
package controller
import (
"context"
"net/http"
"strconv"
"dubbo.apache.org/dubbo-go/v3/client"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"github.com/gin-gonic/gin"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/gateway/pkg/response"
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
"go.uber.org/zap"
)
// AIChatController AI Chat HTTP 控制器
type AIChatController struct {
aiChatClient *client.Client
}
// NewAIChatController 创建 AIChatController
func NewAIChatController(aiChatClient *client.Client) (*AIChatController, error) {
return &AIChatController{
aiChatClient: aiChatClient,
}, nil
}
// GetPersonas 获取人设列表
// @Summary 获取用户的所有人设
// @Tags ai-chat
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {object} response.Response
// @Router /api/v1/ai-chat/personas [get]
func (c *AIChatController) GetPersonas(ctx *gin.Context) {
userIDVal, exists := ctx.Get("user_id")
if !exists {
response.Error(ctx, http.StatusUnauthorized, "用户未认证")
return
}
uid, _ := userIDVal.(int64)
aiChatService, err := pbAIChat.NewAIChatService(c.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
response.Error(ctx, http.StatusInternalServerError, "服务暂时不可用")
return
}
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(uid, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
}
rpcCtx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
req := &pbAIChat.GetPersonasRequest{
UserId: uid,
}
resp, err := aiChatService.GetPersonas(rpcCtx, req)
if err != nil {
logger.Logger.Error("GetPersonas call failed", zap.Error(err))
response.Error(ctx, http.StatusInternalServerError, "获取人设失败")
return
}
personas := make([]map[string]interface{}, len(resp.Personas))
for i, p := range resp.Personas {
personas[i] = map[string]interface{}{
"id": p.Id,
"name": p.Name,
"description": p.Description,
"avatar_url": p.AvatarUrl,
"talk_style": p.TalkStyle,
"is_default": p.IsDefault,
"created_at": p.CreatedAt,
"updated_at": p.UpdatedAt,
}
}
response.Success(ctx, map[string]interface{}{
"personas": personas,
})
}
// GetHistory 获取对话历史
// @Summary 获取对话历史
// @Tags ai-chat
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param sessionId path string true "会话ID"
// @Success 200 {object} response.Response
// @Router /api/v1/ai-chat/history/{sessionId} [get]
func (c *AIChatController) GetHistory(ctx *gin.Context) {
sessionID := ctx.Param("sessionId")
if sessionID == "" {
response.Error(ctx, http.StatusBadRequest, "sessionId不能为空")
return
}
aiChatService, err := pbAIChat.NewAIChatService(c.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
response.Error(ctx, http.StatusInternalServerError, "服务暂时不可用")
return
}
req := &pbAIChat.ChatHistoryRequest{
SessionId: sessionID,
Limit: 20,
}
// 获取 user_id 和 star_id 用于 Attachments
userIDVal, exists := ctx.Get("user_id")
if !exists {
response.Error(ctx, http.StatusUnauthorized, "用户未认证")
return
}
uid, _ := userIDVal.(int64)
starIDVal, _ := ctx.Get("star_id")
sid, _ := starIDVal.(int64)
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(uid, 10)
starIDStr := strconv.FormatInt(sid, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
"star_id": starIDStr,
}
rpcCtx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
resp, err := aiChatService.GetHistory(rpcCtx, req)
if err != nil {
logger.Logger.Error("GetHistory call failed", zap.Error(err))
response.Error(ctx, http.StatusInternalServerError, "获取历史失败")
return
}
history := make([]map[string]string, len(resp.History))
for i, m := range resp.History {
history[i] = map[string]string{
"role": m.Role,
"content": m.Content,
}
}
response.Success(ctx, map[string]interface{}{
"history": history,
})
}
// extractUserInfoFromContext 从 gin.Context 中提取用户信息
func extractUserInfoFromContext(ctx *gin.Context) (int64, int64, error) {
userID, exists := ctx.Get("user_id")
if !exists {
return 0, 0, nil
}
starID, _ := ctx.Get("star_id")
uid, _ := userID.(int64)
sid, _ := starID.(int64)
return uid, sid, nil
}

View File

@ -146,9 +146,18 @@ func main() {
}
logger.Logger.Info("Starbook Service Dubbo client connected successfully")
// 4.8 AIChatService Client
aiChatClient, err := client.NewClient(
client.WithClientURL(cfg.Dubbo.AIChatServiceURL),
)
if err != nil {
logger.Logger.Fatal("Failed to create AI Chat Service Dubbo client", zap.Error(err))
}
logger.Logger.Info("AI Chat Service Dubbo client connected successfully")
// 5. 设置路由
logger.Logger.Info("Setting up routes...")
r, err := router.SetupRouter(userClient, socialClient, assetClient, galleryClient, activityClient, taskClient, starbookClient)
r, err := router.SetupRouter(userClient, socialClient, assetClient, galleryClient, activityClient, taskClient, starbookClient, aiChatClient, cfg.WebSocket.AIChatPath)
if err != nil {
logger.Logger.Fatal("Failed to setup router", zap.Error(err))
}

View File

@ -1,6 +1,8 @@
package router
import (
"net/http"
"dubbo.apache.org/dubbo-go/v3/client"
"github.com/gin-gonic/gin"
swaggerfiles "github.com/swaggo/files"
@ -8,10 +10,11 @@ import (
_ "github.com/topfans/backend/gateway/docs" // 导入 docs 包以初始化 Swagger
"github.com/topfans/backend/gateway/controller"
"github.com/topfans/backend/gateway/middleware"
"github.com/topfans/backend/gateway/socket"
)
// SetupRouter 设置路由
func SetupRouter(userClient *client.Client, socialClient *client.Client, assetClient *client.Client, galleryClient *client.Client, activityClient *client.Client, taskClient *client.Client, starbookClient *client.Client) (*gin.Engine, error) {
func SetupRouter(userClient *client.Client, socialClient *client.Client, assetClient *client.Client, galleryClient *client.Client, activityClient *client.Client, taskClient *client.Client, starbookClient *client.Client, aiChatClient *client.Client, aiChatPath string) (*gin.Engine, error) {
r := gin.Default()
// 全局中间件
@ -30,6 +33,12 @@ func SetupRouter(userClient *client.Client, socialClient *client.Client, assetCl
swaggerHandler := ginSwagger.WrapHandler(swaggerfiles.Handler)
r.GET("/swagger/*any", swaggerHandler)
// AI Chat WebSocket 路由
aiChatHub := socket.NewHub(aiChatClient, aiChatPath)
r.GET(aiChatPath, gin.WrapF(func(w http.ResponseWriter, r *http.Request) {
aiChatHub.HandleWebSocket(w, r)
}))
// 创建控制器
authCtrl, err := controller.NewAuthController(userClient)
if err != nil {
@ -76,6 +85,11 @@ func SetupRouter(userClient *client.Client, socialClient *client.Client, assetCl
return nil, err
}
aiChatCtrl, err := controller.NewAIChatController(aiChatClient)
if err != nil {
return nil, err
}
// API v1 路由组
v1 := r.Group("/api/v1")
{
@ -287,6 +301,14 @@ func SetupRouter(userClient *client.Client, socialClient *client.Client, assetCl
starbook.GET("/home", starbookCtrl.GetStarbookHome) // 获取星册首页
starbook.GET("/items", starbookCtrl.GetStarbookItems) // 获取星册藏品列表
}
// AI Chat 相关路由(需要认证)
aiChat := v1.Group("/ai-chat")
aiChat.Use(middleware.AuthMiddleware())
{
aiChat.GET("/personas", aiChatCtrl.GetPersonas) // 获取人设列表
aiChat.GET("/history/:sessionId", aiChatCtrl.GetHistory) // 获取对话历史
}
}
return r, nil

View File

@ -0,0 +1,478 @@
package socket
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
"dubbo.apache.org/dubbo-go/v3/client"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"github.com/gorilla/websocket"
"github.com/topfans/backend/pkg/jwt"
"github.com/topfans/backend/pkg/logger"
"go.uber.org/zap"
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
// Hub 管理所有 AI Chat WebSocket 连接
type Hub struct {
// 用户连接映射: userId -> *Connection
clients map[int64]*Connection
// Dubbo 客户端
aiChatClient *client.Client
// WebSocket 配置
aiChatPath string
mu sync.RWMutex
}
// Connection WebSocket 连接
type Connection struct {
UserID int64
StarID int64
Conn *websocket.Conn
Send chan []byte
Hub *Hub
writeMu sync.Mutex // 保护并发写操作
}
// writeJSON 线程安全的 WriteJSON 封装
func (c *Connection) writeJSON(data interface{}) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return c.Conn.WriteJSON(data)
}
// sendError 发送错误消息
func (c *Connection) sendError(code, message string) {
c.Send <- []byte(fmt.Sprintf(`{"type":"error","code":"%s","message":"%s","session_id":"%d_%d"}`,
code, message, c.UserID, c.StarID))
}
// NewHub 创建 Hub 实例
func NewHub(aiChatClient *client.Client, aiChatPath string) *Hub {
return &Hub{
clients: make(map[int64]*Connection),
aiChatClient: aiChatClient,
aiChatPath: aiChatPath,
}
}
// HandleWebSocket 处理 WebSocket 连接
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request) {
// 从 URL 参数获取 token
token := r.URL.Query().Get("token")
if token == "" {
// 尝试从 query 参数获取URL 编码后)
token = r.URL.Query().Get("token")
}
// 解析 token 获取用户信息
userID, starID, err := h.validateToken(token)
if err != nil {
logger.Logger.Error("WebSocket token validation failed", zap.Error(err))
// 返回 401 Unauthorized而不是尝试升级 WebSocket
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]interface{}{
"type": "auth_response",
"success": false,
"error": "invalid_token",
})
return
}
// 升级为 WebSocket 连接
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
connection := &Connection{
UserID: userID,
StarID: starID,
Conn: conn,
Send: make(chan []byte, 256),
Hub: h,
}
// 注册连接
h.mu.Lock()
h.clients[userID] = connection
h.mu.Unlock()
logger.Logger.Info("WebSocket connection established",
zap.Int64("user_id", userID),
zap.Int64("star_id", starID),
)
// 发送鉴权成功响应
authResp := map[string]interface{}{
"type": "auth_response",
"success": true,
"user_id": userID,
"star_id": starID,
}
conn.WriteJSON(authResp)
// 启动读 goroutine
go connection.readPump()
// 启动写 goroutine
go connection.writePump()
}
// validateToken 验证 token使用 JWT 解析)
func (h *Hub) validateToken(token string) (int64, int64, error) {
// 去掉 "Bearer_" 前缀
if strings.HasPrefix(token, "Bearer_") {
token = strings.TrimPrefix(token, "Bearer_")
}
// 解析 JWT token
claims, err := jwt.ParseToken(token)
if err != nil {
return 0, 0, fmt.Errorf("failed to parse token: %w", err)
}
userID := claims.UserID
starID := claims.StarID
if userID == 0 {
return 0, 0, errors.New("invalid user id")
}
return userID, starID, nil
}
// readPump 读取客户端消息
func (c *Connection) readPump() {
defer func() {
c.Hub.mu.Lock()
delete(c.Hub.clients, c.UserID)
c.Hub.mu.Unlock()
c.Conn.Close()
}()
c.Conn.SetReadLimit(512 * 1024) // 512KB
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
logger.Logger.Error("WebSocket read error", zap.Error(err))
}
break
}
// 解析消息
var msg map[string]interface{}
if err := json.Unmarshal(message, &msg); err != nil {
logger.Logger.Error("Failed to parse message", zap.Error(err))
continue
}
// 处理消息
c.handleMessage(msg)
}
}
// writePump 写消息到客户端
func (c *Connection) writePump() {
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.Conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// handleMessage 处理客户端消息
func (c *Connection) handleMessage(msg map[string]interface{}) {
action, _ := msg["action"].(string)
sessionID, _ := msg["session_id"].(string)
message, _ := msg["message"].(string)
personaID, _ := msg["persona_id"].(string)
session := sessionID
if session == "" {
session = fmt.Sprintf("%d_%d", c.UserID, c.StarID)
}
switch action {
case "ping":
// 心跳 - 通过 Send 通道发送,避免并发写
c.Send <- []byte(`{"type":"pong"}`)
case "init_session":
// 初始化会话,返回欢迎消息
go c.initSession(session)
case "send_message":
// 发送消息
go c.sendMessage(session, message, personaID)
case "get_history":
// 获取历史
go c.getHistory(session)
case "get_personas":
// 获取人设列表
go c.getPersonas()
default:
logger.Logger.Warn("Unknown action", zap.String("action", action))
}
}
// initSession 初始化会话,返回欢迎消息
func (c *Connection) initSession(sessionID string) {
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(c.UserID, 10)
starIDStr := strconv.FormatInt(c.StarID, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
"star_id": starIDStr,
}
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
req := &pbAIChat.InitSessionRequest{
SessionId: sessionID,
UserId: c.UserID,
}
resp, err := aiChatService.InitSession(ctx, req)
if err != nil {
logger.Logger.Error("InitSession call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "初始化会话失败")
return
}
// 发送欢迎消息
c.writeJSON(map[string]interface{}{
"type": "message",
"session_id": sessionID,
"content": resp.WelcomeMessage,
"is_end": true,
})
}
// sendMessage 发送消息到 AI Chat Service
func (c *Connection) sendMessage(sessionID, message, personaID string) {
// 创建 Dubbo 客户端调用
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(c.UserID, 10)
starIDStr := strconv.FormatInt(c.StarID, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
"star_id": starIDStr,
}
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
req := &pbAIChat.ChatMessageRequest{
SessionId: sessionID,
Message: message,
PersonaId: personaID,
UserId: c.UserID,
}
// 流式调用
stream, err := aiChatService.SendMessage(ctx, req)
if err != nil {
logger.Logger.Error("SendMessage call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "消息发送失败")
return
}
// 接收流式响应
for {
if !stream.Recv() {
break
}
resp := stream.Msg()
// 发送消息到客户端
respMsg := map[string]interface{}{
"type": "message",
"content": resp.Content,
"session_id": resp.SessionId,
"is_end": resp.IsEnd,
}
if resp.Error != "" {
respMsg["error"] = resp.Error
}
if err := c.writeJSON(respMsg); err != nil {
logger.Logger.Error("Failed to send message to client", zap.Error(err))
break
}
if resp.IsEnd {
break
}
}
}
// getHistory 获取对话历史
func (c *Connection) getHistory(sessionID string) {
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
// 创建带 Attachments 的 context
userIDStr := strconv.FormatInt(c.UserID, 10)
starIDStr := strconv.FormatInt(c.StarID, 10)
attachments := map[string]interface{}{
"user_id": userIDStr,
"star_id": starIDStr,
}
ctx := context.WithValue(context.Background(), constant.AttachmentKey, attachments)
req := &pbAIChat.ChatHistoryRequest{
SessionId: sessionID,
Limit: 20,
}
resp, err := aiChatService.GetHistory(ctx, req)
if err != nil {
logger.Logger.Error("GetHistory call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "获取历史失败")
return
}
// 转换历史消息
history := make([]map[string]string, len(resp.History))
for i, m := range resp.History {
history[i] = map[string]string{
"role": m.Role,
"content": m.Content,
}
}
// 发送响应
c.writeJSON(map[string]interface{}{
"type": "history_response",
"session_id": sessionID,
"history": history,
})
}
// getPersonas 获取人设列表
func (c *Connection) getPersonas() {
aiChatService, err := pbAIChat.NewAIChatService(c.Hub.aiChatClient)
if err != nil {
logger.Logger.Error("Failed to create AI Chat Service client", zap.Error(err))
c.sendError("SERVICE_ERROR", "服务暂时不可用")
return
}
ctx := context.Background()
req := &pbAIChat.GetPersonasRequest{
UserId: c.UserID,
}
resp, err := aiChatService.GetPersonas(ctx, req)
if err != nil {
logger.Logger.Error("GetPersonas call failed", zap.Error(err))
c.sendError("SERVICE_ERROR", "获取人设失败")
return
}
// 转换人设信息
personas := make([]map[string]interface{}, len(resp.Personas))
for i, p := range resp.Personas {
personas[i] = map[string]interface{}{
"id": p.Id,
"name": p.Name,
"description": p.Description,
"avatar_url": p.AvatarUrl,
"talk_style": p.TalkStyle,
"is_default": p.IsDefault,
"created_at": p.CreatedAt,
"updated_at": p.UpdatedAt,
}
}
// 发送响应
c.writeJSON(map[string]interface{}{
"type": "personas_response",
"personas": personas,
})
}
// sendError 发送错误消息 (使用 Send 通道)
// Close 关闭连接
func (h *Hub) Close() {
h.mu.Lock()
defer h.mu.Unlock()
for _, conn := range h.clients {
conn.Conn.Close()
}
}

View File

@ -54,6 +54,7 @@ require (
github.com/k0kubun/pp v3.0.1+incompatible // indirect
github.com/knadh/koanf v1.5.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lib/pq v1.12.3 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.5 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect

View File

@ -470,6 +470,8 @@ github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjS
github.com/lestrrat/go-envload v0.0.0-20180220120943-6ed08b54a570/go.mod h1:BLt8L9ld7wVsvEWQbuLrUZnCMnUmLZ+CGDzKtclrTlE=
github.com/lestrrat/go-file-rotatelogs v0.0.0-20180223000712-d3151e2a480f/go.mod h1:UGmTpUd3rjbtfIpwAPrcfmGf/Z1HS95TATB+m57TPB8=
github.com/lestrrat/go-strftime v0.0.0-20180220042222-ba3bf9c1d042/go.mod h1:TPpsiPUEh0zFL1Snz4crhMlBe60PYxRHr5oFF3rRYg0=
github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ=
github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=

View File

@ -8,4 +8,5 @@ use (
./services/galleryService
./services/socialService
./services/userService
./services/aiChatService
)

View File

@ -26,39 +26,11 @@ github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HR
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM=
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAuRjVTiNNhvNRfY2Wxp9nhfyel4rklc=
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
github.com/alibabacloud-go/alibabacloud-gateway-pop v0.0.6/go.mod h1:4EUIoxs/do24zMOGGqYVWgw0s9NtiylnJglOeEB5UJo=
github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.4/go.mod h1:sCavSAvdzOjul4cEqeVtvlSaSScfNsTQ+46HwlTL1hc=
github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.5 h1:zE8vH9C7JiZLNJJQ5OwjU9mSi4T9ef9u3BURT6LCLC8=
github.com/alibabacloud-go/alibabacloud-gateway-spi v0.0.5/go.mod h1:tWnyE9AjF8J8qqLk645oUmVUnFybApTQWklQmi5tY6g=
github.com/alibabacloud-go/darabonba-array v0.1.0/go.mod h1:BLKxr0brnggqOJPqT09DFJ8g3fsDshapUD3C3aOEFaI=
github.com/alibabacloud-go/darabonba-encode-util v0.0.2/go.mod h1:JiW9higWHYXm7F4PKuMgEUETNZasrDM6vqVr/Can7H8=
github.com/alibabacloud-go/darabonba-map v0.0.2/go.mod h1:28AJaX8FOE/ym8OUFWga+MtEzBunJwQGceGQlvaPGPc=
github.com/alibabacloud-go/darabonba-openapi/v2 v2.0.8/go.mod h1:CzQnh+94WDnJOnKZH5YRyouL+OOcdBnXY5VWAf0McgI=
github.com/alibabacloud-go/darabonba-openapi/v2 v2.1.13 h1:Q00FU3H94Ts0ZIHDmY+fYGgB7dV9D/YX6FGsgorQPgw=
github.com/alibabacloud-go/darabonba-openapi/v2 v2.1.13/go.mod h1:lxFGfobinVsQ49ntjpgWghXmIF0/Sm4+wvBJ1h5RtaE=
github.com/alibabacloud-go/darabonba-signature-util v0.0.7/go.mod h1:oUzCYV2fcCH797xKdL6BDH8ADIHlzrtKVjeRtunBNTQ=
github.com/alibabacloud-go/darabonba-string v1.0.2/go.mod h1:93cTfV3vuPhhEwGGpKKqhVW4jLe7tDpo3LUM0i0g6mA=
github.com/alibabacloud-go/dysmsapi-20180501/v2 v2.0.8 h1:aDPyz6C+nenypx24N5qEt09NjpS6mu7Cu1A+wf9UTaY=
github.com/alibabacloud-go/dysmsapi-20180501/v2 v2.0.8/go.mod h1:e/vWJ5gLVnraPROSh+3oMSodf5ukaUlqNgH0IIcnz98=
github.com/alibabacloud-go/endpoint-util v1.1.0/go.mod h1:O5FuCALmCKs2Ff7JFJMudHs0I5EBgecXXxZRyswlEjE=
github.com/alibabacloud-go/openapi-util v0.1.0/go.mod h1:sQuElr4ywwFRlCCberQwKRFhRzIyG4QTP/P4y1CJ6Ws=
github.com/alibabacloud-go/tea v1.1.7/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4=
github.com/alibabacloud-go/tea v1.1.8/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4=
github.com/alibabacloud-go/tea v1.1.11/go.mod h1:/tmnEaQMyb4Ky1/5D+SE1BAsa5zj/KeGOFfwYm3N/p4=
github.com/alibabacloud-go/tea v1.1.20/go.mod h1:nXxjm6CIFkBhwW4FQkNrolwbfon8Svy6cujmKFUq98A=
github.com/alibabacloud-go/tea v1.2.1/go.mod h1:qbzof29bM/IFhLMtJPrgTGK3eauV5J2wSyEUo4OEmnA=
github.com/alibabacloud-go/tea v1.3.13 h1:WhGy6LIXaMbBM6VBYcsDCz6K/TPsT1Ri2hPmmZffZ94=
github.com/alibabacloud-go/tea v1.3.13/go.mod h1:A560v/JTQ1n5zklt2BEpurJzZTI8TUT+Psg2drWlxRg=
github.com/alibabacloud-go/tea-utils v1.3.1/go.mod h1:EI/o33aBfj3hETm4RLiAxF/ThQdSngxrpF8rKUDJjPE=
github.com/alibabacloud-go/tea-utils/v2 v2.0.5/go.mod h1:dL6vbUT35E4F4bFTHL845eUloqaerYBYPsdWR2/jhe4=
github.com/alibabacloud-go/tea-utils/v2 v2.0.7 h1:WDx5qW3Xa5ZgJ1c8NfqJkF6w+AU5wB8835UdhPr6Ax0=
github.com/alibabacloud-go/tea-utils/v2 v2.0.7/go.mod h1:qxn986l+q33J5VkialKMqT/TTs3E+U9MJpd001iWQ9I=
github.com/alibabacloud-go/tea-utils/v2 v2.0.8/go.mod h1:qxn986l+q33J5VkialKMqT/TTs3E+U9MJpd001iWQ9I=
github.com/alibabacloud-go/tea-xml v1.1.3/go.mod h1:Rq08vgCcCAjHyRi/M7xlHKUykZCEtyBy9+DPF6GgEu8=
github.com/aliyun/credentials-go v1.1.2/go.mod h1:ozcZaMR5kLM7pwtCMEpVmQ242suV6qTJya2bDq4X1Tw=
github.com/aliyun/credentials-go v1.3.1/go.mod h1:8jKYhQuDawt8x2+fusqa1Y6mPxemTsBEN04dgcAcYz0=
github.com/aliyun/credentials-go v1.3.6/go.mod h1:1LxUuX7L5YrZUWzBrRyk0SwSdH4OmPrib8NVePL3fxM=
github.com/aliyun/credentials-go v1.4.5/go.mod h1:Jm6d+xIgwJVLVWT561vy67ZRP4lPTQxMbEYRuT2Ti1U=
github.com/antihax/optional v1.0.0 h1:xK2lYat7ZLaVVcIuj82J8kIro4V6kDe0AUDFboUCwcg=
github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e h1:QEF07wC0T1rKkctt1RINW/+RMTVmiwxETico2l3gxJA=
@ -81,8 +53,6 @@ github.com/aws/smithy-go v1.8.0 h1:AEwwwXQZtUwP5Mz506FeXXrKBe0jA8gVM+1gEcSRooc=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
github.com/bketelsen/crypt v0.0.4 h1:w/jqZtC9YD4DS/Vp9GhWfWcCpuAL58oTnLoI8vE9YHU=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/casbin/casbin/v2 v2.1.2 h1:bTwon/ECRx9dwBy2ewRVr5OiqjeXSGiTUY74sDPQi/g=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk=
@ -94,8 +64,6 @@ github.com/chzyer/logex v1.2.0 h1:+eqR0HfOetur4tgnC8ftU5imRnhi4te+BadWS95c5AM=
github.com/chzyer/readline v1.5.0 h1:lSwwFrbNviGePhkewF1az4oLmcwqCZijQ2/Wi3BGHAI=
github.com/chzyer/test v0.0.0-20210722231415-061457976a23 h1:dZ0/VyGgQdVGAss6Ju0dt5P0QltE0SFY5Woh6hbIfiQ=
github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s=
github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME=
github.com/clbanning/mxj/v2 v2.7.0/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s=
github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec h1:EdRZT3IeKQmfCSrgo8SZ8V3MEnskuJP0wCYNpe+aiXo=
github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4 h1:hzAQntlaYRkVSFEfj9OTWlVV1H155FMD8BTKktLv0QI=
@ -112,6 +80,7 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbp
github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM=
github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954 h1:RMLoZVzv4GliuWafOuPuQDKSm1SJph7uCRnnS61JAn4=
github.com/dubbogo/jsonparser v1.0.1 h1:sAIr8gk+gkahkIm6CnUxh9wTCkbgwLEQ8dTXTnAXyzo=
github.com/dubbogo/net v0.0.4 h1:Rn9aMPZwOiRE22YhtxmDEE3H0Q3cfVRNhuEjNMelJ/8=
@ -174,13 +143,11 @@ github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 h1:8jtTdc+Nfj9AR+0s
github.com/gonum/lapack v0.0.0-20181123203213-e4cdc5a0bff9 h1:7qnwS9+oeSiOIsiUMajT+0R7HR6hw5NegnKPmn/94oI=
github.com/gonum/matrix v0.0.0-20181209220409-c518dec07be9 h1:V2IgdyerlBa/MxaEFRbV5juy/C3MGdj4ePi+g6ePIp4=
github.com/gonum/stat v0.0.0-20181125101827-41a0da705a5b h1:fbskpz/cPqWH8VqkQ7LJghFkl2KPAiIFUHrTJ2O3RGk=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no=
github.com/google/martian/v3 v3.1.0 h1:wCKgOCHuUEVfsaQLpPSJb7VdYCdTVZQAuOdYm1yc/60=
github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=
github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU=
@ -306,6 +273,7 @@ github.com/rabbitmq/amqp091-go v1.8.1 h1:RejT1SBUim5doqcL6s7iN6SBmsQqyTgXb1xMlH0
github.com/rabbitmq/amqp091-go v1.8.1/go.mod h1:+jPrT9iY2eLjRaMSRHUhc3z14E/l85kv/f+6luSD3pc=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/redis/go-redis/v9 v9.5.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rhnvrm/simples3 v0.6.1 h1:H0DJwybR6ryQE+Odi9eqkHuzjYAeJgtGcGtuBwOhsH8=
github.com/rogpeppe/fastuuid v1.2.0 h1:Ppwyp6VYCF1nvBTXL3trRso7mXMlRrw9ooo375wvi2s=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
@ -316,18 +284,13 @@ github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da h1:p3Vo3i64TCL
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I=
github.com/shirou/gopsutil v3.20.11+incompatible h1:LJr4ZQK4mPpIV5gOa4jCOKOGb4ty4DZO54I4FGqIpto=
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
github.com/sony/gobreaker v0.4.1 h1:oMnRNZXX5j85zso6xCPRNPtmAycat+WcoKbklScLDgQ=
github.com/spf13/cobra v1.1.1 h1:KfztREH0tPxJJ+geloSLaAkaPkr4ki2Er5quFV1TDo4=
github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo=
github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs=
github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271 h1:WhxRHzgeVGETMlmVfqhRn8RIeeNoPr2Czh33I4Zdccw=
github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a h1:AhmOdSHeswKHBjhsLs/7+1voOxT+LLrSk/Nxvk35fug=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/tebeka/strftime v0.1.3 h1:5HQXOqWKYRFfNyBMNVc9z5+QzuBtIXy03psIhtdJYto=
github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE=
github.com/toolkits/concurrent v0.0.0-20150624120057-a4371d70e3e3 h1:kF/7m/ZU+0D4Jj5eZ41Zm3IH/J8OElK1Qtd7tVKAwLk=
github.com/ugorji/go v1.2.6 h1:tGiWC9HENWE2tqYycIqFTNorMmFRVhNwCpDOpWqnk8E=
github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY=
@ -335,75 +298,34 @@ github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s=
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opentelemetry.io/contrib/detectors/gcp v1.38.0 h1:ZoYbqX7OaA/TAikspPl3ozPI6iY6LiIY9I8cUfm+pJs=
go.opentelemetry.io/contrib/detectors/gcp v1.38.0/go.mod h1:SU+iU7nu5ud4oCb3LQOhIZ3nRLj6FNVrKgtflbaf2ts=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4=
golang.org/x/crypto v0.0.0-20191219195013-becbf705a915/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20200509044756-6aff5f38e54f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 h1:E2/AqCUMZGgd73TQkxUMcMla25GB9i/5HOdLr+uH7Vo=
golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ=
golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.0.0-20200509030707-2212a7e161a5/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
@ -417,7 +339,6 @@ gopkg.in/cheggaaa/pb.v1 v1.0.25 h1:Ev7yu1/f6+d+b3pi5vPdRPc6nNtP1umSfcWiEfRqv6I=
gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/gcfg.v1 v1.2.3 h1:m8OOJ4ccYHnx2f4gQwpno8nAX5OGOh7RLaaz0pj3Ogs=
gopkg.in/ini.v1 v1.56.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/resty.v1 v1.12.0 h1:CuXP0Pjfw9rOuY6EP+UvtNvt5DSqHpIxILZKT/quCZI=
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=

View File

@ -0,0 +1,726 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v7.34.0
// source: ai_chat.proto
package ai_chat
import (
_ "google.golang.org/genproto/googleapis/api/annotations"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Message struct {
state protoimpl.MessageState `protogen:"open.v1"`
Role string `protobuf:"bytes,1,opt,name=role,proto3" json:"role,omitempty"` // "user" / "assistant"
Content string `protobuf:"bytes,2,opt,name=content,proto3" json:"content,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Message) Reset() {
*x = Message{}
mi := &file_ai_chat_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Message) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Message) ProtoMessage() {}
func (x *Message) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Message.ProtoReflect.Descriptor instead.
func (*Message) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{0}
}
func (x *Message) GetRole() string {
if x != nil {
return x.Role
}
return ""
}
func (x *Message) GetContent() string {
if x != nil {
return x.Content
}
return ""
}
type InitSessionRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"`
UserId int64 `protobuf:"varint,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InitSessionRequest) Reset() {
*x = InitSessionRequest{}
mi := &file_ai_chat_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InitSessionRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InitSessionRequest) ProtoMessage() {}
func (x *InitSessionRequest) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InitSessionRequest.ProtoReflect.Descriptor instead.
func (*InitSessionRequest) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{1}
}
func (x *InitSessionRequest) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *InitSessionRequest) GetUserId() int64 {
if x != nil {
return x.UserId
}
return 0
}
type InitSessionResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
WelcomeMessage string `protobuf:"bytes,1,opt,name=welcome_message,json=welcomeMessage,proto3" json:"welcome_message,omitempty"`
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InitSessionResponse) Reset() {
*x = InitSessionResponse{}
mi := &file_ai_chat_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InitSessionResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InitSessionResponse) ProtoMessage() {}
func (x *InitSessionResponse) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InitSessionResponse.ProtoReflect.Descriptor instead.
func (*InitSessionResponse) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{2}
}
func (x *InitSessionResponse) GetWelcomeMessage() string {
if x != nil {
return x.WelcomeMessage
}
return ""
}
func (x *InitSessionResponse) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
type ChatMessageRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
PersonaId string `protobuf:"bytes,3,opt,name=persona_id,json=personaId,proto3" json:"persona_id,omitempty"` // 可选,空则用默认人设
UserId int64 `protobuf:"varint,4,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ChatMessageRequest) Reset() {
*x = ChatMessageRequest{}
mi := &file_ai_chat_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ChatMessageRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ChatMessageRequest) ProtoMessage() {}
func (x *ChatMessageRequest) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ChatMessageRequest.ProtoReflect.Descriptor instead.
func (*ChatMessageRequest) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{3}
}
func (x *ChatMessageRequest) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ChatMessageRequest) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
func (x *ChatMessageRequest) GetPersonaId() string {
if x != nil {
return x.PersonaId
}
return ""
}
func (x *ChatMessageRequest) GetUserId() int64 {
if x != nil {
return x.UserId
}
return 0
}
type ChatMessageResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Content string `protobuf:"bytes,1,opt,name=content,proto3" json:"content,omitempty"`
SessionId string `protobuf:"bytes,2,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"`
IsEnd bool `protobuf:"varint,3,opt,name=is_end,json=isEnd,proto3" json:"is_end,omitempty"`
Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ChatMessageResponse) Reset() {
*x = ChatMessageResponse{}
mi := &file_ai_chat_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ChatMessageResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ChatMessageResponse) ProtoMessage() {}
func (x *ChatMessageResponse) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ChatMessageResponse.ProtoReflect.Descriptor instead.
func (*ChatMessageResponse) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{4}
}
func (x *ChatMessageResponse) GetContent() string {
if x != nil {
return x.Content
}
return ""
}
func (x *ChatMessageResponse) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ChatMessageResponse) GetIsEnd() bool {
if x != nil {
return x.IsEnd
}
return false
}
func (x *ChatMessageResponse) GetError() string {
if x != nil {
return x.Error
}
return ""
}
type ChatHistoryRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"`
Limit int32 `protobuf:"varint,2,opt,name=limit,proto3" json:"limit,omitempty"` // 默认 20
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ChatHistoryRequest) Reset() {
*x = ChatHistoryRequest{}
mi := &file_ai_chat_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ChatHistoryRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ChatHistoryRequest) ProtoMessage() {}
func (x *ChatHistoryRequest) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[5]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ChatHistoryRequest.ProtoReflect.Descriptor instead.
func (*ChatHistoryRequest) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{5}
}
func (x *ChatHistoryRequest) GetSessionId() string {
if x != nil {
return x.SessionId
}
return ""
}
func (x *ChatHistoryRequest) GetLimit() int32 {
if x != nil {
return x.Limit
}
return 0
}
type ChatHistoryResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
History []*Message `protobuf:"bytes,1,rep,name=history,proto3" json:"history,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ChatHistoryResponse) Reset() {
*x = ChatHistoryResponse{}
mi := &file_ai_chat_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ChatHistoryResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ChatHistoryResponse) ProtoMessage() {}
func (x *ChatHistoryResponse) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[6]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ChatHistoryResponse.ProtoReflect.Descriptor instead.
func (*ChatHistoryResponse) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{6}
}
func (x *ChatHistoryResponse) GetHistory() []*Message {
if x != nil {
return x.History
}
return nil
}
type GetPersonasRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
UserId int64 `protobuf:"varint,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GetPersonasRequest) Reset() {
*x = GetPersonasRequest{}
mi := &file_ai_chat_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GetPersonasRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetPersonasRequest) ProtoMessage() {}
func (x *GetPersonasRequest) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[7]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetPersonasRequest.ProtoReflect.Descriptor instead.
func (*GetPersonasRequest) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{7}
}
func (x *GetPersonasRequest) GetUserId() int64 {
if x != nil {
return x.UserId
}
return 0
}
type PersonaListResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Personas []*PersonaInfo `protobuf:"bytes,1,rep,name=personas,proto3" json:"personas,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *PersonaListResponse) Reset() {
*x = PersonaListResponse{}
mi := &file_ai_chat_proto_msgTypes[8]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *PersonaListResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PersonaListResponse) ProtoMessage() {}
func (x *PersonaListResponse) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[8]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PersonaListResponse.ProtoReflect.Descriptor instead.
func (*PersonaListResponse) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{8}
}
func (x *PersonaListResponse) GetPersonas() []*PersonaInfo {
if x != nil {
return x.Personas
}
return nil
}
type PersonaInfo struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
Description string `protobuf:"bytes,3,opt,name=description,proto3" json:"description,omitempty"`
AvatarUrl string `protobuf:"bytes,4,opt,name=avatar_url,json=avatarUrl,proto3" json:"avatar_url,omitempty"`
TalkStyle string `protobuf:"bytes,5,opt,name=talk_style,json=talkStyle,proto3" json:"talk_style,omitempty"`
IsDefault bool `protobuf:"varint,6,opt,name=is_default,json=isDefault,proto3" json:"is_default,omitempty"`
CreatedAt int64 `protobuf:"varint,7,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"`
UpdatedAt int64 `protobuf:"varint,8,opt,name=updated_at,json=updatedAt,proto3" json:"updated_at,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *PersonaInfo) Reset() {
*x = PersonaInfo{}
mi := &file_ai_chat_proto_msgTypes[9]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *PersonaInfo) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PersonaInfo) ProtoMessage() {}
func (x *PersonaInfo) ProtoReflect() protoreflect.Message {
mi := &file_ai_chat_proto_msgTypes[9]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PersonaInfo.ProtoReflect.Descriptor instead.
func (*PersonaInfo) Descriptor() ([]byte, []int) {
return file_ai_chat_proto_rawDescGZIP(), []int{9}
}
func (x *PersonaInfo) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *PersonaInfo) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *PersonaInfo) GetDescription() string {
if x != nil {
return x.Description
}
return ""
}
func (x *PersonaInfo) GetAvatarUrl() string {
if x != nil {
return x.AvatarUrl
}
return ""
}
func (x *PersonaInfo) GetTalkStyle() string {
if x != nil {
return x.TalkStyle
}
return ""
}
func (x *PersonaInfo) GetIsDefault() bool {
if x != nil {
return x.IsDefault
}
return false
}
func (x *PersonaInfo) GetCreatedAt() int64 {
if x != nil {
return x.CreatedAt
}
return 0
}
func (x *PersonaInfo) GetUpdatedAt() int64 {
if x != nil {
return x.UpdatedAt
}
return 0
}
var File_ai_chat_proto protoreflect.FileDescriptor
const file_ai_chat_proto_rawDesc = "" +
"\n" +
"\rai_chat.proto\x12\x0ftopfans.ai_chat\x1a\x1cgoogle/api/annotations.proto\"7\n" +
"\aMessage\x12\x12\n" +
"\x04role\x18\x01 \x01(\tR\x04role\x12\x18\n" +
"\acontent\x18\x02 \x01(\tR\acontent\"L\n" +
"\x12InitSessionRequest\x12\x1d\n" +
"\n" +
"session_id\x18\x01 \x01(\tR\tsessionId\x12\x17\n" +
"\auser_id\x18\x02 \x01(\x03R\x06userId\"]\n" +
"\x13InitSessionResponse\x12'\n" +
"\x0fwelcome_message\x18\x01 \x01(\tR\x0ewelcomeMessage\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\"\x85\x01\n" +
"\x12ChatMessageRequest\x12\x1d\n" +
"\n" +
"session_id\x18\x01 \x01(\tR\tsessionId\x12\x18\n" +
"\amessage\x18\x02 \x01(\tR\amessage\x12\x1d\n" +
"\n" +
"persona_id\x18\x03 \x01(\tR\tpersonaId\x12\x17\n" +
"\auser_id\x18\x04 \x01(\x03R\x06userId\"{\n" +
"\x13ChatMessageResponse\x12\x18\n" +
"\acontent\x18\x01 \x01(\tR\acontent\x12\x1d\n" +
"\n" +
"session_id\x18\x02 \x01(\tR\tsessionId\x12\x15\n" +
"\x06is_end\x18\x03 \x01(\bR\x05isEnd\x12\x14\n" +
"\x05error\x18\x04 \x01(\tR\x05error\"I\n" +
"\x12ChatHistoryRequest\x12\x1d\n" +
"\n" +
"session_id\x18\x01 \x01(\tR\tsessionId\x12\x14\n" +
"\x05limit\x18\x02 \x01(\x05R\x05limit\"I\n" +
"\x13ChatHistoryResponse\x122\n" +
"\ahistory\x18\x01 \x03(\v2\x18.topfans.ai_chat.MessageR\ahistory\"-\n" +
"\x12GetPersonasRequest\x12\x17\n" +
"\auser_id\x18\x01 \x01(\x03R\x06userId\"O\n" +
"\x13PersonaListResponse\x128\n" +
"\bpersonas\x18\x01 \x03(\v2\x1c.topfans.ai_chat.PersonaInfoR\bpersonas\"\xee\x01\n" +
"\vPersonaInfo\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" +
"\x04name\x18\x02 \x01(\tR\x04name\x12 \n" +
"\vdescription\x18\x03 \x01(\tR\vdescription\x12\x1d\n" +
"\n" +
"avatar_url\x18\x04 \x01(\tR\tavatarUrl\x12\x1d\n" +
"\n" +
"talk_style\x18\x05 \x01(\tR\ttalkStyle\x12\x1d\n" +
"\n" +
"is_default\x18\x06 \x01(\bR\tisDefault\x12\x1d\n" +
"\n" +
"created_at\x18\a \x01(\x03R\tcreatedAt\x12\x1d\n" +
"\n" +
"updated_at\x18\b \x01(\x03R\tupdatedAt2\xc9\x03\n" +
"\rAIChatService\x12X\n" +
"\vInitSession\x12#.topfans.ai_chat.InitSessionRequest\x1a$.topfans.ai_chat.InitSessionResponse\x12Z\n" +
"\vSendMessage\x12#.topfans.ai_chat.ChatMessageRequest\x1a$.topfans.ai_chat.ChatMessageResponse0\x01\x12\x85\x01\n" +
"\n" +
"GetHistory\x12#.topfans.ai_chat.ChatHistoryRequest\x1a$.topfans.ai_chat.ChatHistoryResponse\",\x82\xd3\xe4\x93\x02&\x12$/api/v1/ai-chat/history/{session_id}\x12z\n" +
"\vGetPersonas\x12#.topfans.ai_chat.GetPersonasRequest\x1a$.topfans.ai_chat.PersonaListResponse\" \x82\xd3\xe4\x93\x02\x1a\x12\x18/api/v1/ai-chat/personasB6Z4github.com/topfans/backend/pkg/proto/ai_chat;ai_chatb\x06proto3"
var (
file_ai_chat_proto_rawDescOnce sync.Once
file_ai_chat_proto_rawDescData []byte
)
func file_ai_chat_proto_rawDescGZIP() []byte {
file_ai_chat_proto_rawDescOnce.Do(func() {
file_ai_chat_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_ai_chat_proto_rawDesc), len(file_ai_chat_proto_rawDesc)))
})
return file_ai_chat_proto_rawDescData
}
var file_ai_chat_proto_msgTypes = make([]protoimpl.MessageInfo, 10)
var file_ai_chat_proto_goTypes = []any{
(*Message)(nil), // 0: topfans.ai_chat.Message
(*InitSessionRequest)(nil), // 1: topfans.ai_chat.InitSessionRequest
(*InitSessionResponse)(nil), // 2: topfans.ai_chat.InitSessionResponse
(*ChatMessageRequest)(nil), // 3: topfans.ai_chat.ChatMessageRequest
(*ChatMessageResponse)(nil), // 4: topfans.ai_chat.ChatMessageResponse
(*ChatHistoryRequest)(nil), // 5: topfans.ai_chat.ChatHistoryRequest
(*ChatHistoryResponse)(nil), // 6: topfans.ai_chat.ChatHistoryResponse
(*GetPersonasRequest)(nil), // 7: topfans.ai_chat.GetPersonasRequest
(*PersonaListResponse)(nil), // 8: topfans.ai_chat.PersonaListResponse
(*PersonaInfo)(nil), // 9: topfans.ai_chat.PersonaInfo
}
var file_ai_chat_proto_depIdxs = []int32{
0, // 0: topfans.ai_chat.ChatHistoryResponse.history:type_name -> topfans.ai_chat.Message
9, // 1: topfans.ai_chat.PersonaListResponse.personas:type_name -> topfans.ai_chat.PersonaInfo
1, // 2: topfans.ai_chat.AIChatService.InitSession:input_type -> topfans.ai_chat.InitSessionRequest
3, // 3: topfans.ai_chat.AIChatService.SendMessage:input_type -> topfans.ai_chat.ChatMessageRequest
5, // 4: topfans.ai_chat.AIChatService.GetHistory:input_type -> topfans.ai_chat.ChatHistoryRequest
7, // 5: topfans.ai_chat.AIChatService.GetPersonas:input_type -> topfans.ai_chat.GetPersonasRequest
2, // 6: topfans.ai_chat.AIChatService.InitSession:output_type -> topfans.ai_chat.InitSessionResponse
4, // 7: topfans.ai_chat.AIChatService.SendMessage:output_type -> topfans.ai_chat.ChatMessageResponse
6, // 8: topfans.ai_chat.AIChatService.GetHistory:output_type -> topfans.ai_chat.ChatHistoryResponse
8, // 9: topfans.ai_chat.AIChatService.GetPersonas:output_type -> topfans.ai_chat.PersonaListResponse
6, // [6:10] is the sub-list for method output_type
2, // [2:6] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_ai_chat_proto_init() }
func file_ai_chat_proto_init() {
if File_ai_chat_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_chat_proto_rawDesc), len(file_ai_chat_proto_rawDesc)),
NumEnums: 0,
NumMessages: 10,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_ai_chat_proto_goTypes,
DependencyIndexes: file_ai_chat_proto_depIdxs,
MessageInfos: file_ai_chat_proto_msgTypes,
}.Build()
File_ai_chat_proto = out.File
file_ai_chat_proto_goTypes = nil
file_ai_chat_proto_depIdxs = nil
}

View File

@ -0,0 +1,258 @@
// Code generated by protoc-gen-triple. DO NOT EDIT.
//
// Source: ai_chat.proto
package ai_chat
import (
"context"
"net/http"
)
import (
"dubbo.apache.org/dubbo-go/v3"
"dubbo.apache.org/dubbo-go/v3/client"
"dubbo.apache.org/dubbo-go/v3/common"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol"
"dubbo.apache.org/dubbo-go/v3/server"
)
// This is a compile-time assertion to ensure that this generated file and the Triple package
// are compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of Triple newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of Triple or updating the Triple
// version compiled into your binary.
const _ = triple_protocol.IsAtLeastVersion0_1_0
const (
// AIChatServiceName is the fully-qualified name of the AIChatService service.
AIChatServiceName = "topfans.ai_chat.AIChatService"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// AIChatServiceInitSessionProcedure is the fully-qualified name of the AIChatService's InitSession RPC.
AIChatServiceInitSessionProcedure = "/topfans.ai_chat.AIChatService/InitSession"
// AIChatServiceSendMessageProcedure is the fully-qualified name of the AIChatService's SendMessage RPC.
AIChatServiceSendMessageProcedure = "/topfans.ai_chat.AIChatService/SendMessage"
// AIChatServiceGetHistoryProcedure is the fully-qualified name of the AIChatService's GetHistory RPC.
AIChatServiceGetHistoryProcedure = "/topfans.ai_chat.AIChatService/GetHistory"
// AIChatServiceGetPersonasProcedure is the fully-qualified name of the AIChatService's GetPersonas RPC.
AIChatServiceGetPersonasProcedure = "/topfans.ai_chat.AIChatService/GetPersonas"
)
var (
_ AIChatService = (*AIChatServiceImpl)(nil)
_ AIChatService_SendMessageClient = (*AIChatServiceSendMessageClient)(nil)
_ AIChatService_SendMessageServer = (*AIChatServiceSendMessageServer)(nil)
)
// AIChatService is a client for the topfans.ai_chat.AIChatService service.
type AIChatService interface {
InitSession(ctx context.Context, req *InitSessionRequest, opts ...client.CallOption) (*InitSessionResponse, error)
SendMessage(ctx context.Context, req *ChatMessageRequest, opts ...client.CallOption) (AIChatService_SendMessageClient, error)
GetHistory(ctx context.Context, req *ChatHistoryRequest, opts ...client.CallOption) (*ChatHistoryResponse, error)
GetPersonas(ctx context.Context, req *GetPersonasRequest, opts ...client.CallOption) (*PersonaListResponse, error)
}
// NewAIChatService constructs a client for the ai_chat.AIChatService service.
func NewAIChatService(cli *client.Client, opts ...client.ReferenceOption) (AIChatService, error) {
conn, err := cli.DialWithInfo("topfans.ai_chat.AIChatService", &AIChatService_ClientInfo, opts...)
if err != nil {
return nil, err
}
return &AIChatServiceImpl{
conn: conn,
}, nil
}
func SetConsumerAIChatService(srv common.RPCService) {
dubbo.SetConsumerServiceWithInfo(srv, &AIChatService_ClientInfo)
}
// AIChatServiceImpl implements AIChatService.
type AIChatServiceImpl struct {
conn *client.Connection
}
func (c *AIChatServiceImpl) InitSession(ctx context.Context, req *InitSessionRequest, opts ...client.CallOption) (*InitSessionResponse, error) {
resp := new(InitSessionResponse)
if err := c.conn.CallUnary(ctx, []interface{}{req}, resp, "InitSession", opts...); err != nil {
return nil, err
}
return resp, nil
}
func (c *AIChatServiceImpl) SendMessage(ctx context.Context, req *ChatMessageRequest, opts ...client.CallOption) (AIChatService_SendMessageClient, error) {
stream, err := c.conn.CallServerStream(ctx, req, "SendMessage", opts...)
if err != nil {
return nil, err
}
rawStream := stream.(*triple_protocol.ServerStreamForClient)
return &AIChatServiceSendMessageClient{rawStream}, nil
}
func (c *AIChatServiceImpl) GetHistory(ctx context.Context, req *ChatHistoryRequest, opts ...client.CallOption) (*ChatHistoryResponse, error) {
resp := new(ChatHistoryResponse)
if err := c.conn.CallUnary(ctx, []interface{}{req}, resp, "GetHistory", opts...); err != nil {
return nil, err
}
return resp, nil
}
func (c *AIChatServiceImpl) GetPersonas(ctx context.Context, req *GetPersonasRequest, opts ...client.CallOption) (*PersonaListResponse, error) {
resp := new(PersonaListResponse)
if err := c.conn.CallUnary(ctx, []interface{}{req}, resp, "GetPersonas", opts...); err != nil {
return nil, err
}
return resp, nil
}
type AIChatService_SendMessageClient interface {
Recv() bool
ResponseHeader() http.Header
ResponseTrailer() http.Header
Msg() *ChatMessageResponse
Err() error
Conn() (triple_protocol.StreamingClientConn, error)
Close() error
}
type AIChatServiceSendMessageClient struct {
*triple_protocol.ServerStreamForClient
}
func (cli *AIChatServiceSendMessageClient) Recv() bool {
msg := new(ChatMessageResponse)
return cli.ServerStreamForClient.Receive(msg)
}
func (cli *AIChatServiceSendMessageClient) Msg() *ChatMessageResponse {
msg := cli.ServerStreamForClient.Msg()
if msg == nil {
return new(ChatMessageResponse)
}
return msg.(*ChatMessageResponse)
}
func (cli *AIChatServiceSendMessageClient) Conn() (triple_protocol.StreamingClientConn, error) {
return cli.ServerStreamForClient.Conn()
}
var AIChatService_ClientInfo = client.ClientInfo{
InterfaceName: "topfans.ai_chat.AIChatService",
MethodNames: []string{"InitSession", "SendMessage", "GetHistory", "GetPersonas"},
ConnectionInjectFunc: func(dubboCliRaw interface{}, conn *client.Connection) {
dubboCli := dubboCliRaw.(*AIChatServiceImpl)
dubboCli.conn = conn
},
}
// AIChatServiceHandler is an implementation of the topfans.ai_chat.AIChatService service.
type AIChatServiceHandler interface {
InitSession(context.Context, *InitSessionRequest) (*InitSessionResponse, error)
SendMessage(context.Context, *ChatMessageRequest, AIChatService_SendMessageServer) error
GetHistory(context.Context, *ChatHistoryRequest) (*ChatHistoryResponse, error)
GetPersonas(context.Context, *GetPersonasRequest) (*PersonaListResponse, error)
}
func RegisterAIChatServiceHandler(srv *server.Server, hdlr AIChatServiceHandler, opts ...server.ServiceOption) error {
return srv.Register(hdlr, &AIChatService_ServiceInfo, opts...)
}
func SetProviderAIChatService(srv common.RPCService) {
dubbo.SetProviderServiceWithInfo(srv, &AIChatService_ServiceInfo)
}
type AIChatService_SendMessageServer interface {
Send(*ChatMessageResponse) error
ResponseHeader() http.Header
ResponseTrailer() http.Header
Conn() triple_protocol.StreamingHandlerConn
}
type AIChatServiceSendMessageServer struct {
*triple_protocol.ServerStream
}
func (g *AIChatServiceSendMessageServer) Send(msg *ChatMessageResponse) error {
return g.ServerStream.Send(msg)
}
var AIChatService_ServiceInfo = server.ServiceInfo{
InterfaceName: "topfans.ai_chat.AIChatService",
ServiceType: (*AIChatServiceHandler)(nil),
Methods: []server.MethodInfo{
{
Name: "InitSession",
Type: constant.CallUnary,
ReqInitFunc: func() interface{} {
return new(InitSessionRequest)
},
MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) {
req := args[0].(*InitSessionRequest)
res, err := handler.(AIChatServiceHandler).InitSession(ctx, req)
if err != nil {
return nil, err
}
return triple_protocol.NewResponse(res), nil
},
},
{
Name: "SendMessage",
Type: constant.CallServerStream,
ReqInitFunc: func() interface{} {
return new(ChatMessageRequest)
},
StreamInitFunc: func(baseStream interface{}) interface{} {
return &AIChatServiceSendMessageServer{baseStream.(*triple_protocol.ServerStream)}
},
MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) {
req := args[0].(*ChatMessageRequest)
stream := args[1].(AIChatService_SendMessageServer)
if err := handler.(AIChatServiceHandler).SendMessage(ctx, req, stream); err != nil {
return nil, err
}
return nil, nil
},
},
{
Name: "GetHistory",
Type: constant.CallUnary,
ReqInitFunc: func() interface{} {
return new(ChatHistoryRequest)
},
MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) {
req := args[0].(*ChatHistoryRequest)
res, err := handler.(AIChatServiceHandler).GetHistory(ctx, req)
if err != nil {
return nil, err
}
return triple_protocol.NewResponse(res), nil
},
},
{
Name: "GetPersonas",
Type: constant.CallUnary,
ReqInitFunc: func() interface{} {
return new(GetPersonasRequest)
},
MethodFunc: func(ctx context.Context, args []interface{}, handler interface{}) (interface{}, error) {
req := args[0].(*GetPersonasRequest)
res, err := handler.(AIChatServiceHandler).GetPersonas(ctx, req)
if err != nil {
return nil, err
}
return triple_protocol.NewResponse(res), nil
},
},
},
}

View File

@ -0,0 +1,94 @@
syntax = "proto3";
package topfans.ai_chat;
option go_package = "github.com/topfans/backend/pkg/proto/ai_chat;ai_chat";
import "google/api/annotations.proto";
// ==================== ====================
message Message {
string role = 1; // "user" / "assistant"
string content = 2;
}
// ==================== ====================
message InitSessionRequest {
string session_id = 1;
int64 user_id = 2;
}
message InitSessionResponse {
string welcome_message = 1;
string session_id = 2;
}
message ChatMessageRequest {
string session_id = 1;
string message = 2;
string persona_id = 3; //
int64 user_id = 4;
}
message ChatMessageResponse {
string content = 1;
string session_id = 2;
bool is_end = 3;
string error = 4;
}
message ChatHistoryRequest {
string session_id = 1;
int32 limit = 2; // 20
}
message ChatHistoryResponse {
repeated Message history = 1;
}
// ==================== ====================
message GetPersonasRequest {
int64 user_id = 1;
}
message PersonaListResponse {
repeated PersonaInfo personas = 1;
}
message PersonaInfo {
string id = 1;
string name = 2;
string description = 3;
string avatar_url = 4;
string talk_style = 5;
bool is_default = 6;
int64 created_at = 7;
int64 updated_at = 8;
}
// ==================== AI Chat Service ====================
service AIChatService {
//
rpc InitSession(InitSessionRequest) returns (InitSessionResponse);
//
rpc SendMessage(ChatMessageRequest) returns (stream ChatMessageResponse);
//
rpc GetHistory(ChatHistoryRequest) returns (ChatHistoryResponse) {
option (google.api.http) = {
get: "/api/v1/ai-chat/history/{session_id}"
};
}
//
rpc GetPersonas(GetPersonasRequest) returns (PersonaListResponse) {
option (google.api.http) = {
get: "/api/v1/ai-chat/personas"
};
}
}

View File

@ -67,7 +67,7 @@ move_triple_files() {
# 预先创建目标目录
echo "📁 创建目标目录..."
for name in common user social asset gallery ranking activity task starbook; do
for name in common user social asset gallery ranking activity task starbook ai_chat; do
mkdir -p "pkg/proto/$name"
done
echo ""
@ -194,6 +194,20 @@ move_triple_files "topfans/backend/pkg/proto/starbook" "pkg/proto/starbook"
echo "✅ starbook.proto 编译完成"
echo ""
# 编译 ai_chat.proto
echo "📦 编译 ai_chat.proto ..."
protoc --proto_path=proto \
--proto_path=. \
--go_out=pkg/proto/ai_chat \
--go_opt=paths=source_relative \
--go-triple_out=pkg/proto/ai_chat \
--go-triple_opt=paths=source_relative \
ai_chat.proto
move_triple_files "topfans/backend/pkg/proto/ai_chat" "pkg/proto/ai_chat"
echo "✅ ai_chat.proto 编译完成"
echo ""
# 清理可能存在的冗余目录和文件
echo "🔄 清理冗余文件..."
@ -204,7 +218,7 @@ if [ -d "github.com" ]; then
fi
# 删除 proto 目录下的生成文件(如果存在)
for name in common user social asset gallery ranking activity task starbook; do
for name in common user social asset gallery ranking activity task starbook ai_chat; do
if [ -f "proto/$name.pb.go" ]; then
rm "proto/$name.pb.go"
echo " ✅ proto/$name.pb.go 已清理"

View File

@ -0,0 +1,150 @@
-- AI Chat Service 数据库迁移
-- 创建人设表、用户记忆表、配置表
-- ============================================
-- 1. ai_personas 表 - 人设表
-- 存储用户创建的 AI 聊天人设配置
-- ============================================
CREATE TABLE IF NOT EXISTS ai_personas (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), -- 人设唯一标识符
user_id BIGINT NOT NULL, -- 所属用户 ID
name VARCHAR(64) NOT NULL, -- 人设名称
description TEXT, -- 人设描述/简介
avatar_url VARCHAR(512), -- 头像 URL
talk_style VARCHAR(256), -- 说话风格描述
system_prompt TEXT NOT NULL, -- 系统提示词(人设核心定义)
is_default BOOLEAN DEFAULT FALSE, -- 是否为默认人设
created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000, -- 创建时间(毫秒时间戳)
updated_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000 -- 更新时间(毫秒时间戳)
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_ai_personas_user_id ON ai_personas(user_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_personas_user_default ON ai_personas(user_id) WHERE is_default = TRUE;
-- ============================================
-- 2. ai_user_memories 表 - 用户长期记忆表
-- 存储用户的长期记忆信息,用于 AI 对话时召回
-- ============================================
CREATE TABLE IF NOT EXISTS ai_user_memories (
id SERIAL PRIMARY KEY, -- 记忆唯一标识符
user_id BIGINT NOT NULL, -- 所属用户 ID
content TEXT NOT NULL, -- 记忆内容
keywords TEXT[], -- 关键词数组(用于检索)
weight INTEGER DEFAULT 50, -- 权重0-100影响召回优先级
is_core BOOLEAN DEFAULT FALSE, -- 是否为核心记忆(核心记忆优先召回)
created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000, -- 创建时间
updated_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000 -- 更新时间
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_ai_user_memories_user_id ON ai_user_memories(user_id);
CREATE INDEX IF NOT EXISTS idx_ai_user_memories_keywords ON ai_user_memories USING GIN(keywords);
CREATE INDEX IF NOT EXISTS idx_ai_user_memories_weight ON ai_user_memories(weight DESC);
-- ============================================
-- 3. ai_chat_configs 表 - AI Chat 配置表
-- 存储 AI Chat 服务的各类配置项
-- ============================================
CREATE TABLE IF NOT EXISTS ai_chat_configs (
id SERIAL PRIMARY KEY, -- 配置项唯一 ID
config_key VARCHAR(128) NOT NULL UNIQUE, -- 配置键
config_value TEXT NOT NULL, -- 配置值
config_type VARCHAR(32) NOT NULL DEFAULT 'string', -- 配置类型string/number/boolean/json
category VARCHAR(64) NOT NULL, -- 配置分类redis/llm/dialog/token/circuit/summary
description VARCHAR(256), -- 配置描述
is_encrypted BOOLEAN DEFAULT FALSE, -- 是否加密存储(加密字段不返回明文)
updated_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000, -- 更新时间
created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()) * 1000 -- 创建时间
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_ai_chat_configs_category ON ai_chat_configs(category);
CREATE INDEX IF NOT EXISTS idx_ai_chat_configs_key ON ai_chat_configs(config_key);
-- 初始配置数据
INSERT INTO ai_chat_configs (config_key, config_value, config_type, category, description, is_encrypted) VALUES
-- Redis 配置
('redis.host', '127.0.0.1', 'string', 'redis', 'Redis 主机地址', FALSE),
('redis.port', '6379', 'number', 'redis', 'Redis 端口', FALSE),
('redis.password', '123456', 'string', 'redis', 'Redis 密码', TRUE),
('redis.db', '0', 'number', 'redis', 'Redis 数据库编号', FALSE),
-- MiniMax 大模型配置
('minimax.api_key', '', 'string', 'llm', 'MiniMax API Key', TRUE),
('minimax.api_url', 'https://api.minimaxi.com/v1', 'string', 'llm', 'MiniMax API 地址', FALSE),
('minimax.model', 'M2-her', 'string', 'llm', 'MiniMax 模型名称', FALSE),
-- 通义备用模型配置
('qwen.api_key', '', 'string', 'llm', '通义 API Key', TRUE),
('qwen.api_url', 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'string', 'llm', '通义 API 地址', FALSE),
('qwen.model', 'qwen-plus', 'string', 'llm', '通义模型名称', FALSE),
-- 对话配置
('dialog.max_context_turns', '10', 'number', 'dialog', '最大上下文轮数', FALSE),
('dialog.context_expire_seconds', '86400', 'number', 'dialog', '上下文过期时间(秒)', FALSE),
('dialog.memory_recall_topn', '5', 'number', 'dialog', '记忆召回返回条数', FALSE),
('dialog.fallback_threshold', '3', 'number', 'dialog', '模型降级连续失败次数阈值', FALSE),
('dialog.slow_response_ms', '5000', 'number', 'dialog', '慢响应判定阈值(毫秒)', FALSE),
-- AI Token 限制配置
('token.max_total', '32000', 'number', 'token', '总 Token 上限', FALSE),
('token.max_history', '24000', 'number', 'token', '对话历史最大 Token', FALSE),
('token.max_system', '4000', 'number', 'token', 'System Prompt 最大 Token', FALSE),
('token.max_memory', '2000', 'number', 'token', '记忆召回最大 Token', FALSE),
('token.reserved', '2000', 'number', 'token', '保留空间 Token', FALSE),
-- 熔断配置
('circuit.max_fail_count', '5', 'number', 'circuit', '熔断连续失败次数', FALSE),
('circuit.breaker_timeout', '60', 'number', 'circuit', '熔断恢复超时(秒)', FALSE),
-- 摘要配置
('summary.trigger_turns', '10', 'number', 'summary', '自动摘要触发轮数', FALSE),
('summary.max_length', '100', 'number', 'summary', '摘要最大字数', FALSE)
ON CONFLICT (config_key) DO NOTHING;
-- ============================================
-- 表结构注释
-- ============================================
COMMENT ON TABLE ai_personas IS 'AI 人设表 - 存储用户创建的聊天人设配置';
COMMENT ON TABLE ai_user_memories IS '用户长期记忆表 - 存储用户记忆信息,用于对话时召回';
COMMENT ON TABLE ai_chat_configs IS 'AI Chat 配置表 - 存储服务各类配置项';
-- ============================================
-- ai_personas 列注释
-- ============================================
COMMENT ON COLUMN ai_personas.id IS '人设唯一标识符';
COMMENT ON COLUMN ai_personas.user_id IS '所属用户ID';
COMMENT ON COLUMN ai_personas.name IS '人设名称';
COMMENT ON COLUMN ai_personas.description IS '人设描述/简介';
COMMENT ON COLUMN ai_personas.avatar_url IS '人设头像URL';
COMMENT ON COLUMN ai_personas.talk_style IS 'AI说话风格描述';
COMMENT ON COLUMN ai_personas.system_prompt IS '系统提示词,定义人设核心性格和行为';
COMMENT ON COLUMN ai_personas.is_default IS '是否为用户的默认人设';
COMMENT ON COLUMN ai_personas.created_at IS '记录创建时间(毫秒时间戳)';
COMMENT ON COLUMN ai_personas.updated_at IS '记录更新时间(毫秒时间戳)';
-- ============================================
-- ai_user_memories 列注释
-- ============================================
COMMENT ON COLUMN ai_user_memories.id IS '记忆记录唯一ID';
COMMENT ON COLUMN ai_user_memories.user_id IS '所属用户ID';
COMMENT ON COLUMN ai_user_memories.content IS '记忆内容文本';
COMMENT ON COLUMN ai_user_memories.keywords IS '关键词数组,用于记忆检索';
COMMENT ON COLUMN ai_user_memories.weight IS '记忆权重(0-100),影响召回优先级';
COMMENT ON COLUMN ai_user_memories.is_core IS '是否为核心记忆,核心记忆优先召回';
COMMENT ON COLUMN ai_user_memories.created_at IS '记忆创建时间';
COMMENT ON COLUMN ai_user_memories.updated_at IS '记忆更新时间';
-- ============================================
-- ai_chat_configs 列注释
-- ============================================
COMMENT ON COLUMN ai_chat_configs.id IS '配置项唯一ID';
COMMENT ON COLUMN ai_chat_configs.config_key IS '配置键名称';
COMMENT ON COLUMN ai_chat_configs.config_value IS '配置值';
COMMENT ON COLUMN ai_chat_configs.config_type IS '配置类型string/number/boolean/json';
COMMENT ON COLUMN ai_chat_configs.category IS '配置分类redis/llm/dialog/token/circuit/summary';
COMMENT ON COLUMN ai_chat_configs.description IS '配置项的中文描述';
COMMENT ON COLUMN ai_chat_configs.is_encrypted IS '是否加密存储,加密字段不会返回明文';
COMMENT ON COLUMN ai_chat_configs.updated_at IS '配置更新时间';
COMMENT ON COLUMN ai_chat_configs.created_at IS '配置创建时间';

View File

@ -0,0 +1,35 @@
# Dubbo AI Chat Service 配置文件
dubbo:
# 应用配置
application:
name: ai-chat-service
version: 1.0.0
# 注册中心配置
registries:
nacos:
protocol: nacos
address: 127.0.0.1:8848
timeout: 5s
# 协议配置
protocols:
triple:
name: tri
port: 20008 # AI Chat Service 端口
# Provider 配置
provider:
registry-ids: nacos
protocol-ids: triple
services:
# AI Chat Service 服务定义
AIChatService:
interface: "github.com/topfans/backend/pkg/proto/ai_chat.AIChatService"
# Consumer 配置
consumer:
registry-ids: nacos
# 超时配置
timeout: 60s

View File

@ -0,0 +1,23 @@
module github.com/topfans/backend/services/aiChatService
go 1.25.5
require (
dubbo.apache.org/dubbo-go/v3 v3.3.1
github.com/google/uuid v1.6.0
github.com/redis/go-redis/v9 v9.5.1
github.com/stretchr/testify v1.11.1
github.com/topfans/backend v0.0.0
go.uber.org/zap v1.27.1
gorm.io/gorm v1.31.1
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/stretchr/objx v0.4.0 // indirect
github.com/stretchr/testify v1.7.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace github.com/topfans/backend => ../..

View File

@ -0,0 +1,265 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"strconv"
"syscall"
"time"
_ "dubbo.apache.org/dubbo-go/v3/imports"
_ "dubbo.apache.org/dubbo-go/v3/protocol/triple"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/server"
"github.com/joho/godotenv"
"github.com/redis/go-redis/v9"
"github.com/topfans/backend/pkg/database"
"github.com/topfans/backend/pkg/health"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/aiChatService/model"
"github.com/topfans/backend/services/aiChatService/provider"
"github.com/topfans/backend/services/aiChatService/repository"
"github.com/topfans/backend/services/aiChatService/service"
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
"go.uber.org/zap"
)
var (
port = flag.Int("port", getEnvInt("PORT", 20008), "Dubbo service port")
dbHost = flag.String("db-host", getEnv("DB_HOST", "localhost"), "Database host")
dbPort = flag.Int("db-port", getEnvInt("DB_PORT", 5432), "Database port")
dbUser = flag.String("db-user", getEnv("DB_USER", "postgres"), "Database user")
dbPassword = flag.String("db-password", getEnv("DB_PASSWORD", ""), "Database password")
dbName = flag.String("db-name", getEnv("DB_NAME", "top-fans"), "Database name")
redisHost = flag.String("redis-host", getEnv("REDIS_HOST", "127.0.0.1"), "Redis host")
redisPort = flag.Int("redis-port", getEnvInt("REDIS_PORT", 6379), "Redis port")
redisPassword = flag.String("redis-password", getEnv("REDIS_PASSWORD", ""), "Redis password")
redisDB = flag.Int("redis-db", getEnvInt("REDIS_DB", 0), "Redis db")
contextTTL = flag.Int("context-ttl", getEnvInt("CONTEXT_TTL", 86400), "Context TTL in seconds")
triggerTurns = flag.Int("trigger-turns", getEnvInt("TRIGGER_TURNS", 5), "Turns to trigger memory extraction")
healthHandler *health.Handler
)
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return fallback
}
func main() {
godotenv.Load()
flag.Parse()
env := os.Getenv("ENV")
if env == "" {
env = "development"
}
if err := logger.Init(logger.Config{
ServiceName: "ai-chat-service",
Environment: env,
LogLevel: os.Getenv("LOG_LEVEL"),
}); err != nil {
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
}
defer logger.Sync()
logger.Logger.Info("Starting AI Chat Service...")
// 初始化数据库
dbConfig := database.Config{
Host: *dbHost,
Port: *dbPort,
User: *dbUser,
Password: *dbPassword,
DBName: *dbName,
SSLMode: "disable",
TimeZone: "Asia/Shanghai",
}
if err := database.Init(dbConfig); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to initialize database: %v", err))
}
logger.Logger.Info("Database initialized successfully")
// 初始化 Redis
redisClient := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", *redisHost, *redisPort),
Password: *redisPassword,
DB: *redisDB,
})
// 测试 Redis 连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := redisClient.Ping(ctx).Err(); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to connect to Redis: %v", err))
}
cancel()
logger.Logger.Info("Redis connected successfully")
// 启动健康检查 HTTP 服务器
healthPort := *port + 1000
healthHandler = health.NewHandler("ai-chat-service", healthPort)
healthHandler.Start()
// 自动迁移数据库表
if err := autoMigrate(); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to migrate database: %v", err))
}
// 创建 Repository 层实例
personaRepo := repository.NewPostgreSQLPersonaRepository(database.GetDB())
configRepo := repository.NewPostgreSQLConfigRepository(database.GetDB())
shortTermMemoryRepo := repository.NewRedisMemoryRepository(redisClient, *contextTTL)
longTermMemoryRepo := repository.NewPostgreSQLMemoryRepository(database.GetDB())
logger.Logger.Info("Repository layer initialized")
// 从数据库加载 LLM 配置
loadCtx := context.Background()
llmConfigs, err := configRepo.GetByCategory(loadCtx, "llm")
if err != nil {
logger.Logger.Warn("Failed to load LLM configs from database, using env defaults", zap.Error(err))
}
// 获取 MiniMax 配置
miniMaxAPIKey := getEnv("MINIMAX_API_KEY", "")
miniMaxAPIURL := getEnv("MINIMAX_API_URL", "https://api.minimaxi.com/v1")
miniMaxModel := getEnv("MINIMAX_MODEL", "M2-her")
if val, ok := llmConfigs["minimax.api_key"]; ok && val != "" {
miniMaxAPIKey = val
}
if val, ok := llmConfigs["minimax.api_url"]; ok && val != "" {
miniMaxAPIURL = val
}
if val, ok := llmConfigs["minimax.model"]; ok && val != "" {
miniMaxModel = val
}
// 获取 Qwen 配置
qwenAPIKey := getEnv("QWEN_API_KEY", "")
qwenAPIURL := getEnv("QWEN_API_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
qwenModel := getEnv("QWEN_MODEL", "qwen-plus")
if val, ok := llmConfigs["qwen.api_key"]; ok && val != "" {
qwenAPIKey = val
}
if val, ok := llmConfigs["qwen.api_url"]; ok && val != "" {
qwenAPIURL = val
}
if val, ok := llmConfigs["qwen.model"]; ok && val != "" {
qwenModel = val
}
logger.Logger.Info("LLM configs loaded",
zap.String("minimax_url", miniMaxAPIURL),
zap.String("minimax_model", miniMaxModel),
zap.Bool("has_minimax_key", miniMaxAPIKey != ""),
zap.String("qwen_url", qwenAPIURL),
zap.String("qwen_model", qwenModel),
zap.Bool("has_qwen_key", qwenAPIKey != ""),
)
// 创建 Service 层实例
llmService := service.NewLLMService(
miniMaxAPIURL,
miniMaxAPIKey,
miniMaxModel,
qwenAPIURL,
qwenAPIKey,
qwenModel,
)
personaService := service.NewPersonaService(personaRepo)
memoryService := service.NewMemoryService(shortTermMemoryRepo, longTermMemoryRepo)
auditService := service.NewAuditService()
chatService := service.NewChatService(
llmService,
personaService,
memoryService,
auditService,
*contextTTL,
*triggerTurns,
)
logger.Logger.Info("Service layer initialized")
// 创建 Provider 层实例
aiChatProvider := provider.NewAIChatProvider(
chatService,
personaService,
memoryService,
auditService,
)
logger.Logger.Info("Provider layer initialized")
// 创建 Dubbo 服务器
srv, err := server.NewServer(
server.WithServerProtocol(
protocol.WithPort(*port),
protocol.WithTriple(),
),
)
if err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to create Dubbo server: %v", err))
}
// 注册 AI Chat Service
if err := pbAIChat.RegisterAIChatServiceHandler(srv, aiChatProvider); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to register AI Chat Service: %v", err))
}
// 启动服务
if err := srv.Serve(); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to start AI Chat Service: %v", err))
}
logger.Logger.Info(fmt.Sprintf("AI Chat Service started successfully on port %d", *port))
// 等待退出信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Logger.Info("Shutting down AI Chat Service...")
if healthHandler != nil {
healthHandler.Stop()
}
redisClient.Close()
}
// autoMigrate 自动迁移数据库表
func autoMigrate() error {
db := database.GetDB()
if db == nil {
return fmt.Errorf("database is not initialized")
}
tables := []interface{}{
&model.Persona{},
&model.UserMemory{},
&model.Config{},
}
for _, table := range tables {
if err := db.AutoMigrate(table); err != nil {
// 如果约束删除失败(表已存在该约束),忽略并继续
logger.Logger.Warn(fmt.Sprintf("Migration warning for %T: %v", table, err))
}
}
logger.Logger.Info("Database migration completed successfully")
return nil
}

View File

@ -0,0 +1,26 @@
package model
import "errors"
var (
// ErrPersonaNotFound 人设不存在
ErrPersonaNotFound = errors.New("persona_not_found")
// ErrMemoryNotFound 记忆不存在
ErrMemoryNotFound = errors.New("memory_not_found")
// ErrAuditBlocked 内容审核未通过
ErrAuditBlocked = errors.New("audit_blocked")
// ErrSessionNotFound 会话不存在
ErrSessionNotFound = errors.New("session_not_found")
// ErrConfigNotFound 配置不存在
ErrConfigNotFound = errors.New("config_not_found")
// ErrLLMFailed 大模型调用失败
ErrLLMFailed = errors.New("llm_failed")
// ErrInvalidRequest 无效请求
ErrInvalidRequest = errors.New("invalid_request")
)

View File

@ -0,0 +1,102 @@
package model
import (
"time"
"github.com/google/uuid"
)
// Persona 人设结构
type Persona struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;default:gen_random_uuid()"`
UserID int64 `json:"user_id" gorm:"index;not null"`
Name string `json:"name" gorm:"type:varchar(64);not null"`
Description string `json:"description" gorm:"type:text"`
AvatarURL string `json:"avatar_url" gorm:"type:varchar(512)"`
TalkStyle string `json:"talk_style" gorm:"type:varchar(256)"`
SystemPrompt string `json:"system_prompt" gorm:"type:text;not null"`
IsDefault bool `json:"is_default" gorm:"default:false"`
CreatedAt int64 `json:"created_at" gorm:"autoCreateTime:milli"`
UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime:milli"`
}
// TableName 指定表名
func (Persona) TableName() string {
return "ai_personas"
}
// UserMemory 用户长期记忆
type UserMemory struct {
ID uint `json:"id" gorm:"primaryKey;autoIncrement"`
UserID int64 `json:"user_id" gorm:"index;not null"`
Content string `json:"content" gorm:"type:text;not null"`
Keywords []string `json:"keywords" gorm:"type:text[]"`
Weight int `json:"weight" gorm:"default:50"`
IsCore bool `json:"is_core" gorm:"default:false"`
CreatedAt int64 `json:"created_at" gorm:"autoCreateTime:milli"`
UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime:milli"`
}
// TableName 指定表名
func (UserMemory) TableName() string {
return "ai_user_memories"
}
// Config 配置结构
type Config struct {
ID uint `json:"id" gorm:"primaryKey;autoIncrement"`
ConfigKey string `json:"config_key" gorm:"type:varchar(128);uniqueIndex;not null"`
ConfigValue string `json:"config_value" gorm:"type:text;not null"`
ConfigType string `json:"config_type" gorm:"type:varchar(32);default:string"`
Category string `json:"category" gorm:"type:varchar(64);index"`
Description string `json:"description" gorm:"type:varchar(256)"`
IsEncrypted bool `json:"is_encrypted" gorm:"default:false"`
UpdatedAt int64 `json:"updated_at" gorm:"autoUpdateTime:milli"`
CreatedAt int64 `json:"created_at" gorm:"autoCreateTime:milli"`
}
// TableName 指定表名
func (Config) TableName() string {
return "ai_chat_configs"
}
// Message 对话消息
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ChatContext 短期上下文Redis 存储)
type ChatContext struct {
SessionID string `json:"session_id"`
Messages []Message `json:"messages"`
PersonaID string `json:"persona_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PersonaInfo 人设信息API 返回)
type PersonaInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
AvatarURL string `json:"avatar_url"`
TalkStyle string `json:"talk_style"`
IsDefault bool `json:"is_default"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
// ToPersonaInfo converts Persona to PersonaInfo
func (p *Persona) ToPersonaInfo() PersonaInfo {
return PersonaInfo{
ID: p.ID.String(),
Name: p.Name,
Description: p.Description,
AvatarURL: p.AvatarURL,
TalkStyle: p.TalkStyle,
IsDefault: p.IsDefault,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}

View File

@ -0,0 +1,442 @@
package provider
import (
"context"
"fmt"
"io"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/aiChatService/model"
"github.com/topfans/backend/services/aiChatService/service"
pb "github.com/topfans/backend/pkg/proto/ai_chat"
"go.uber.org/zap"
)
// AIChatProvider AI Chat 服务 Provider 实现
type AIChatProvider struct {
chatService *service.ChatService
personaService *service.PersonaService
memoryService *service.MemoryService
auditService *service.AuditService
}
// 确保 AIChatProvider 实现了 AIChatServiceHandler 接口
var _ pb.AIChatServiceHandler = (*AIChatProvider)(nil)
// NewAIChatProvider 创建 AIChatProvider 实例
func NewAIChatProvider(
chatService *service.ChatService,
personaService *service.PersonaService,
memoryService *service.MemoryService,
auditService *service.AuditService,
) *AIChatProvider {
return &AIChatProvider{
chatService: chatService,
personaService: personaService,
memoryService: memoryService,
auditService: auditService,
}
}
// InitSession 初始化会话,返回欢迎消息
func (p *AIChatProvider) InitSession(ctx context.Context, req *pb.InitSessionRequest) (*pb.InitSessionResponse, error) {
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
sessionID := req.SessionId
if sessionID == "" {
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
if err != nil {
logger.Logger.Error("Failed to extract user info from attachments",
zap.Error(err),
)
return nil, err
}
logger.Logger.Info("Received InitSession request",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
)
// 获取欢迎消息
welcomeMessage := p.chatService.GetWelcomeMessage(sessionID, userID, starID)
return &pb.InitSessionResponse{
WelcomeMessage: welcomeMessage,
SessionId: sessionID,
}, nil
}
// SendMessage 发送消息(流式返回)
func (p *AIChatProvider) SendMessage(ctx context.Context, req *pb.ChatMessageRequest, stream pb.AIChatService_SendMessageServer) error {
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
sessionID := req.SessionId
if sessionID == "" {
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
if err != nil {
logger.Logger.Error("Failed to extract user info from attachments",
zap.Error(err),
)
stream.Send(&pb.ChatMessageResponse{
Content: "user authentication required",
SessionId: sessionID,
IsEnd: true,
Error: err.Error(),
})
return err
}
if sessionID == "" {
// 如果没有 sessionID生成一个
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
message := req.Message
personaID := req.PersonaId
logger.Logger.Info("Received SendMessage request",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
zap.String("message", message),
)
// 1. 前置审核
if !p.auditService.AuditText(message) {
logger.Logger.Info("Message blocked by audit")
stream.Send(&pb.ChatMessageResponse{
Content: p.auditService.DefaultSafeResponse(),
SessionId: sessionID,
IsEnd: false,
})
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
return nil
}
// 2. 获取人设
persona, err := p.personaService.GetPersonaOrDefault(ctx, userID, personaID)
if err != nil {
logger.Logger.Error("Failed to get persona", zap.Error(err))
stream.Send(&pb.ChatMessageResponse{
Content: err.Error(),
IsEnd: true,
Error: err.Error(),
})
return err
}
// 3. 记忆召回
memoryText, _ := p.memoryService.RecallMemories(ctx, userID, message, 5)
// 4. 获取对话历史
history, _ := p.memoryService.GetContext(ctx, sessionID)
// 5. 构建 Prompt
tokenizer := &service.Tokenizer{}
messages, _ := service.BuildPrompt(
persona.SystemPrompt,
memoryText,
history,
message,
tokenizer,
)
// 6. 检查是否需要调用大模型
if service.IsNoNeedLLMCall(message) {
stream.Send(&pb.ChatMessageResponse{
Content: "好的,我听到了。",
SessionId: sessionID,
IsEnd: false,
})
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
return nil
}
// 7. 调用大模型(流式)
streamReader, err := p.chatService.LLMService.StreamChat(ctx, messages)
if err != nil {
logger.Logger.Error("LLM call failed", zap.Error(err))
// 检查是否是敏感内容错误
if _, ok := err.(*service.SensitiveContentError); ok {
logger.Logger.Info("Content blocked by MiniMax safety filter, trying backup model")
// 尝试备用模型
streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages)
if err != nil {
// 备用模型也失败
logger.Logger.Error("Backup model also failed", zap.Error(err))
stream.Send(&pb.ChatMessageResponse{
Content: p.auditService.DefaultSafeResponse(),
SessionId: sessionID,
IsEnd: true,
})
return nil
}
} else {
// 其他错误,尝试备用模型
streamReader, err = p.chatService.LLMService.StreamChatWithBackup(ctx, messages)
if err != nil {
logger.Logger.Error("Backup model also failed", zap.Error(err))
stream.Send(&pb.ChatMessageResponse{
Content: "抱歉,服务暂时不可用",
SessionId: sessionID,
IsEnd: true,
Error: err.Error(),
})
return err
}
}
}
defer streamReader.Close()
// 8. 流式处理
var fullResponse string
var sentEnd = false
for {
content, done, err := streamReader.Next()
if err != nil {
if err == io.EOF {
// 流结束,发送 is_end
if !sentEnd {
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
sentEnd = true
}
break
}
logger.Logger.Error("Stream read error", zap.Error(err))
break
}
// 后置审核(逐 token
if !p.auditService.AuditResponse(content) {
logger.Logger.Info("Response blocked by audit")
streamReader.Close()
// 发送安全回复作为替代
stream.Send(&pb.ChatMessageResponse{
Content: p.auditService.DefaultSafeResponse(),
SessionId: sessionID,
IsEnd: false,
})
stream.Send(&pb.ChatMessageResponse{
SessionId: sessionID,
IsEnd: true,
})
sentEnd = true
return nil
}
fullResponse += content
// 发送 token 给客户端
if err := stream.Send(&pb.ChatMessageResponse{
Content: content,
SessionId: sessionID,
IsEnd: done,
}); err != nil {
logger.Logger.Error("Failed to send message to stream", zap.Error(err))
return err
}
if done {
sentEnd = true
}
}
// 9. 保存上下文
newHistory := append(history, model.Message{Role: "user", Content: message})
newHistory = append(newHistory, model.Message{Role: "assistant", Content: fullResponse})
p.memoryService.SaveContext(ctx, sessionID, newHistory, personaID)
// 10. 触发记忆提取每5轮
newTurns := len(newHistory) / 2
logger.Logger.Info("Memory extraction check",
zap.Int("message_count", len(newHistory)),
zap.Int("turns", newTurns),
zap.Bool("should_extract", newTurns >= 5),
)
if newTurns >= 5 {
logger.Logger.Info("Triggering memory extraction", zap.Int64("user_id", userID))
if err := p.memoryService.ExtractMemory(ctx, userID, newHistory); err != nil {
logger.Logger.Error("Failed to extract memory", zap.Error(err))
} else {
logger.Logger.Info("Memory extracted successfully", zap.Int64("user_id", userID))
}
}
logger.Logger.Info("SendMessage completed",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
zap.Int("response_length", len(fullResponse)),
)
return nil
}
// GetHistory 获取对话历史
func (p *AIChatProvider) GetHistory(ctx context.Context, req *pb.ChatHistoryRequest) (*pb.ChatHistoryResponse, error) {
userID, starID, err := extractUserInfoFromDubboAttachments(ctx)
if err != nil {
logger.Logger.Error("Failed to extract user info from attachments",
zap.Error(err),
)
return nil, err
}
sessionID := req.SessionId
if sessionID == "" {
sessionID = fmt.Sprintf("%d_%d", userID, starID)
}
logger.Logger.Info("Received GetHistory request",
zap.Int64("user_id", userID),
zap.String("session_id", sessionID),
)
messages, err := p.memoryService.GetContext(ctx, sessionID)
if err != nil {
return nil, err
}
pbMessages := make([]*pb.Message, len(messages))
for i, m := range messages {
pbMessages[i] = &pb.Message{
Role: m.Role,
Content: m.Content,
}
}
return &pb.ChatHistoryResponse{
History: pbMessages,
}, nil
}
// GetPersonas 获取用户的所有人设
func (p *AIChatProvider) GetPersonas(ctx context.Context, req *pb.GetPersonasRequest) (*pb.PersonaListResponse, error) {
userID := req.UserId
logger.Logger.Info("Received GetPersonas request",
zap.Int64("user_id", userID),
)
personas, err := p.personaService.GetPersonas(ctx, userID)
if err != nil {
return nil, err
}
pbPersonas := make([]*pb.PersonaInfo, len(personas))
for i, persona := range personas {
pbPersonas[i] = &pb.PersonaInfo{
Id: persona.ID,
Name: persona.Name,
Description: persona.Description,
AvatarUrl: persona.AvatarURL,
TalkStyle: persona.TalkStyle,
IsDefault: persona.IsDefault,
CreatedAt: persona.CreatedAt,
UpdatedAt: persona.UpdatedAt,
}
}
return &pb.PersonaListResponse{
Personas: pbPersonas,
}, nil
}
// extractUserInfoFromDubboAttachments 从 Dubbo attachments 中提取用户信息
func extractUserInfoFromDubboAttachments(ctx context.Context) (int64, int64, error) {
logger.Logger.Debug("Extracting user info from Dubbo attachments",
zap.Any("context_type", fmt.Sprintf("%T", ctx)),
)
// Try to get any value from context
if attachments := ctx.Value(constant.AttachmentKey); attachments != nil {
logger.Logger.Debug("Found attachments via constant.AttachmentKey",
zap.Any("attachments", attachments),
zap.String("type", fmt.Sprintf("%T", attachments)),
)
if attMap, ok := attachments.(map[string]interface{}); ok {
logger.Logger.Debug("Attachments map content",
zap.Any("map", attMap),
)
userID := parseIntValue(attMap["user_id"])
starID := parseIntValue(attMap["star_id"])
logger.Logger.Debug("Parsed user info from attachments",
zap.Any("user_id_raw", attMap["user_id"]),
zap.Int64("user_id", userID),
zap.Any("star_id_raw", attMap["star_id"]),
zap.Int64("star_id", starID),
)
if userID > 0 && starID > 0 {
return userID, starID, nil
}
logger.Logger.Warn("Parsed user_id or star_id is zero",
zap.Int64("user_id", userID),
zap.Int64("star_id", starID),
)
} else {
logger.Logger.Warn("Attachments is not map[string]interface{}",
zap.String("actual_type", fmt.Sprintf("%T", attachments)),
)
}
} else {
logger.Logger.Warn("ctx.Value(constant.AttachmentKey) returned nil",
zap.String("constant_attachment_key", string(constant.AttachmentKey)),
)
}
// Debug: list all keys in context
logger.Logger.Warn("Checking alternative key: 'attachment'")
if val := ctx.Value("attachment"); val != nil {
logger.Logger.Debug("Found value with key 'attachment'",
zap.Any("value", val),
zap.String("type", fmt.Sprintf("%T", val)),
)
}
return 0, 0, fmt.Errorf("user info not found in Dubbo attachments")
}
// parseIntValue 解析各种类型的值为 int64
func parseIntValue(v interface{}) int64 {
switch val := v.(type) {
case int64:
return val
case int:
return int64(val)
case float64:
return int64(val)
case string:
var result int64
fmt.Sscanf(val, "%d", &result)
return result
case []string:
if len(val) > 0 {
var result int64
fmt.Sscanf(val[0], "%d", &result)
return result
}
case []interface{}:
if len(val) > 0 {
switch s := val[0].(type) {
case string:
var result int64
fmt.Sscanf(s, "%d", &result)
return result
case int:
return int64(s)
case int64:
return s
}
}
}
return 0
}

View File

@ -0,0 +1,60 @@
package repository
import (
"context"
"fmt"
"github.com/topfans/backend/services/aiChatService/model"
"gorm.io/gorm"
)
// ConfigRepository 配置仓库接口
type ConfigRepository interface {
Get(ctx context.Context, key string) (string, error)
GetByCategory(ctx context.Context, category string) (map[string]string, error)
Update(ctx context.Context, key, value string) error
}
// PostgreSQLConfigRepository PostgreSQL 配置仓库实现
type PostgreSQLConfigRepository struct {
db *gorm.DB
}
// NewPostgreSQLConfigRepository 创建配置仓库
func NewPostgreSQLConfigRepository(db *gorm.DB) *PostgreSQLConfigRepository {
return &PostgreSQLConfigRepository{db: db}
}
// Get 获取单个配置值
func (r *PostgreSQLConfigRepository) Get(ctx context.Context, key string) (string, error) {
var config model.Config
if err := r.db.WithContext(ctx).Where("config_key = ?", key).First(&config).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return "", model.ErrConfigNotFound
}
return "", fmt.Errorf("failed to get config: %w", err)
}
return config.ConfigValue, nil
}
// GetByCategory 按分类获取所有配置
func (r *PostgreSQLConfigRepository) GetByCategory(ctx context.Context, category string) (map[string]string, error) {
var configs []model.Config
if err := r.db.WithContext(ctx).Where("category = ?", category).Find(&configs).Error; err != nil {
return nil, fmt.Errorf("failed to get configs: %w", err)
}
result := make(map[string]string)
for _, c := range configs {
result[c.ConfigKey] = c.ConfigValue
}
return result, nil
}
// Update 更新配置值
func (r *PostgreSQLConfigRepository) Update(ctx context.Context, key, value string) error {
return r.db.WithContext(ctx).
Model(&model.Config{}).
Where("config_key = ?", key).
Update("config_value", value).Error
}

View File

@ -0,0 +1,156 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/lib/pq"
"github.com/redis/go-redis/v9"
"github.com/topfans/backend/services/aiChatService/model"
"gorm.io/gorm"
)
// ShortTermMemoryRepository 短期记忆仓库接口 (Redis)
type ShortTermMemoryRepository interface {
SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error
GetContext(ctx context.Context, sessionID string) ([]model.Message, error)
DeleteContext(ctx context.Context, sessionID string) error
}
// LongTermMemoryRepository 长期记忆仓库接口 (PostgreSQL)
type LongTermMemoryRepository interface {
SaveMemory(ctx context.Context, memory *model.UserMemory) error
GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error)
GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error)
}
// MemoryRepository 记忆仓库接口(兼容性别名,用于不需要区分短期/长期的场景)
type MemoryRepository interface {
ShortTermMemoryRepository
LongTermMemoryRepository
}
// RedisMemoryRepository Redis 短期记忆实现
type RedisMemoryRepository struct {
client *redis.Client
ttl time.Duration
}
// NewRedisMemoryRepository 创建 Redis 短期记忆仓库
func NewRedisMemoryRepository(client *redis.Client, ttlSeconds int) *RedisMemoryRepository {
return &RedisMemoryRepository{
client: client,
ttl: time.Duration(ttlSeconds) * time.Second,
}
}
// contextKey 生成 Redis key
func contextKey(sessionID string) string {
return fmt.Sprintf("context:%s", sessionID)
}
// SaveContext 保存短期上下文到 Redis
func (r *RedisMemoryRepository) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error {
data := model.ChatContext{
SessionID: sessionID,
Messages: messages,
PersonaID: personaID,
UpdatedAt: time.Now(),
}
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal context: %w", err)
}
return r.client.Set(ctx, contextKey(sessionID), jsonData, r.ttl).Err()
}
// GetContext 从 Redis 获取短期上下文
func (r *RedisMemoryRepository) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) {
data, err := r.client.Get(ctx, contextKey(sessionID)).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, fmt.Errorf("failed to get context: %w", err)
}
var chatCtx model.ChatContext
if err := json.Unmarshal(data, &chatCtx); err != nil {
return nil, fmt.Errorf("failed to unmarshal context: %w", err)
}
return chatCtx.Messages, nil
}
// DeleteContext 删除短期上下文
func (r *RedisMemoryRepository) DeleteContext(ctx context.Context, sessionID string) error {
return r.client.Del(ctx, contextKey(sessionID)).Err()
}
// PostgreSQLMemoryRepository PostgreSQL 长期记忆实现
type PostgreSQLMemoryRepository struct {
db *gorm.DB
}
// NewPostgreSQLMemoryRepository 创建 PostgreSQL 长期记忆仓库
func NewPostgreSQLMemoryRepository(db *gorm.DB) *PostgreSQLMemoryRepository {
return &PostgreSQLMemoryRepository{db: db}
}
// SaveMemory 保存长期记忆
func (r *PostgreSQLMemoryRepository) SaveMemory(ctx context.Context, memory *model.UserMemory) error {
now := time.Now().UnixMilli()
if memory.CreatedAt == 0 {
memory.CreatedAt = now
}
if memory.UpdatedAt == 0 {
memory.UpdatedAt = now
}
return r.db.WithContext(ctx).Exec(
`INSERT INTO "ai_user_memories" ("user_id", "content", "keywords", "weight", "is_core", "created_at", "updated_at")
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
memory.UserID,
memory.Content,
pq.Array(memory.Keywords),
memory.Weight,
memory.IsCore,
memory.CreatedAt,
memory.UpdatedAt,
).Error
}
// GetMemories 根据关键词查询记忆
func (r *PostgreSQLMemoryRepository) GetMemories(ctx context.Context, userID int64, keywords []string, limit int) ([]model.UserMemory, error) {
var memories []model.UserMemory
query := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("weight DESC, created_at DESC").
Limit(limit)
if len(keywords) > 0 {
query = query.Where("keywords && $1", pq.Array(keywords))
}
if err := query.Find(&memories).Error; err != nil {
return nil, fmt.Errorf("failed to get memories: %w", err)
}
return memories, nil
}
// GetMemoriesByUserID 获取用户所有记忆
func (r *PostgreSQLMemoryRepository) GetMemoriesByUserID(ctx context.Context, userID int64) ([]model.UserMemory, error) {
var memories []model.UserMemory
if err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("weight DESC, created_at DESC").
Find(&memories).Error; err != nil {
return nil, fmt.Errorf("failed to get memories: %w", err)
}
return memories, nil
}

View File

@ -0,0 +1,122 @@
package repository
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/topfans/backend/services/aiChatService/model"
"gorm.io/gorm"
)
// PersonaRepository 人设仓库接口
type PersonaRepository interface {
Create(ctx context.Context, persona *model.Persona) error
GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error)
GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error)
GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error)
Update(ctx context.Context, persona *model.Persona) error
Delete(ctx context.Context, id uuid.UUID) error
EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error)
}
// DefaultSystemPrompt 默认系统提示词
const DefaultSystemPrompt = `你是一个温柔体贴的AI伴侣名字叫角角你善于倾听能理解用户的情绪
用温暖的话语陪伴用户说话风格亲切自然像朋友聊天一样
不要过于正式或说教当用户情绪低落时先给予共情和安慰`
// DefaultPersonaName 默认人设名称
const DefaultPersonaName = "角角"
// DefaultPersonaDescription 默认人设描述
const DefaultPersonaDescription = "温柔陪伴型闺蜜"
// PostgreSQLPersonaRepository PostgreSQL 人设仓库实现
type PostgreSQLPersonaRepository struct {
db *gorm.DB
}
// NewPostgreSQLPersonaRepository 创建人设仓库
func NewPostgreSQLPersonaRepository(db *gorm.DB) *PostgreSQLPersonaRepository {
return &PostgreSQLPersonaRepository{db: db}
}
// Create 创建人设
func (r *PostgreSQLPersonaRepository) Create(ctx context.Context, persona *model.Persona) error {
return r.db.WithContext(ctx).Create(persona).Error
}
// GetByID 根据 ID 获取人设
func (r *PostgreSQLPersonaRepository) GetByID(ctx context.Context, id uuid.UUID) (*model.Persona, error) {
var persona model.Persona
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&persona).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, model.ErrPersonaNotFound
}
return nil, fmt.Errorf("failed to get persona: %w", err)
}
return &persona, nil
}
// GetByUserID 获取用户的所有人设
func (r *PostgreSQLPersonaRepository) GetByUserID(ctx context.Context, userID int64) ([]model.Persona, error) {
var personas []model.Persona
if err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&personas).Error; err != nil {
return nil, fmt.Errorf("failed to get personas: %w", err)
}
return personas, nil
}
// GetDefaultByUserID 获取用户的默认人设
func (r *PostgreSQLPersonaRepository) GetDefaultByUserID(ctx context.Context, userID int64) (*model.Persona, error) {
var persona model.Persona
if err := r.db.WithContext(ctx).
Where("user_id = ? AND is_default = TRUE", userID).
First(&persona).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, model.ErrPersonaNotFound
}
return nil, fmt.Errorf("failed to get default persona: %w", err)
}
return &persona, nil
}
// Update 更新人设
func (r *PostgreSQLPersonaRepository) Update(ctx context.Context, persona *model.Persona) error {
return r.db.WithContext(ctx).Save(persona).Error
}
// Delete 删除人设
func (r *PostgreSQLPersonaRepository) Delete(ctx context.Context, id uuid.UUID) error {
return r.db.WithContext(ctx).Delete(&model.Persona{}, "id = ?", id).Error
}
// EnsureDefaultPersona 确保用户有默认人设
func (r *PostgreSQLPersonaRepository) EnsureDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error) {
// 检查是否已有默认人设
persona, err := r.GetDefaultByUserID(ctx, userID)
if err == nil {
return persona, nil
}
if err != model.ErrPersonaNotFound {
return nil, err
}
// 创建默认人设
persona = &model.Persona{
UserID: userID,
Name: DefaultPersonaName,
Description: DefaultPersonaDescription,
SystemPrompt: DefaultSystemPrompt,
IsDefault: true,
}
if err := r.Create(ctx, persona); err != nil {
return nil, fmt.Errorf("failed to create default persona: %w", err)
}
return persona, nil
}

View File

@ -0,0 +1,59 @@
package service
import "strings"
// AuditService 审核服务
type AuditService struct {
// 敏感词列表
sensitiveWords []string
}
// NewAuditService 创建审核服务
func NewAuditService() *AuditService {
return &AuditService{
sensitiveWords: []string{
// 政治类
"台独", "港独", "藏独", "疆独", "分裂", "颠覆",
// 色情类
"色情", "裸聊", "约炮", "成人",
// 暴力类
"杀人", "虐待", "暴力",
// 违规诱导
"转账", "汇款", "银行卡", "密码",
// AI 身份冒充
"你是真人", "你是人类", "真人在吗",
},
}
}
// AuditText 审核文本内容
// 返回 true 表示通过false 表示违规
func (s *AuditService) AuditText(text string) bool {
for _, word := range s.sensitiveWords {
if strings.Contains(text, word) {
return false
}
}
return true
}
// AuditResponse 审核 AI 回复(逐 token
func (s *AuditService) AuditResponse(token string) bool {
// 累加 token 判断是否违规
for _, word := range s.sensitiveWords {
if strings.Contains(token, word) {
return false
}
}
return true
}
// DefaultSafeResponse 获取默认安全回复
func (s *AuditService) DefaultSafeResponse() string {
return "抱歉,这个话题我无法继续,我们换个话题聊聊吧。"
}
// AddSensitiveWord 添加敏感词
func (s *AuditService) AddSensitiveWord(word string) {
s.sensitiveWords = append(s.sensitiveWords, word)
}

View File

@ -0,0 +1,60 @@
package service
import (
"context"
"github.com/topfans/backend/services/aiChatService/model"
)
// ChatService 对话服务
type ChatService struct {
LLMService *LLMService
personaService *PersonaService
memoryService *MemoryService
auditService *AuditService
contextTTL int // 上下文过期时间(秒)
triggerTurns int // 触发记忆提取的轮数
}
// NewChatService 创建对话服务
func NewChatService(
llmService *LLMService,
personaService *PersonaService,
memoryService *MemoryService,
auditService *AuditService,
contextTTL int,
triggerTurns int,
) *ChatService {
return &ChatService{
LLMService: llmService,
personaService: personaService,
memoryService: memoryService,
auditService: auditService,
contextTTL: contextTTL,
triggerTurns: triggerTurns,
}
}
// GetHistory 获取对话历史
func (s *ChatService) GetHistory(ctx context.Context, sessionID string) ([]model.Message, error) {
return s.memoryService.GetContext(ctx, sessionID)
}
// SaveContext 保存对话上下文
func (s *ChatService) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error {
return s.memoryService.SaveContext(ctx, sessionID, messages, personaID)
}
// ExtractMemory 提取记忆
func (s *ChatService) ExtractMemory(ctx context.Context, userID int64, recentMessages []model.Message) error {
if ShouldExtractMemory(recentMessages, s.triggerTurns) {
return s.memoryService.ExtractMemory(ctx, userID, recentMessages)
}
return nil
}
// GetWelcomeMessage 获取欢迎消息
func (s *ChatService) GetWelcomeMessage(sessionID string, userID int64, starID int64) string {
// 默认欢迎消息
return "亲爱的你来辣 ~~"
}

View File

@ -0,0 +1,312 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/topfans/backend/services/aiChatService/model"
)
// StreamReader 流式读取器接口
type StreamReader interface {
Next() (content string, done bool, err error)
Close() error
}
// SensitiveContentError 敏感内容错误
type SensitiveContentError struct {
Message string
}
func (e *SensitiveContentError) Error() string {
return e.Message
}
// LLMService 大模型服务
type LLMService struct {
miniMaxClient *http.Client
qwenClient *http.Client
miniMaxURL string
miniMaxKey string
miniMaxModel string
qwenURL string
qwenKey string
qwenModel string
failCount int
fallbackCount int
useBackup bool
mu sync.RWMutex
}
// MiniMaxMessage MiniMax 消息格式
type MiniMaxMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// MiniMaxRequest MiniMax 请求格式
type MiniMaxRequest struct {
Model string `json:"model"`
Messages []MiniMaxMessage `json:"messages"`
Stream bool `json:"stream"`
}
// MiniMaxResponse MiniMax 响应格式
type MiniMaxResponse struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}
// MiniMaxSSERecord SSE 格式记录
type MiniMaxSSERecord struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}
// NewLLMService 创建大模型服务
func NewLLMService(miniMaxURL, miniMaxKey, miniMaxModel, qwenURL, qwenKey, qwenModel string) *LLMService {
return &LLMService{
miniMaxClient: &http.Client{Timeout: 60 * time.Second},
qwenClient: &http.Client{Timeout: 60 * time.Second},
miniMaxURL: miniMaxURL,
miniMaxKey: miniMaxKey,
miniMaxModel: miniMaxModel,
qwenURL: qwenURL,
qwenKey: qwenKey,
qwenModel: qwenModel,
}
}
// StreamChat 流式聊天
func (s *LLMService) StreamChat(ctx context.Context, messages []model.Message) (StreamReader, error) {
// Qwen fallback 已禁用
return s.streamChatMiniMax(ctx, messages)
}
// StreamChatWithBackup 禁用备用模型
func (s *LLMService) StreamChatWithBackup(ctx context.Context, messages []model.Message) (StreamReader, error) {
// Qwen fallback 已禁用,直接返回 MiniMax
return s.streamChatMiniMax(ctx, messages)
}
// streamChatMiniMax MiniMax 流式聊天
func (s *LLMService) streamChatMiniMax(ctx context.Context, messages []model.Message) (*MiniMaxStreamReader, error) {
reqMessages := make([]MiniMaxMessage, len(messages))
for i, m := range messages {
reqMessages[i] = MiniMaxMessage{
Role: m.Role,
Content: m.Content,
}
}
reqBody := MiniMaxRequest{
Model: s.miniMaxModel,
Messages: reqMessages,
Stream: true,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.miniMaxURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+s.miniMaxKey)
resp, err := s.miniMaxClient.Do(req)
if err != nil {
s.incrementFailCount()
return nil, fmt.Errorf("failed to call MiniMax: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
bodyStr := string(body)
// 检查是否是敏感内容错误 (1027)
if strings.Contains(bodyStr, "1027") || strings.Contains(bodyStr, "new_sensitive") {
return nil, &SensitiveContentError{Message: "content blocked by MiniMax safety filter"}
}
s.incrementFailCount()
// 如果是 401/403 等错误,切换到备用模型
if resp.StatusCode == 401 || resp.StatusCode == 403 {
s.SwitchToBackup()
}
return nil, fmt.Errorf("MiniMax API error: %d, body: %s", resp.StatusCode, bodyStr)
}
s.resetFailCount()
return &MiniMaxStreamReader{
reader: resp.Body,
decoder: NewSSEDecoder(resp.Body),
}, nil
}
// streamChatQwen 通义流式聊天
func (s *LLMService) streamChatQwen(ctx context.Context, messages []model.Message) (*MiniMaxStreamReader, error) {
reqMessages := make([]MiniMaxMessage, len(messages))
for i, m := range messages {
reqMessages[i] = MiniMaxMessage{
Role: m.Role,
Content: m.Content,
}
}
reqBody := MiniMaxRequest{
Model: s.qwenModel,
Messages: reqMessages,
Stream: true,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.qwenURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+s.qwenKey)
resp, err := s.qwenClient.Do(req)
if err != nil {
s.incrementFailCount()
return nil, fmt.Errorf("failed to call Qwen: %w", err)
}
if resp.StatusCode != http.StatusOK {
s.incrementFailCount()
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("Qwen API error: %d, body: %s", resp.StatusCode, string(body))
}
s.resetFailCount()
return &MiniMaxStreamReader{
reader: resp.Body,
decoder: NewSSEDecoder(resp.Body),
}, nil
}
func (s *LLMService) incrementFailCount() {
s.mu.Lock()
s.failCount++
if s.failCount >= 3 {
s.useBackup = true
}
s.mu.Unlock()
}
func (s *LLMService) resetFailCount() {
s.mu.Lock()
s.failCount = 0
s.mu.Unlock()
}
func (s *LLMService) SwitchToBackup() {
s.mu.Lock()
s.useBackup = true
s.mu.Unlock()
}
// MiniMaxStreamReader MiniMax 流式读取器
type MiniMaxStreamReader struct {
reader io.ReadCloser
decoder *SSEDecoder
buffer string
}
func (r *MiniMaxStreamReader) Next() (content string, done bool, err error) {
for {
line, err := r.decoder.Next()
if err != nil {
r.Close()
return "", true, err
}
if line == "" {
continue
}
// SSE data: 格式
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return "", true, nil
}
var record MiniMaxSSERecord
if err := json.Unmarshal([]byte(data), &record); err != nil {
continue
}
if len(record.Choices) > 0 && record.Choices[0].Delta.Content != "" {
return record.Choices[0].Delta.Content, false, nil
}
}
}
}
func (r *MiniMaxStreamReader) Close() error {
return r.reader.Close()
}
// SSEDecoder SSE 解码器
type SSEDecoder struct {
reader io.Reader
}
func NewSSEDecoder(reader io.Reader) *SSEDecoder {
return &SSEDecoder{reader: reader}
}
func (d *SSEDecoder) Next() (string, error) {
var line []byte
buf := make([]byte, 1)
for {
n, err := d.reader.Read(buf)
if err != nil {
return "", err
}
if n == 0 {
continue
}
if buf[0] == '\n' {
break
}
line = append(line, buf[0])
}
return strings.TrimRight(string(line), "\r"), nil
}

View File

@ -0,0 +1,211 @@
package service
import (
"context"
"strings"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/aiChatService/model"
"github.com/topfans/backend/services/aiChatService/repository"
"go.uber.org/zap"
)
// MemoryService 记忆服务
type MemoryService struct {
shortTermRepo repository.ShortTermMemoryRepository
longTermRepo repository.LongTermMemoryRepository
}
// NewMemoryService 创建记忆服务
func NewMemoryService(shortTermRepo repository.ShortTermMemoryRepository, longTermRepo repository.LongTermMemoryRepository) *MemoryService {
return &MemoryService{
shortTermRepo: shortTermRepo,
longTermRepo: longTermRepo,
}
}
// SaveContext 保存短期上下文
func (s *MemoryService) SaveContext(ctx context.Context, sessionID string, messages []model.Message, personaID string) error {
return s.shortTermRepo.SaveContext(ctx, sessionID, messages, personaID)
}
// GetContext 获取短期上下文
func (s *MemoryService) GetContext(ctx context.Context, sessionID string) ([]model.Message, error) {
return s.shortTermRepo.GetContext(ctx, sessionID)
}
// RecallMemories 召回相关记忆
func (s *MemoryService) RecallMemories(ctx context.Context, userID int64, userInput string, limit int) (string, error) {
// 从用户输入提取关键词
keywords := extractKeywords(userInput)
// 查询长期记忆
memories, err := s.longTermRepo.GetMemories(ctx, userID, keywords, limit)
if err != nil {
return "", err
}
if len(memories) == 0 {
return "", nil
}
// 组装记忆文本
var builder strings.Builder
builder.WriteString("# 用户核心记忆\n")
for _, m := range memories {
builder.WriteString("- ")
builder.WriteString(m.Content)
builder.WriteString("\n")
}
return builder.String(), nil
}
// ExtractMemory 提取记忆每5轮对话触发一次
func (s *MemoryService) ExtractMemory(ctx context.Context, userID int64, recentMessages []model.Message) error {
// 从最近的用户消息中提取关键词
var userMessages []string
for i := len(recentMessages) - 1; i >= 0 && len(userMessages) < 5; i-- {
if recentMessages[i].Role == "user" {
userMessages = append(userMessages, recentMessages[i].Content)
}
}
if len(userMessages) == 0 {
return nil
}
// 简单关键词提取
keywords := extractKeywordsFromMessages(userMessages)
// 生成记忆摘要
content := summarizeMessages(userMessages)
if content == "" {
return nil
}
// 保存到长期记忆
memory := &model.UserMemory{
UserID: userID,
Content: content,
Keywords: keywords,
Weight: 50,
}
return s.longTermRepo.SaveMemory(ctx, memory)
}
// GetUserMemories 获取用户所有长期记忆
func (s *MemoryService) GetUserMemories(ctx context.Context, userID int64) ([]model.UserMemory, error) {
return s.longTermRepo.GetMemoriesByUserID(ctx, userID)
}
// extractKeywords 从用户输入提取关键词
func extractKeywords(text string) []string {
// 简单实现按标点符号分割检查长度大于2的词
var keywords []string
words := strings.FieldsFunc(text, func(r rune) bool {
return r == ',' || r == '.' || r == '!' || r == '?' || r == '' || r == '。' || r == '' || r == '' || r == ' ' || r == '\n'
})
for _, word := range words {
word = strings.TrimSpace(word)
if len(word) >= 2 {
keywords = append(keywords, word)
}
}
return keywords
}
// extractKeywordsFromMessages 从多条消息提取关键词
func extractKeywordsFromMessages(messages []string) []string {
var allKeywords []string
seen := make(map[string]bool)
for _, msg := range messages {
keywords := extractKeywords(msg)
for _, k := range keywords {
if !seen[k] {
seen[k] = true
allKeywords = append(allKeywords, k)
}
}
}
// 只保留前5个关键词
if len(allKeywords) > 5 {
allKeywords = allKeywords[:5]
}
return allKeywords
}
// summarizeMessages 生成记忆摘要
func summarizeMessages(messages []string) string {
if len(messages) == 0 {
return ""
}
// 简单实现拼接前3条消息的核心内容
var summary strings.Builder
count := 0
for _, msg := range messages {
if count >= 3 {
break
}
// 截断过长的消息
if len(msg) > 100 {
msg = msg[:100] + "..."
}
summary.WriteString(msg)
summary.WriteString("; ")
count++
}
result := strings.TrimSpace(summary.String())
if len(result) > 200 {
result = result[:200] + "..."
}
return result
}
// FormatMemoriesForPrompt 将记忆格式化为 prompt 片段
func FormatMemoriesForPrompt(memories []model.UserMemory) string {
if len(memories) == 0 {
return ""
}
var builder strings.Builder
builder.WriteString("# 用户核心记忆\n")
for _, m := range memories {
builder.WriteString("- ")
builder.WriteString(m.Content)
builder.WriteString("\n")
}
return builder.String()
}
// GetTurnCount 计算对话轮数
func GetTurnCount(messages []model.Message) int {
turns := 0
for i := 0; i < len(messages)-1; i += 2 {
if i+1 < len(messages) && messages[i].Role == "user" && messages[i+1].Role == "assistant" {
turns++
}
}
return turns
}
// ShouldExtractMemory 判断是否应该触发记忆提取
func ShouldExtractMemory(messages []model.Message, triggerTurns int) bool {
turns := GetTurnCount(messages)
logger.Logger.Info("ShouldExtractMemory check",
zap.Int("turns", turns),
zap.Int("trigger_turns", triggerTurns),
zap.Bool("should_extract", turns >= triggerTurns),
)
return turns >= triggerTurns
}

View File

@ -0,0 +1,68 @@
package service
import (
"context"
"github.com/google/uuid"
"github.com/topfans/backend/services/aiChatService/model"
"github.com/topfans/backend/services/aiChatService/repository"
)
// PersonaService 人设服务
type PersonaService struct {
personaRepo repository.PersonaRepository
}
// NewPersonaService 创建人设服务
func NewPersonaService(personaRepo repository.PersonaRepository) *PersonaService {
return &PersonaService{
personaRepo: personaRepo,
}
}
// GetPersonas 获取用户的所有人设
func (s *PersonaService) GetPersonas(ctx context.Context, userID int64) ([]model.PersonaInfo, error) {
personas, err := s.personaRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, err
}
infos := make([]model.PersonaInfo, len(personas))
for i, p := range personas {
infos[i] = p.ToPersonaInfo()
}
return infos, nil
}
// GetPersona 获取指定人设
func (s *PersonaService) GetPersona(ctx context.Context, personaID string) (*model.Persona, error) {
id, err := uuid.Parse(personaID)
if err != nil {
return nil, model.ErrInvalidRequest
}
return s.personaRepo.GetByID(ctx, id)
}
// GetDefaultPersona 获取用户默认人设
func (s *PersonaService) GetDefaultPersona(ctx context.Context, userID int64) (*model.Persona, error) {
return s.personaRepo.GetDefaultByUserID(ctx, userID)
}
// GetPersonaOrDefault 获取人设,如果 personaID 为空则获取默认人设
func (s *PersonaService) GetPersonaOrDefault(ctx context.Context, userID int64, personaID string) (*model.Persona, error) {
// 如果指定了 personaID优先使用指定人设
if personaID != "" {
id, err := uuid.Parse(personaID)
if err != nil {
return nil, model.ErrInvalidRequest
}
persona, err := s.personaRepo.GetByID(ctx, id)
if err == nil {
return persona, nil
}
// 如果人设不存在fallback 到默认人设
}
// 获取或创建默认人设
return s.personaRepo.EnsureDefaultPersona(ctx, userID)
}

View File

@ -0,0 +1,180 @@
package service
import (
"github.com/topfans/backend/services/aiChatService/model"
)
// Token 限制配置
const (
MaxTotalTokens = 32000 // 总 Token 上限
MaxHistoryTokens = 24000 // 对话历史最大 Token
MaxSystemTokens = 4000 // System Prompt 最大 Token
MaxMemoryTokens = 2000 // 记忆召回最大 Token
ReservedTokens = 2000 // 保留空间
MinHistoryMessages = 4 // 最少保留消息对数
MaxSingleMessageTokens = 4000 // 单条消息最大 Token
)
// Tokenizer Token 计算器
type Tokenizer struct{}
// EstimateTokens 估算 Token 数量
func (t *Tokenizer) EstimateTokens(text string) int {
var count int
for _, r := range text {
switch {
case r >= 0x4e00 && r <= 0x9fff: // 中文
count += 2
case r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z': // 英文
count += 1
case r >= '0' && r <= '9':
count += 1
case r < 128: // ASCII 符号
count += 1
default:
count += 2
}
}
return count
}
// EstimateMessagesTokens 估算消息列表的总 Token
func (t *Tokenizer) EstimateMessagesTokens(messages []model.Message) int {
var total int
for _, m := range messages {
total += t.EstimateTokens(m.Role) + t.EstimateTokens(m.Content) + 10
}
return total
}
// BuildPrompt 组装 Prompt
func BuildPrompt(
systemPrompt string,
userCoreInfo string,
history []model.Message,
userInput string,
tokenizer *Tokenizer,
) ([]model.Message, int) {
// 1. 计算各部分 Token
systemTokens := tokenizer.EstimateTokens(systemPrompt)
memoryTokens := tokenizer.EstimateTokens(userCoreInfo)
// 2. 预留空间计算
reserved := ReservedTokens
if systemTokens > MaxSystemTokens {
reserved += systemTokens - MaxSystemTokens
}
// 3. 计算可用于对话历史的 Token
availableTokens := MaxHistoryTokens - memoryTokens - reserved
if availableTokens < 500 {
availableTokens = 500
}
// 4. 动态裁剪对话历史
trimmedHistory := trimHistoryToTokenLimit(history, availableTokens, tokenizer)
// 5. 组装最终消息
messages := []model.Message{
{Role: "system", Content: systemPrompt},
}
if userCoreInfo != "" {
messages = append(messages, model.Message{
Role: "system",
Content: "# 用户核心记忆\n" + userCoreInfo,
})
}
messages = append(messages, trimmedHistory...)
messages = append(messages, model.Message{Role: "user", Content: userInput})
// 6. 最终 Token 统计
totalTokens := tokenizer.EstimateMessagesTokens(messages)
return messages, totalTokens
}
// trimHistoryToTokenLimit 裁剪历史消息至 Token 限制内
func trimHistoryToTokenLimit(history []model.Message, maxTokens int, tokenizer *Tokenizer) []model.Message {
if len(history) == 0 {
return history
}
currentTokens := tokenizer.EstimateMessagesTokens(history)
if currentTokens <= maxTokens {
return history
}
result := make([]model.Message, 0)
var usedTokens int
// 从最新开始保留
for i := len(history) - 1; i >= 0; i -= 2 {
msgToken := tokenizer.EstimateTokens(history[i].Content) + 10
prevToken := 0
if i > 0 {
prevToken = tokenizer.EstimateTokens(history[i-1].Content) + 10
}
pairTokens := msgToken + prevToken
if len(result)/2 >= MinHistoryMessages && usedTokens+pairTokens > maxTokens {
break
}
if i > 0 {
result = append([]model.Message{history[i-1], history[i]}, result...)
} else {
result = append([]model.Message{history[i]}, result...)
}
usedTokens += pairTokens
}
return result
}
// truncateMessageIfNeeded 截断超长单条消息
func truncateMessageIfNeeded(content string, maxTokens int, tokenizer *Tokenizer) string {
if tokenizer.EstimateTokens(content) <= maxTokens {
return content
}
runes := []rune(content)
lo, hi := 0, len(runes)
for lo < hi {
mid := (lo + hi + 1) / 2
if tokenizer.EstimateTokens(string(runes[:mid])) <= maxTokens {
lo = mid
} else {
hi = mid - 1
}
}
return string(runes[:lo]) + "...(已截断)"
}
// EstimateTurnCount 估算对话轮数
func EstimateTurnCount(messages []model.Message) int {
turns := 0
for i := 0; i < len(messages)-1; i += 2 {
if messages[i].Role == "user" && messages[i+1].Role == "assistant" {
turns++
}
}
return turns
}
// IsNoNeedLLMCall 判断是否不需要调用大模型
func IsNoNeedLLMCall(input string) bool {
// 纯符号或数字
symbolOnly := true
for _, r := range input {
if r != ' ' && r != '\n' && (r < '0' || r > '9') && r != '?' && r != '' && r != '.' && r != '。' {
symbolOnly = false
break
}
}
return symbolOnly
}

View File

@ -1,13 +1,33 @@
<script>
import { getGlobalSocket } from '@/utils/socket'
export default {
onLaunch: function() {
console.log('App Launch')
// AI Chat
},
onShow: function() {
console.log('App Show')
},
onHide: function() {
console.log('App Hide')
// WebSocket
this.closeWebSocket()
},
methods: {
initWebSocket() {
const token = uni.getStorageSync('access_token')
if (token) {
console.log('初始化全局 WebSocket 连接')
const globalSocket = getGlobalSocket()
globalSocket.init(token)
}
},
closeWebSocket() {
console.log('关闭全局 WebSocket 连接')
const globalSocket = getGlobalSocket()
globalSocket.closeAll()
}
}
}
</script>

View File

@ -39,11 +39,11 @@
</view>
<!-- AI角色对话气泡 -->
<view class="dialog-area">
<view class="dialog-bubble">
<view class="dialog-area" v-if="aiMessage">
<!-- <view class="dialog-bubble">
<image class="bubble-bg" src="/static/AIimg/duihuakuang.png" mode="widthFix" style="width: 320rpx;" />
<text class="bubble-text">亲爱的你来辣 ~~</text>
</view>
<text class="bubble-text">{{ aiMessage }}</text>
</view> -->
</view>
<!-- AI角色毛绒小怪兽占位区域 -->
@ -54,72 +54,305 @@
</view> -->
</view>
<!-- 聊天消息区域 -->
<scroll-view class="chat-messages" scroll-y :scroll-top="scrollTop" :scroll-into-view="scrollIntoViewId"
:show-scrollbar="false" scroll-with-animation>
<view v-for="(msg, index) in messages" :key="index" :class="['message-item', msg.role]">
<view v-if="msg.role === 'assistant'" class="message-bubble assistant-bubble">
<text class="message-text">{{ msg.content }}</text>
</view>
<view v-else class="message-bubble user-bubble">
<text class="message-text">{{ msg.content }}</text>
</view>
</view>
<view v-if="isTyping" class="message-item assistant">
<view class="message-bubble assistant-bubble">
<text class="message-text typing">...</text>
</view>
</view>
<view id="scroll-bottom-view"></view>
</scroll-view>
<!-- 底部输入框 -->
<view class="bottom-bar">
<view class="input-wrapper">
<input
class="chat-input"
v-model="inputText"
placeholder="发送消息给角角"
placeholder-class="input-placeholder"
confirm-type="send"
@confirm="handleSend"
/>
<view
class="send-btn"
:class="{ 'send-btn-disabled': !inputText.trim() }"
@click="inputText.trim() && handleSend()"
>
<input class="chat-input" v-model="inputText" placeholder="发送消息给角角"
placeholder-class="input-placeholder" confirm-type="send" @confirm="handleSend"
:disabled="isTyping" />
<view class="send-btn" :class="{ 'send-btn-disabled': !inputText.trim() || isTyping }"
@click="inputText.trim() && !isTyping && handleSend()">
<text class="send-icon">发送</text>
</view>
</view>
<!-- <view class="add-btn" @click="handleAdd">
<text class="add-icon">+</text>
</view> -->
</view>
</view>
</template>
<script setup>
import { ref } from 'vue';
<script>
import { getAiChatSocket, closeAiChatSocket, isAiChatClosing, resetAiChatClosing } from '@/utils/socket'
const inputText = ref('');
export default {
data() {
return {
inputText: '',
messages: [],
aiMessage: '亲爱的你来辣 ~~',
isTyping: false,
scrollTop: 0,
scrollIntoViewId: '',
sessionId: '',
backIconColor: '#fff',
currentAssistantMessage: '',
streamTimeout: null,
streamTimeoutMs: 15000, // 15
isErrorProcessing: false //
}
},
onLoad() {
this.initSession()
this.initWebSocket()
},
onReady() {
console.log('[AI Chat] onReady called')
},
onUnload() {
this.clearStreamTimeout()
this.closeWebSocket()
},
methods: {
initSession() {
// sessionId: userId_starId
const userRaw = uni.getStorageSync('user')
const starId = uni.getStorageSync('star_id')
const handleClose = () => {
//
const pages = getCurrentPages();
if (pages.length > 1) {
//
uni.navigateBack();
} else {
// square
uni.reLaunch({
url: '/pages/square/square'
});
// user JSON
let user = userRaw
if (typeof user === 'string') {
try {
user = JSON.parse(user)
} catch (e) {
console.error('[AI Chat] Failed to parse user:', e)
}
}
const uid = user?.uid || user?.["uid"]
if (uid && starId) {
this.sessionId = uid + '_' + starId
} else {
this.sessionId = 'temp_' + Date.now()
}
},
initWebSocket() {
const token = uni.getStorageSync('access_token')
if (!token) {
uni.showToast({ title: '请先登录', icon: 'none' })
return
}
const socket = getAiChatSocket()
console.log('[AI Chat] initWebSocket called, socket already connected:', socket.isConnected)
// -
socket.setOnMessageCallback((data) => {
console.log('[AI Chat] Message received:', data.type, data.session_id)
// session_id session_id
if (data.session_id === this.sessionId || !data.session_id) {
this.handleStreamMessage(data)
} else {
console.log('[AI Chat] Session ID mismatch:', data.session_id, '!=', this.sessionId)
}
})
//
socket.setOnErrorCallback((data) => {
console.error('AI Chat error:', data)
//
if (this.isErrorProcessing) {
return
}
this.isErrorProcessing = true
this.isTyping = false
uni.showToast({ title: data.message || data.error || '发生错误', icon: 'none' })
this.aiMessage = '亲爱的你来辣 ~~'
//
setTimeout(() => {
this.isErrorProcessing = false
}, 1000)
})
//
socket.setOnConnectCallback(() => {
console.log('[AI Chat] Connected callback triggered')
socket.initSession(this.sessionId)
this.loadHistory()
})
//
resetAiChatClosing()
//
socket.connect(token)
},
handleStreamMessage(data) {
//
if (data.error) {
console.log('[AI Chat] Error detected, setting isTyping=false')
this.clearStreamTimeout()
//
if (this.isErrorProcessing) {
return
}
this.isErrorProcessing = true
this.isTyping = false
this.currentAssistantMessage = ''
uni.showToast({ title: data.content || data.error || '发生错误', icon: 'none' })
this.aiMessage = '亲爱的你来辣 ~~'
setTimeout(() => {
this.isErrorProcessing = false
}, 1000)
return
}
if (data.is_end) {
//
this.clearStreamTimeout()
//
this.isTyping = false
// messages
if (this.currentAssistantMessage) {
this.messages.push({
role: 'assistant',
content: this.currentAssistantMessage
})
this.currentAssistantMessage = ''
} else if (data.content) {
// init_session
this.messages.push({
role: 'assistant',
content: data.content
})
}
this.aiMessage = '亲爱的你来辣 ~~'
this.scrollToBottom()
} else {
//
if (!this.currentAssistantMessage) {
this.currentAssistantMessage = ''
}
this.currentAssistantMessage += data.content
//
this.aiMessage = this.currentAssistantMessage
}
},
loadHistory() {
const token = uni.getStorageSync('access_token')
if (!token) return
const socket = getAiChatSocket()
// -
socket.setOnHistoryCallback((data) => {
if ((data.session_id === this.sessionId || !data.session_id) && data.history) {
this.messages = data.history
}
this.$nextTick(() => {
this.scrollToBottom()
})
})
socket.getHistory(this.sessionId, 20)
},
handleSend() {
if (!this.inputText.trim() || this.isTyping) return
const message = this.inputText.trim()
const token = uni.getStorageSync('access_token')
if (!token) {
uni.showToast({ title: '请先登录', icon: 'none' })
return
}
//
this.messages.push({
role: 'user',
content: message
})
this.inputText = ''
this.scrollToBottom()
//
const socket = getAiChatSocket()
this.isTyping = true
this.currentAssistantMessage = ''
//
this.startStreamTimeout()
socket.sendMessage(message, this.sessionId)
},
startStreamTimeout() {
this.clearStreamTimeout()
this.streamTimeout = setTimeout(() => {
console.log('[AI Chat] Stream timeout, using fallback response')
this.isTyping = false
if (this.currentAssistantMessage) {
this.messages.push({
role: 'assistant',
content: this.currentAssistantMessage
})
} else {
//
this.messages.push({
role: 'assistant',
content: '抱歉,服务响应有点慢,请稍后再试~'
})
}
this.currentAssistantMessage = ''
this.aiMessage = '亲爱的你来辣 ~~'
this.scrollToBottom()
}, this.streamTimeoutMs)
},
clearStreamTimeout() {
if (this.streamTimeout) {
clearTimeout(this.streamTimeout)
this.streamTimeout = null
}
},
scrollToBottom() {
this.scrollIntoViewId = ''
this.$nextTick(() => {
this.scrollIntoViewId = 'scroll-bottom-view'
})
},
closeWebSocket() {
console.log('[AI Chat] Closing WebSocket connection')
//
const socket = getAiChatSocket()
socket.setClosing(true)
closeAiChatSocket()
},
handleClose() {
const pages = getCurrentPages()
if (pages.length > 1) {
uni.navigateBack()
} else {
uni.reLaunch({
url: '/pages/square/square'
})
}
},
handleDressup() {
uni.showToast({ title: '装扮功能开发中', icon: 'none' })
},
handleScene() {
uni.showToast({ title: '场景功能开发中', icon: 'none' })
},
handleHistory() {
uni.showToast({ title: '追星历程开发中', icon: 'none' })
}
}
};
const handleDressup = () => {
uni.showToast({ title: '装扮功能开发中', icon: 'none' });
};
const handleScene = () => {
uni.showToast({ title: '场景功能开发中', icon: 'none' });
};
const handleHistory = () => {
uni.showToast({ title: '追星历程开发中', icon: 'none' });
};
const handleSend = () => {
if (!inputText.value.trim()) return;
uni.showToast({ title: `发送:${inputText.value}`, icon: 'none' });
inputText.value = '';
};
const handleAdd = () => {
uni.showToast({ title: '更多功能开发中', icon: 'none' });
};
}
</script>
<style scoped>
@ -144,12 +377,15 @@ const handleAdd = () => {
/* 关闭按钮 */
.close-btn {
position: absolute;
top: 80rpx;
left: 24rpx;
width: 80rpx;
height: 80rpx;
display: flex;
align-items: center;
justify-content: center;
z-index: 50;
}
.nav-back {
@ -302,6 +538,71 @@ const handleAdd = () => {
text-shadow: 0 2rpx 8rpx rgba(0, 0, 0, 0.5);
}
/* 聊天消息区域 */
.chat-messages {
position: absolute;
bottom: 192rpx;
left: 0;
right: 0;
height: 544rpx;
padding: 20rpx;
z-index: 15;
box-sizing: border-box;
}
.message-item {
display: flex;
margin-bottom: 20rpx;
}
.message-item.user {
justify-content: flex-end;
}
.message-item.assistant {
justify-content: flex-start;
}
.message-bubble {
max-width: 70%;
padding: 20rpx;
border-radius: 20rpx;
font-size: 28rpx;
line-height: 1.4;
}
.user-bubble {
background: linear-gradient(135deg, #ff9de2, #c97bff);
color: #fff;
border-bottom-right-radius: 8rpx;
}
.assistant-bubble {
background: rgba(255, 255, 255, 0.9);
color: #5a3060;
border-bottom-left-radius: 8rpx;
}
.message-text {
word-break: break-all;
}
.typing {
animation: blink 1s infinite;
}
@keyframes blink {
0%,
100% {
opacity: 1;
}
50% {
opacity: 0.3;
}
}
/* 角色区域 */
.character-area {
position: absolute;

View File

@ -5,31 +5,60 @@ const DEV_BASE = 'http://192.168.110.60:8080' // 开发环境
const PROD_BASE = 'http://101.132.250.62:8080' // 生产环境
const HEALTH_URL = DEV_BASE + '/health'
// 默认使用生产地址
let baseURL = PROD_BASE
// 启动时探测开发环境是否可用(异步,不阻塞后续逻辑)
uni.request({
url: HEALTH_URL,
method: 'GET',
timeout: 2000,
success: (res) => {
if (res.statusCode === 200) {
baseURL = DEV_BASE // 开发环境可用,切换到开发地址
console.log('[API] 使用开发环境地址:', DEV_BASE)
}
},
fail: () => {
console.log('[API] 开发环境不可用,使用生产环境地址:', PROD_BASE)
}
})
// 是否使用模拟数据(开发调试时设为 true后端API准备好后改为 false
const USE_MOCK_API = false
// 环境检测状态0=检测中, 1=开发环境, 2=生产环境
let envStatus = 0
let baseURL = PROD_BASE
// 环境检测 Promise确保 getApiBaseUrl / getWebSocketBaseUrl 等待检测完成
const envReadyPromise = new Promise((resolve) => {
uni.request({
url: HEALTH_URL,
method: 'GET',
timeout: 2000,
success: (res) => {
if (res.statusCode === 200) {
baseURL = DEV_BASE
envStatus = 1
console.log('[API] 使用开发环境地址:', DEV_BASE)
} else {
envStatus = 2
console.log('[API] 开发环境返回非200使用生产环境地址:', PROD_BASE)
}
resolve(envStatus)
},
fail: () => {
envStatus = 2
console.log('[API] 开发环境不可用,使用生产环境地址:', PROD_BASE)
resolve(envStatus)
}
})
})
/** 等待环境检测完成(返回 'dev' | 'prod' */
export async function waitForEnvReady() {
await envReadyPromise
return envStatus === 1 ? 'dev' : 'prod'
}
/** 网关根地址(供 uni.uploadFile 等无法走 request 封装的场景拼接完整 URL */
export function getApiBaseUrl() {
return String(baseURL).replace(/\/+$/, '')
export async function getApiBaseUrl() {
await envReadyPromise
return String(baseURL).replace(/\/+$/, '')
}
/** 获取 WebSocket 基础地址(将 http:// 替换为 ws:// */
export async function getWebSocketBaseUrl() {
await envReadyPromise
const httpUrl = String(baseURL).replace(/\/+$/, '')
return httpUrl.replace(/^http:/, 'ws:').replace(/^https:/, 'wss:')
}
// 兼容旧代码:同步版本在环境检测完成前返回默认值(生产地址)
export function getApiBaseUrlSync() {
return String(baseURL).replace(/\/+$/, '')
}
// 模拟网络延迟

View File

@ -0,0 +1,187 @@
import SocketManager from './SocketManager'
/**
* AI Chat WebSocket 连接
* 继承自 SocketManager添加 AI Chat 特定的业务逻辑
*/
class AiChatSocket extends SocketManager {
constructor() {
super({
serviceName: 'AiChat',
baseUrl: 'ws://gateway:8080',
path: '/ws/ai-chat',
reconnectInterval: 3000,
heartbeatInterval: 30000,
maxReconnectAttempts: 5
})
// AI Chat 特定回调
this.onMessageCallback = null
this.onHistoryCallback = null
this.onPersonasCallback = null
this.onErrorCallback = null
this.onConnectCallback = null
// 注册 AI Chat 特定的消息处理器
this._registerAiChatHandlers()
}
setClosing(val) {
this.isClosing = val
}
_registerAiChatHandlers() {
// 流式消息
this.registerHandler('message', (data) => {
if (this.onMessageCallback) {
this.onMessageCallback(data)
}
})
// 历史消息响应
this.registerHandler('history_response', (data) => {
if (this.onHistoryCallback) {
this.onHistoryCallback(data)
}
})
// 人设列表响应
this.registerHandler('personas_response', (data) => {
if (this.onPersonasCallback) {
this.onPersonasCallback(data)
}
})
}
/**
* 连接到 AI Chat 服务
*/
async connect(token) {
await super.connect(token, '/ws/ai-chat')
}
/**
* 发送消息
*/
sendMessage(message, sessionId, personaId = '') {
return this.send({
action: 'send_message',
session_id: sessionId,
message: message,
persona_id: personaId
})
}
/**
* 获取历史记录
*/
getHistory(sessionId, limit = 20) {
return this.send({
action: 'get_history',
session_id: sessionId,
limit: limit
})
}
/**
* 初始化会话
*/
initSession(sessionId) {
return this.send({
action: 'init_session',
session_id: sessionId
})
}
/**
* 获取人设列表
*/
getPersonas() {
return this.send({
action: 'get_personas'
})
}
/**
* 设置消息回调
*/
setOnMessageCallback(callback) {
this.onMessageCallback = callback
this.off('message', this._messageHandler)
this._messageHandler = (data) => {
if (this.onMessageCallback) this.onMessageCallback(data)
}
this.on('message', this._messageHandler)
}
/**
* 设置历史记录回调
*/
setOnHistoryCallback(callback) {
this.onHistoryCallback = callback
}
/**
* 设置人设列表回调
*/
setOnPersonasCallback(callback) {
this.onPersonasCallback = callback
}
/**
* 设置错误回调
*/
setOnErrorCallback(callback) {
this.onErrorCallback = callback
this.off('error', this._errorHandler)
this._errorHandler = (data) => {
if (callback) callback(data)
}
this.on('error', this._errorHandler)
}
/**
* 设置连接成功回调
*/
setOnConnectCallback(callback) {
// 移除旧的 connect listener防止重复注册
if (this._connectHandler) {
this.off('connect', this._connectHandler)
}
this.onConnectCallback = callback
this._connectHandler = () => {
if (callback) callback()
}
this.on('connect', this._connectHandler)
}
}
// 单例模式
let aiChatInstance = null
let closing = false // 标记是否主动关闭
export function getAiChatSocket() {
if (!aiChatInstance) {
aiChatInstance = new AiChatSocket()
}
return aiChatInstance
}
export function closeAiChatSocket() {
if (aiChatInstance) {
closing = true
aiChatInstance.close()
// 不再设为 null保持实例以防止旧定时器创建新连接
// 下次 getAiChatSocket() 继续使用同一实例
}
}
export function isAiChatClosing() {
return closing
}
export function resetAiChatClosing() {
closing = false
}
export default AiChatSocket

View File

@ -0,0 +1,50 @@
import { getAiChatSocket } from './AiChatSocket'
/**
* 全局 WebSocket 管理器
* 统一管理多个服务的 WebSocket 连接
*/
class GlobalSocketManager {
constructor() {
this.sockets = {} // serviceName -> SocketManager
this.token = null
this.isAllConnected = false
}
/**
* 初始化所有连接
*/
init(token) {
this.token = token
this._initAiChat()
// Future: this._initNotification()
}
async _initAiChat() {
const aiChat = getAiChatSocket()
aiChat.on('connect', () => console.log('AI Chat connected'))
aiChat.on('error', (err) => console.error('AI Chat error:', err))
await aiChat.connect(this.token)
this.sockets['ai_chat'] = aiChat
}
getSocket(serviceName) {
return this.sockets[serviceName]
}
closeAll() {
Object.values(this.sockets).forEach(socket => socket.close())
this.sockets = {}
}
}
let globalInstance = null
export function getGlobalSocket() {
if (!globalInstance) {
globalInstance = new GlobalSocketManager()
}
return globalInstance
}
export default GlobalSocketManager

View File

@ -0,0 +1,345 @@
import { getWebSocketBaseUrl } from '../api'
/**
* WebSocket 管理器
* 支持多个 WebSocket 连接自动管理鉴权心跳重连
*/
class SocketManager {
constructor(options = {}) {
this.serviceName = options.serviceName || 'unknown'
this.baseUrl = 'ws://127.0.0.1:8080' // 占位符,等 connect() 时再获取真实地址
this.token = null
this.socket = null
this.heartbeatTimer = null
this.reconnectTimer = null
this.reconnectInterval = options.reconnectInterval || 3000
this.heartbeatInterval = options.heartbeatInterval || 30000
this.maxReconnectAttempts = options.maxReconnectAttempts || 5
this.reconnectAttempts = 0
// 状态
this.isConnected = false
this.isAuthed = false
this.isClosing = false // 标记是否主动关闭
// 事件处理器
this.eventHandlers = {
'connect': [],
'disconnect': [],
'auth_success': [],
'auth_fail': [],
'error': [],
'message': [] // 通用消息处理
}
// 子类可覆盖的消息类型处理
this.messageHandlers = {}
}
/**
* 连接到 WebSocket 服务器
*/
async connect(token, path) {
this.token = token
this.path = path
this.reconnectAttempts = 0
this.isClosing = false // 重置关闭状态,允许重连
// 异步获取真实 WebSocket 地址(等待环境检测完成)
this.baseUrl = await getWebSocketBaseUrl()
console.log(`[${this.serviceName}] WebSocket base URL: ${this.baseUrl}`)
this._doConnect()
}
_doConnect() {
// 如果已有连接且已连接,不重复连接
if (this.socket && this.isConnected) {
console.log(`[${this.serviceName}] Already connected (${this.isConnected}), skip reconnect`)
return
}
console.log(`[${this.serviceName}] _doConnect called, clearing old socket`)
// 清理旧连接
if (this.socket) {
this.socket = null
}
this.isConnected = false
const url = `${this.baseUrl}${this.path}?token=Bearer_${this.token}`
console.log(`[${this.serviceName}] Connecting to ${url}`)
// UniApp: connectSocket 是异步的,需要处理错误
try {
this.socket = uni.connectSocket({
url,
fail: (err) => {
console.error(`[${this.serviceName}] connectSocket fail:`, err)
this._emit('error', { code: 'CONNECT_FAILED', message: err.errMsg || '连接失败' })
}
})
if (!this.socket) {
console.error(`[${this.serviceName}] socket is null`)
return
}
console.log(`[${this.serviceName}] SocketTask created, checking methods:`, Object.keys(this.socket))
this._setupListeners()
} catch (err) {
console.error(`[${this.serviceName}] Exception during connect:`, err)
}
}
_setupListeners() {
// 检查 socket 是否有效
if (!this.socket) {
console.error(`[${this.serviceName}] socket is null in _setupListeners`)
return
}
// UniApp SocketTask API: onOpen/onClose/onError/onMessage 是设置回调的方法
// 也可能是事件名是 'open', 'close', 'error', 'message'
const socket = this.socket
const self = this
// 连接打开
if (typeof socket.onOpen === 'function') {
socket.onOpen(function() {
console.log(`[${self.serviceName}] WebSocket connected`)
self.isConnected = true
// 清除重连计时器
if (self.reconnectTimer) {
clearTimeout(self.reconnectTimer)
self.reconnectTimer = null
}
self.reconnectAttempts = 0
self._emit('connect')
})
} else if (typeof socket.onopen === 'function') {
// 标准 WebSocket 风格
socket.onopen(function() {
console.log(`[${self.serviceName}] WebSocket connected`)
self.isConnected = true
// 清除重连计时器
if (self.reconnectTimer) {
clearTimeout(self.reconnectTimer)
self.reconnectTimer = null
}
self.reconnectAttempts = 0
self._emit('connect')
})
} else {
console.warn(`[${self.serviceName}] Unknown socket API, socket keys:`, Object.keys(socket))
}
// 接收消息
if (typeof socket.onMessage === 'function') {
socket.onMessage(function(event) {
const data = JSON.parse(event.data)
self._handleMessage(data)
})
} else if (typeof socket.onmessage === 'function') {
socket.onmessage(function(event) {
const data = JSON.parse(event.data)
self._handleMessage(data)
})
}
// 连接关闭
if (typeof socket.onClose === 'function') {
socket.onClose(function() {
console.log(`[${self.serviceName}] WebSocket closed`)
self._cleanup()
self._emit('disconnect')
self._tryReconnect()
})
} else if (typeof socket.onclose === 'function') {
socket.onclose(function() {
console.log(`[${self.serviceName}] WebSocket closed`)
self._cleanup()
self._emit('disconnect')
self._tryReconnect()
})
}
// 连接错误
if (typeof socket.onError === 'function') {
socket.onError(function(err) {
console.error(`[${self.serviceName}] WebSocket error:`, err)
self._emit('error', err)
})
} else if (typeof socket.onerror === 'function') {
socket.onerror(function(err) {
console.error(`[${self.serviceName}] WebSocket error:`, err)
self._emit('error', err)
})
}
}
_handleMessage(data) {
// 触发通用消息事件
this._emit('message', data)
// 根据消息类型处理
const { type, action } = data
// 1. 鉴权响应(通用)
if (type === 'auth_response') {
if (data.success) {
this.isAuthed = true
this._emit('auth_success', data)
this._startHeartbeat()
} else {
this.isAuthed = false
this._emit('auth_fail', data)
this.close()
}
return
}
// 2. 心跳响应(通用)
if (type === 'pong') {
console.log(`[${this.serviceName}] Heartbeat received`)
return
}
// 3. 错误响应(通用)
if (type === 'error') {
this._emit('error', data)
return
}
// 4. 服务特定消息类型处理
const handler = this.messageHandlers[type] || this.messageHandlers[action]
if (handler) {
handler(data)
}
}
/**
* 发送消息
*/
send(data) {
if (!this.socket || !this.isConnected) {
console.warn(`[${this.serviceName}] Socket not connected`)
return false
}
// UniApp: socket.send({ data, success, fail })
if (typeof this.socket.send === 'function') {
this.socket.send({
data: JSON.stringify(data),
fail: (err) => {
console.error(`[${this.serviceName}] send fail:`, err)
}
})
} else {
console.warn(`[${this.serviceName}] socket.send is not a function`)
}
return true
}
/**
* 发送心跳
*/
_startHeartbeat() {
this._stopHeartbeat()
this.heartbeatTimer = setInterval(() => {
if (this.isConnected) {
this.send({ action: 'ping' })
}
}, this.heartbeatInterval)
}
_stopHeartbeat() {
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer)
this.heartbeatTimer = null
}
}
/**
* 重连机制
*/
_tryReconnect() {
if (this.isClosing) {
console.log(`[${this.serviceName}] Closing intentionally, skip reconnect`)
return
}
if (this.reconnectAttempts >= this.maxReconnectAttempts) {
console.warn(`[${this.serviceName}] Max reconnect attempts reached`)
this._emit('error', { code: 'RECONNECT_FAILED', message: '重连次数已达上限' })
return
}
this.reconnectAttempts++
console.log(`[${this.serviceName}] Auto reconnecting... (${this.reconnectAttempts}/${this.maxReconnectAttempts})`)
this.reconnectTimer = setTimeout(() => {
this._doConnect()
}, this.reconnectInterval)
}
_cleanup() {
this._stopHeartbeat()
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer)
this.reconnectTimer = null
}
this.isConnected = false
this.isAuthed = false
}
/**
* 关闭连接
*/
close() {
this.isClosing = true
this._cleanup()
if (this.socket) {
// UniApp: socket.close() 可能不存在,用 complete 代替
if (typeof this.socket.close === 'function') {
this.socket.close()
} else if (typeof this.socket.complete === 'function') {
this.socket.complete()
}
this.socket = null
}
}
// ===== 事件系统 =====
on(event, handler) {
if (this.eventHandlers[event]) {
this.eventHandlers[event].push(handler)
}
return () => this.off(event, handler) // 返回取消订阅函数
}
off(event, handler) {
if (this.eventHandlers[event]) {
this.eventHandlers[event] = this.eventHandlers[event].filter(h => h !== handler)
}
}
_emit(event, data) {
if (this.eventHandlers[event]) {
this.eventHandlers[event].forEach(handler => handler(data))
}
}
// ===== 订阅特定消息类型 =====
/**
* 注册特定消息类型的处理器
* @param {string} type 消息类型或 action
* @param {function} handler 处理函数
*/
registerHandler(type, handler) {
this.messageHandlers[type] = handler
}
}
/** 导出 */
export default SocketManager

View File

@ -0,0 +1,3 @@
export { default as SocketManager } from './SocketManager'
export { default as AiChatSocket, getAiChatSocket, closeAiChatSocket, isAiChatClosing, resetAiChatClosing } from './AiChatSocket'
export { default as GlobalSocketManager, getGlobalSocket } from './GlobalSocketManager'