topfans/backend/services/aiChatService/service/prompt_builder.go
2026-05-28 12:00:19 +08:00

180 lines
4.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"github.com/topfans/backend/services/aiChatService/model"
)
// Token 限制配置
const (
MaxTotalTokens = 32000 // 总 Token 上限
MaxHistoryTokens = 24000 // 对话历史最大 Token
MaxSystemTokens = 4000 // System Prompt 最大 Token
MaxMemoryTokens = 2000 // 记忆召回最大 Token
ReservedTokens = 2000 // 保留空间
MinHistoryMessages = 4 // 最少保留消息对数
MaxSingleMessageTokens = 4000 // 单条消息最大 Token
)
// Tokenizer Token 计算器
type Tokenizer struct{}
// EstimateTokens 估算 Token 数量
func (t *Tokenizer) EstimateTokens(text string) int {
var count int
for _, r := range text {
switch {
case r >= 0x4e00 && r <= 0x9fff: // 中文
count += 2
case r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z': // 英文
count += 1
case r >= '0' && r <= '9':
count += 1
case r < 128: // ASCII 符号
count += 1
default:
count += 2
}
}
return count
}
// EstimateMessagesTokens 估算消息列表的总 Token
func (t *Tokenizer) EstimateMessagesTokens(messages []model.Message) int {
var total int
for _, m := range messages {
total += t.EstimateTokens(m.Role) + t.EstimateTokens(m.Content) + 10
}
return total
}
// BuildPrompt 组装 Prompt
func BuildPrompt(
systemPrompt string,
userCoreInfo string,
history []model.Message,
userInput string,
tokenizer *Tokenizer,
) ([]model.Message, int) {
// 1. 计算各部分 Token
systemTokens := tokenizer.EstimateTokens(systemPrompt)
memoryTokens := tokenizer.EstimateTokens(userCoreInfo)
// 2. 预留空间计算
reserved := ReservedTokens
if systemTokens > MaxSystemTokens {
reserved += systemTokens - MaxSystemTokens
}
// 3. 计算可用于对话历史的 Token
availableTokens := MaxHistoryTokens - memoryTokens - reserved
if availableTokens < 500 {
availableTokens = 500
}
// 4. 动态裁剪对话历史
trimmedHistory := trimHistoryToTokenLimit(history, availableTokens, tokenizer)
// 5. 组装最终消息
messages := []model.Message{
{Role: "system", Content: systemPrompt},
}
if userCoreInfo != "" {
messages = append(messages, model.Message{
Role: "system",
Content: "# 用户核心记忆\n" + userCoreInfo,
})
}
messages = append(messages, trimmedHistory...)
messages = append(messages, model.Message{Role: "user", Content: userInput})
// 6. 最终 Token 统计
totalTokens := tokenizer.EstimateMessagesTokens(messages)
return messages, totalTokens
}
// trimHistoryToTokenLimit 裁剪历史消息至 Token 限制内
func trimHistoryToTokenLimit(history []model.Message, maxTokens int, tokenizer *Tokenizer) []model.Message {
if len(history) == 0 {
return history
}
currentTokens := tokenizer.EstimateMessagesTokens(history)
if currentTokens <= maxTokens {
return history
}
result := make([]model.Message, 0)
var usedTokens int
// 从最新开始保留
for i := len(history) - 1; i >= 0; i -= 2 {
msgToken := tokenizer.EstimateTokens(history[i].Content) + 10
prevToken := 0
if i > 0 {
prevToken = tokenizer.EstimateTokens(history[i-1].Content) + 10
}
pairTokens := msgToken + prevToken
if len(result)/2 >= MinHistoryMessages && usedTokens+pairTokens > maxTokens {
break
}
if i > 0 {
result = append([]model.Message{history[i-1], history[i]}, result...)
} else {
result = append([]model.Message{history[i]}, result...)
}
usedTokens += pairTokens
}
return result
}
// truncateMessageIfNeeded 截断超长单条消息
func truncateMessageIfNeeded(content string, maxTokens int, tokenizer *Tokenizer) string {
if tokenizer.EstimateTokens(content) <= maxTokens {
return content
}
runes := []rune(content)
lo, hi := 0, len(runes)
for lo < hi {
mid := (lo + hi + 1) / 2
if tokenizer.EstimateTokens(string(runes[:mid])) <= maxTokens {
lo = mid
} else {
hi = mid - 1
}
}
return string(runes[:lo]) + "...(已截断)"
}
// EstimateTurnCount 估算对话轮数
func EstimateTurnCount(messages []model.Message) int {
turns := 0
for i := 0; i < len(messages)-1; i += 2 {
if messages[i].Role == "user" && messages[i+1].Role == "assistant" {
turns++
}
}
return turns
}
// IsNoNeedLLMCall 判断是否不需要调用大模型
func IsNoNeedLLMCall(input string) bool {
// 纯符号或数字
symbolOnly := true
for _, r := range input {
if r != ' ' && r != '\n' && (r < '0' || r > '9') && r != '?' && r != '' && r != '.' && r != '。' {
symbolOnly = false
break
}
}
return symbolOnly
}