topfans/backend/services/aiChatService/main.go
2026-05-28 12:00:19 +08:00

265 lines
7.7 KiB
Go

package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"strconv"
"syscall"
"time"
_ "dubbo.apache.org/dubbo-go/v3/imports"
_ "dubbo.apache.org/dubbo-go/v3/protocol/triple"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/server"
"github.com/joho/godotenv"
"github.com/redis/go-redis/v9"
"github.com/topfans/backend/pkg/database"
"github.com/topfans/backend/pkg/health"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/aiChatService/model"
"github.com/topfans/backend/services/aiChatService/provider"
"github.com/topfans/backend/services/aiChatService/repository"
"github.com/topfans/backend/services/aiChatService/service"
pbAIChat "github.com/topfans/backend/pkg/proto/ai_chat"
"go.uber.org/zap"
)
var (
port = flag.Int("port", getEnvInt("PORT", 20008), "Dubbo service port")
dbHost = flag.String("db-host", getEnv("DB_HOST", "localhost"), "Database host")
dbPort = flag.Int("db-port", getEnvInt("DB_PORT", 5432), "Database port")
dbUser = flag.String("db-user", getEnv("DB_USER", "postgres"), "Database user")
dbPassword = flag.String("db-password", getEnv("DB_PASSWORD", ""), "Database password")
dbName = flag.String("db-name", getEnv("DB_NAME", "top-fans"), "Database name")
redisHost = flag.String("redis-host", getEnv("REDIS_HOST", "127.0.0.1"), "Redis host")
redisPort = flag.Int("redis-port", getEnvInt("REDIS_PORT", 6379), "Redis port")
redisPassword = flag.String("redis-password", getEnv("REDIS_PASSWORD", ""), "Redis password")
redisDB = flag.Int("redis-db", getEnvInt("REDIS_DB", 0), "Redis db")
contextTTL = flag.Int("context-ttl", getEnvInt("CONTEXT_TTL", 86400), "Context TTL in seconds")
triggerTurns = flag.Int("trigger-turns", getEnvInt("TRIGGER_TURNS", 5), "Turns to trigger memory extraction")
healthHandler *health.Handler
)
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return fallback
}
func main() {
godotenv.Load()
flag.Parse()
env := os.Getenv("ENV")
if env == "" {
env = "development"
}
if err := logger.Init(logger.Config{
ServiceName: "ai-chat-service",
Environment: env,
LogLevel: os.Getenv("LOG_LEVEL"),
}); err != nil {
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
}
defer logger.Sync()
logger.Logger.Info("Starting AI Chat Service...")
// 初始化数据库
dbConfig := database.Config{
Host: *dbHost,
Port: *dbPort,
User: *dbUser,
Password: *dbPassword,
DBName: *dbName,
SSLMode: "disable",
TimeZone: "Asia/Shanghai",
}
if err := database.Init(dbConfig); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to initialize database: %v", err))
}
logger.Logger.Info("Database initialized successfully")
// 初始化 Redis
redisClient := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", *redisHost, *redisPort),
Password: *redisPassword,
DB: *redisDB,
})
// 测试 Redis 连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := redisClient.Ping(ctx).Err(); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to connect to Redis: %v", err))
}
cancel()
logger.Logger.Info("Redis connected successfully")
// 启动健康检查 HTTP 服务器
healthPort := *port + 1000
healthHandler = health.NewHandler("ai-chat-service", healthPort)
healthHandler.Start()
// 自动迁移数据库表
if err := autoMigrate(); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to migrate database: %v", err))
}
// 创建 Repository 层实例
personaRepo := repository.NewPostgreSQLPersonaRepository(database.GetDB())
configRepo := repository.NewPostgreSQLConfigRepository(database.GetDB())
shortTermMemoryRepo := repository.NewRedisMemoryRepository(redisClient, *contextTTL)
longTermMemoryRepo := repository.NewPostgreSQLMemoryRepository(database.GetDB())
logger.Logger.Info("Repository layer initialized")
// 从数据库加载 LLM 配置
loadCtx := context.Background()
llmConfigs, err := configRepo.GetByCategory(loadCtx, "llm")
if err != nil {
logger.Logger.Warn("Failed to load LLM configs from database, using env defaults", zap.Error(err))
}
// 获取 MiniMax 配置
miniMaxAPIKey := getEnv("MINIMAX_API_KEY", "")
miniMaxAPIURL := getEnv("MINIMAX_API_URL", "https://api.minimaxi.com/v1")
miniMaxModel := getEnv("MINIMAX_MODEL", "M2-her")
if val, ok := llmConfigs["minimax.api_key"]; ok && val != "" {
miniMaxAPIKey = val
}
if val, ok := llmConfigs["minimax.api_url"]; ok && val != "" {
miniMaxAPIURL = val
}
if val, ok := llmConfigs["minimax.model"]; ok && val != "" {
miniMaxModel = val
}
// 获取 Qwen 配置
qwenAPIKey := getEnv("QWEN_API_KEY", "")
qwenAPIURL := getEnv("QWEN_API_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
qwenModel := getEnv("QWEN_MODEL", "qwen-plus")
if val, ok := llmConfigs["qwen.api_key"]; ok && val != "" {
qwenAPIKey = val
}
if val, ok := llmConfigs["qwen.api_url"]; ok && val != "" {
qwenAPIURL = val
}
if val, ok := llmConfigs["qwen.model"]; ok && val != "" {
qwenModel = val
}
logger.Logger.Info("LLM configs loaded",
zap.String("minimax_url", miniMaxAPIURL),
zap.String("minimax_model", miniMaxModel),
zap.Bool("has_minimax_key", miniMaxAPIKey != ""),
zap.String("qwen_url", qwenAPIURL),
zap.String("qwen_model", qwenModel),
zap.Bool("has_qwen_key", qwenAPIKey != ""),
)
// 创建 Service 层实例
llmService := service.NewLLMService(
miniMaxAPIURL,
miniMaxAPIKey,
miniMaxModel,
qwenAPIURL,
qwenAPIKey,
qwenModel,
)
personaService := service.NewPersonaService(personaRepo)
memoryService := service.NewMemoryService(shortTermMemoryRepo, longTermMemoryRepo)
auditService := service.NewAuditService()
chatService := service.NewChatService(
llmService,
personaService,
memoryService,
auditService,
*contextTTL,
*triggerTurns,
)
logger.Logger.Info("Service layer initialized")
// 创建 Provider 层实例
aiChatProvider := provider.NewAIChatProvider(
chatService,
personaService,
memoryService,
auditService,
)
logger.Logger.Info("Provider layer initialized")
// 创建 Dubbo 服务器
srv, err := server.NewServer(
server.WithServerProtocol(
protocol.WithPort(*port),
protocol.WithTriple(),
),
)
if err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to create Dubbo server: %v", err))
}
// 注册 AI Chat Service
if err := pbAIChat.RegisterAIChatServiceHandler(srv, aiChatProvider); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to register AI Chat Service: %v", err))
}
// 启动服务
if err := srv.Serve(); err != nil {
logger.Logger.Fatal(fmt.Sprintf("Failed to start AI Chat Service: %v", err))
}
logger.Logger.Info(fmt.Sprintf("AI Chat Service started successfully on port %d", *port))
// 等待退出信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Logger.Info("Shutting down AI Chat Service...")
if healthHandler != nil {
healthHandler.Stop()
}
redisClient.Close()
}
// autoMigrate 自动迁移数据库表
func autoMigrate() error {
db := database.GetDB()
if db == nil {
return fmt.Errorf("database is not initialized")
}
tables := []interface{}{
&model.Persona{},
&model.UserMemory{},
&model.Config{},
}
for _, table := range tables {
if err := db.AutoMigrate(table); err != nil {
// 如果约束删除失败(表已存在该约束),忽略并继续
logger.Logger.Warn(fmt.Sprintf("Migration warning for %T: %v", table, err))
}
}
logger.Logger.Info("Database migration completed successfully")
return nil
}