265 lines
7.7 KiB
Go
265 lines
7.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"syscall"
|
|
"time"
|
|
|
|
_ "dubbo.apache.org/dubbo-go/v3/imports"
|
|
_ "dubbo.apache.org/dubbo-go/v3/protocol/triple"
|
|
"dubbo.apache.org/dubbo-go/v3/protocol"
|
|
"dubbo.apache.org/dubbo-go/v3/server"
|
|
"github.com/joho/godotenv"
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
"github.com/topfans/backend/pkg/database"
|
|
"github.com/topfans/backend/pkg/health"
|
|
"github.com/topfans/backend/pkg/logger"
|
|
"github.com/topfans/backend/services/aiChatService/model"
|
|
"github.com/topfans/backend/services/aiChatService/provider"
|
|
"github.com/topfans/backend/services/aiChatService/repository"
|
|
"github.com/topfans/backend/services/aiChatService/service"
|
|
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
var (
|
|
port = flag.Int("port", getEnvInt("PORT", 20008), "Dubbo service port")
|
|
dbHost = flag.String("db-host", getEnv("DB_HOST", "localhost"), "Database host")
|
|
dbPort = flag.Int("db-port", getEnvInt("DB_PORT", 5432), "Database port")
|
|
dbUser = flag.String("db-user", getEnv("DB_USER", "postgres"), "Database user")
|
|
dbPassword = flag.String("db-password", getEnv("DB_PASSWORD", ""), "Database password")
|
|
dbName = flag.String("db-name", getEnv("DB_NAME", "top-fans"), "Database name")
|
|
redisHost = flag.String("redis-host", getEnv("REDIS_HOST", "127.0.0.1"), "Redis host")
|
|
redisPort = flag.Int("redis-port", getEnvInt("REDIS_PORT", 6379), "Redis port")
|
|
redisPassword = flag.String("redis-password", getEnv("REDIS_PASSWORD", ""), "Redis password")
|
|
redisDB = flag.Int("redis-db", getEnvInt("REDIS_DB", 0), "Redis db")
|
|
contextTTL = flag.Int("context-ttl", getEnvInt("CONTEXT_TTL", 86400), "Context TTL in seconds")
|
|
triggerTurns = flag.Int("trigger-turns", getEnvInt("TRIGGER_TURNS", 5), "Turns to trigger memory extraction")
|
|
healthHandler *health.Handler
|
|
)
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func getEnvInt(key string, fallback int) int {
|
|
if v := os.Getenv(key); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil {
|
|
return n
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func main() {
|
|
godotenv.Load()
|
|
flag.Parse()
|
|
|
|
env := os.Getenv("ENV")
|
|
if env == "" {
|
|
env = "development"
|
|
}
|
|
|
|
if err := logger.Init(logger.Config{
|
|
ServiceName: "ai-chat-service",
|
|
Environment: env,
|
|
LogLevel: os.Getenv("LOG_LEVEL"),
|
|
}); err != nil {
|
|
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
|
|
}
|
|
defer logger.Sync()
|
|
|
|
logger.Logger.Info("Starting AI Chat Service...")
|
|
|
|
// 初始化数据库
|
|
dbConfig := database.Config{
|
|
Host: *dbHost,
|
|
Port: *dbPort,
|
|
User: *dbUser,
|
|
Password: *dbPassword,
|
|
DBName: *dbName,
|
|
SSLMode: "disable",
|
|
TimeZone: "Asia/Shanghai",
|
|
}
|
|
|
|
if err := database.Init(dbConfig); err != nil {
|
|
logger.Logger.Fatal(fmt.Sprintf("Failed to initialize database: %v", err))
|
|
}
|
|
logger.Logger.Info("Database initialized successfully")
|
|
|
|
// 初始化 Redis
|
|
redisClient := redis.NewClient(&redis.Options{
|
|
Addr: fmt.Sprintf("%s:%d", *redisHost, *redisPort),
|
|
Password: *redisPassword,
|
|
DB: *redisDB,
|
|
})
|
|
|
|
// 测试 Redis 连接
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
if err := redisClient.Ping(ctx).Err(); err != nil {
|
|
logger.Logger.Fatal(fmt.Sprintf("Failed to connect to Redis: %v", err))
|
|
}
|
|
cancel()
|
|
logger.Logger.Info("Redis connected successfully")
|
|
|
|
// 启动健康检查 HTTP 服务器
|
|
healthPort := *port + 1000
|
|
healthHandler = health.NewHandler("ai-chat-service", healthPort)
|
|
healthHandler.Start()
|
|
|
|
// 自动迁移数据库表
|
|
if err := autoMigrate(); err != nil {
|
|
logger.Logger.Fatal(fmt.Sprintf("Failed to migrate database: %v", err))
|
|
}
|
|
|
|
// 创建 Repository 层实例
|
|
personaRepo := repository.NewPostgreSQLPersonaRepository(database.GetDB())
|
|
configRepo := repository.NewPostgreSQLConfigRepository(database.GetDB())
|
|
shortTermMemoryRepo := repository.NewRedisMemoryRepository(redisClient, *contextTTL)
|
|
longTermMemoryRepo := repository.NewPostgreSQLMemoryRepository(database.GetDB())
|
|
logger.Logger.Info("Repository layer initialized")
|
|
|
|
// 从数据库加载 LLM 配置
|
|
loadCtx := context.Background()
|
|
llmConfigs, err := configRepo.GetByCategory(loadCtx, "llm")
|
|
if err != nil {
|
|
logger.Logger.Warn("Failed to load LLM configs from database, using env defaults", zap.Error(err))
|
|
}
|
|
|
|
// 获取 MiniMax 配置
|
|
miniMaxAPIKey := getEnv("MINIMAX_API_KEY", "")
|
|
miniMaxAPIURL := getEnv("MINIMAX_API_URL", "https://api.minimaxi.com/v1")
|
|
miniMaxModel := getEnv("MINIMAX_MODEL", "M2-her")
|
|
if val, ok := llmConfigs["minimax.api_key"]; ok && val != "" {
|
|
miniMaxAPIKey = val
|
|
}
|
|
if val, ok := llmConfigs["minimax.api_url"]; ok && val != "" {
|
|
miniMaxAPIURL = val
|
|
}
|
|
if val, ok := llmConfigs["minimax.model"]; ok && val != "" {
|
|
miniMaxModel = val
|
|
}
|
|
|
|
// 获取 Qwen 配置
|
|
qwenAPIKey := getEnv("QWEN_API_KEY", "")
|
|
qwenAPIURL := getEnv("QWEN_API_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
qwenModel := getEnv("QWEN_MODEL", "qwen-plus")
|
|
if val, ok := llmConfigs["qwen.api_key"]; ok && val != "" {
|
|
qwenAPIKey = val
|
|
}
|
|
if val, ok := llmConfigs["qwen.api_url"]; ok && val != "" {
|
|
qwenAPIURL = val
|
|
}
|
|
if val, ok := llmConfigs["qwen.model"]; ok && val != "" {
|
|
qwenModel = val
|
|
}
|
|
|
|
logger.Logger.Info("LLM configs loaded",
|
|
zap.String("minimax_url", miniMaxAPIURL),
|
|
zap.String("minimax_model", miniMaxModel),
|
|
zap.Bool("has_minimax_key", miniMaxAPIKey != ""),
|
|
zap.String("qwen_url", qwenAPIURL),
|
|
zap.String("qwen_model", qwenModel),
|
|
zap.Bool("has_qwen_key", qwenAPIKey != ""),
|
|
)
|
|
|
|
// 创建 Service 层实例
|
|
llmService := service.NewLLMService(
|
|
miniMaxAPIURL,
|
|
miniMaxAPIKey,
|
|
miniMaxModel,
|
|
qwenAPIURL,
|
|
qwenAPIKey,
|
|
qwenModel,
|
|
)
|
|
personaService := service.NewPersonaService(personaRepo)
|
|
memoryService := service.NewMemoryService(shortTermMemoryRepo, longTermMemoryRepo)
|
|
auditService := service.NewAuditService()
|
|
chatService := service.NewChatService(
|
|
llmService,
|
|
personaService,
|
|
memoryService,
|
|
auditService,
|
|
*contextTTL,
|
|
*triggerTurns,
|
|
)
|
|
logger.Logger.Info("Service layer initialized")
|
|
|
|
// 创建 Provider 层实例
|
|
aiChatProvider := provider.NewAIChatProvider(
|
|
chatService,
|
|
personaService,
|
|
memoryService,
|
|
auditService,
|
|
)
|
|
logger.Logger.Info("Provider layer initialized")
|
|
|
|
// 创建 Dubbo 服务器
|
|
srv, err := server.NewServer(
|
|
server.WithServerProtocol(
|
|
protocol.WithPort(*port),
|
|
protocol.WithTriple(),
|
|
),
|
|
)
|
|
if err != nil {
|
|
logger.Logger.Fatal(fmt.Sprintf("Failed to create Dubbo server: %v", err))
|
|
}
|
|
|
|
// 注册 AI Chat Service
|
|
if err := pbAIChat.RegisterAIChatServiceHandler(srv, aiChatProvider); err != nil {
|
|
logger.Logger.Fatal(fmt.Sprintf("Failed to register AI Chat Service: %v", err))
|
|
}
|
|
|
|
// 启动服务
|
|
if err := srv.Serve(); err != nil {
|
|
logger.Logger.Fatal(fmt.Sprintf("Failed to start AI Chat Service: %v", err))
|
|
}
|
|
|
|
logger.Logger.Info(fmt.Sprintf("AI Chat Service started successfully on port %d", *port))
|
|
|
|
// 等待退出信号
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
<-quit
|
|
|
|
logger.Logger.Info("Shutting down AI Chat Service...")
|
|
|
|
if healthHandler != nil {
|
|
healthHandler.Stop()
|
|
}
|
|
|
|
redisClient.Close()
|
|
}
|
|
|
|
// autoMigrate 自动迁移数据库表
|
|
func autoMigrate() error {
|
|
db := database.GetDB()
|
|
if db == nil {
|
|
return fmt.Errorf("database is not initialized")
|
|
}
|
|
|
|
tables := []interface{}{
|
|
&model.Persona{},
|
|
&model.UserMemory{},
|
|
&model.Config{},
|
|
}
|
|
|
|
for _, table := range tables {
|
|
if err := db.AutoMigrate(table); err != nil {
|
|
// 如果约束删除失败(表已存在该约束),忽略并继续
|
|
logger.Logger.Warn(fmt.Sprintf("Migration warning for %T: %v", table, err))
|
|
}
|
|
}
|
|
|
|
logger.Logger.Info("Database migration completed successfully")
|
|
return nil
|
|
} |