feat: Enhance token counting and content parsing for messages

This commit is contained in:
1808837298@qq.com
2025-02-24 14:18:15 +08:00
parent cc1d6e1c05
commit 7ff4cebdbe
3 changed files with 40 additions and 26 deletions

View File

@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34 "https://api.cohere.ai", //34
"https://api.minimax.chat", //35 "https://api.minimax.chat", //35
"", //36 "", //36
"", //37 "https://api.dify.ai", //37
"https://api.jina.ai", //38 "https://api.jina.ai", //38
"https://api.cloudflare.com", //39 "https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40 "https://api.siliconflow.cn", //40

View File

@@ -1,6 +1,9 @@
package dto package dto
import "encoding/json" import (
"encoding/json"
"strings"
)
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
@@ -153,11 +156,24 @@ func (m *Message) StringContent() string {
if m.parsedStringContent != nil { if m.parsedStringContent != nil {
return *m.parsedStringContent return *m.parsedStringContent
} }
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent return stringContent
} }
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
} }
func (m *Message) SetStringContent(content string) { func (m *Message) SetStringContent(content string) {

View File

@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if text == "" {
return 0
}
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))
} }
@@ -282,30 +285,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 { if len(message.Content) > 0 {
if message.IsStringContent() { if message.Name != nil {
stringContent := message.StringContent() tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, stringContent) tokenNum += getTokenNum(tokenEncoder, *message.Name)
if message.Name != nil { }
tokenNum += tokensPerName arrayContent := message.ParseContent()
tokenNum += getTokenNum(tokenEncoder, *message.Name) for _, m := range arrayContent {
} if m.Type == dto.ContentTypeImageURL {
} else { imageUrl := m.ImageUrl.(dto.MessageImageUrl)
arrayContent := message.ParseContent() imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
for _, m := range arrayContent { if err != nil {
if m.Type == dto.ContentTypeImageURL { return 0, err
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
} }
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
} }
} }
} }