312 lines
7.0 KiB
Go
312 lines
7.0 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/topfans/backend/services/aiChatService/model"
|
|
)
|
|
|
|
// StreamReader 流式读取器接口
|
|
type StreamReader interface {
|
|
Next() (content string, done bool, err error)
|
|
Close() error
|
|
}
|
|
|
|
// SensitiveContentError 敏感内容错误
|
|
type SensitiveContentError struct {
|
|
Message string
|
|
}
|
|
|
|
func (e *SensitiveContentError) Error() string {
|
|
return e.Message
|
|
}
|
|
|
|
// LLMService 大模型服务
|
|
type LLMService struct {
|
|
miniMaxClient *http.Client
|
|
qwenClient *http.Client
|
|
|
|
miniMaxURL string
|
|
miniMaxKey string
|
|
miniMaxModel string
|
|
|
|
qwenURL string
|
|
qwenKey string
|
|
qwenModel string
|
|
|
|
failCount int
|
|
fallbackCount int
|
|
useBackup bool
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// MiniMaxMessage MiniMax 消息格式
|
|
type MiniMaxMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// MiniMaxRequest MiniMax 请求格式
|
|
type MiniMaxRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []MiniMaxMessage `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
// MiniMaxResponse MiniMax 响应格式
|
|
type MiniMaxResponse struct {
|
|
Choices []struct {
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
} `json:"delta"`
|
|
} `json:"choices"`
|
|
}
|
|
|
|
// MiniMaxSSERecord SSE 格式记录
|
|
type MiniMaxSSERecord struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Index int `json:"index"`
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
} `json:"delta"`
|
|
} `json:"choices"`
|
|
}
|
|
|
|
// NewLLMService 创建大模型服务
|
|
func NewLLMService(miniMaxURL, miniMaxKey, miniMaxModel, qwenURL, qwenKey, qwenModel string) *LLMService {
|
|
return &LLMService{
|
|
miniMaxClient: &http.Client{Timeout: 60 * time.Second},
|
|
qwenClient: &http.Client{Timeout: 60 * time.Second},
|
|
miniMaxURL: miniMaxURL,
|
|
miniMaxKey: miniMaxKey,
|
|
miniMaxModel: miniMaxModel,
|
|
qwenURL: qwenURL,
|
|
qwenKey: qwenKey,
|
|
qwenModel: qwenModel,
|
|
}
|
|
}
|
|
|
|
// StreamChat 流式聊天
|
|
func (s *LLMService) StreamChat(ctx context.Context, messages []model.Message) (StreamReader, error) {
|
|
// Qwen fallback 已禁用
|
|
return s.streamChatMiniMax(ctx, messages)
|
|
}
|
|
|
|
// StreamChatWithBackup 禁用备用模型
|
|
func (s *LLMService) StreamChatWithBackup(ctx context.Context, messages []model.Message) (StreamReader, error) {
|
|
// Qwen fallback 已禁用,直接返回 MiniMax
|
|
return s.streamChatMiniMax(ctx, messages)
|
|
}
|
|
|
|
// streamChatMiniMax MiniMax 流式聊天
|
|
func (s *LLMService) streamChatMiniMax(ctx context.Context, messages []model.Message) (*MiniMaxStreamReader, error) {
|
|
reqMessages := make([]MiniMaxMessage, len(messages))
|
|
for i, m := range messages {
|
|
reqMessages[i] = MiniMaxMessage{
|
|
Role: m.Role,
|
|
Content: m.Content,
|
|
}
|
|
}
|
|
|
|
reqBody := MiniMaxRequest{
|
|
Model: s.miniMaxModel,
|
|
Messages: reqMessages,
|
|
Stream: true,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", s.miniMaxURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+s.miniMaxKey)
|
|
|
|
resp, err := s.miniMaxClient.Do(req)
|
|
if err != nil {
|
|
s.incrementFailCount()
|
|
return nil, fmt.Errorf("failed to call MiniMax: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
bodyStr := string(body)
|
|
|
|
// 检查是否是敏感内容错误 (1027)
|
|
if strings.Contains(bodyStr, "1027") || strings.Contains(bodyStr, "new_sensitive") {
|
|
return nil, &SensitiveContentError{Message: "content blocked by MiniMax safety filter"}
|
|
}
|
|
|
|
s.incrementFailCount()
|
|
// 如果是 401/403 等错误,切换到备用模型
|
|
if resp.StatusCode == 401 || resp.StatusCode == 403 {
|
|
s.SwitchToBackup()
|
|
}
|
|
return nil, fmt.Errorf("MiniMax API error: %d, body: %s", resp.StatusCode, bodyStr)
|
|
}
|
|
|
|
s.resetFailCount()
|
|
return &MiniMaxStreamReader{
|
|
reader: resp.Body,
|
|
decoder: NewSSEDecoder(resp.Body),
|
|
}, nil
|
|
}
|
|
|
|
// streamChatQwen 通义流式聊天
|
|
func (s *LLMService) streamChatQwen(ctx context.Context, messages []model.Message) (*MiniMaxStreamReader, error) {
|
|
reqMessages := make([]MiniMaxMessage, len(messages))
|
|
for i, m := range messages {
|
|
reqMessages[i] = MiniMaxMessage{
|
|
Role: m.Role,
|
|
Content: m.Content,
|
|
}
|
|
}
|
|
|
|
reqBody := MiniMaxRequest{
|
|
Model: s.qwenModel,
|
|
Messages: reqMessages,
|
|
Stream: true,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", s.qwenURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+s.qwenKey)
|
|
|
|
resp, err := s.qwenClient.Do(req)
|
|
if err != nil {
|
|
s.incrementFailCount()
|
|
return nil, fmt.Errorf("failed to call Qwen: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
s.incrementFailCount()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
return nil, fmt.Errorf("Qwen API error: %d, body: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
s.resetFailCount()
|
|
return &MiniMaxStreamReader{
|
|
reader: resp.Body,
|
|
decoder: NewSSEDecoder(resp.Body),
|
|
}, nil
|
|
}
|
|
|
|
func (s *LLMService) incrementFailCount() {
|
|
s.mu.Lock()
|
|
s.failCount++
|
|
if s.failCount >= 3 {
|
|
s.useBackup = true
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *LLMService) resetFailCount() {
|
|
s.mu.Lock()
|
|
s.failCount = 0
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *LLMService) SwitchToBackup() {
|
|
s.mu.Lock()
|
|
s.useBackup = true
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// MiniMaxStreamReader MiniMax 流式读取器
|
|
type MiniMaxStreamReader struct {
|
|
reader io.ReadCloser
|
|
decoder *SSEDecoder
|
|
buffer string
|
|
}
|
|
|
|
func (r *MiniMaxStreamReader) Next() (content string, done bool, err error) {
|
|
for {
|
|
line, err := r.decoder.Next()
|
|
if err != nil {
|
|
r.Close()
|
|
return "", true, err
|
|
}
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
// SSE data: 格式
|
|
if strings.HasPrefix(line, "data: ") {
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
return "", true, nil
|
|
}
|
|
|
|
var record MiniMaxSSERecord
|
|
if err := json.Unmarshal([]byte(data), &record); err != nil {
|
|
continue
|
|
}
|
|
|
|
if len(record.Choices) > 0 && record.Choices[0].Delta.Content != "" {
|
|
return record.Choices[0].Delta.Content, false, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *MiniMaxStreamReader) Close() error {
|
|
return r.reader.Close()
|
|
}
|
|
|
|
// SSEDecoder SSE 解码器
|
|
type SSEDecoder struct {
|
|
reader io.Reader
|
|
}
|
|
|
|
func NewSSEDecoder(reader io.Reader) *SSEDecoder {
|
|
return &SSEDecoder{reader: reader}
|
|
}
|
|
|
|
func (d *SSEDecoder) Next() (string, error) {
|
|
var line []byte
|
|
buf := make([]byte, 1)
|
|
|
|
for {
|
|
n, err := d.reader.Read(buf)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if n == 0 {
|
|
continue
|
|
}
|
|
|
|
if buf[0] == '\n' {
|
|
break
|
|
}
|
|
line = append(line, buf[0])
|
|
}
|
|
|
|
return strings.TrimRight(string(line), "\r"), nil
|
|
} |