feat: Enhance token counting and content parsing for messages

This commit is contained in:
1808837298@qq.com
2025-02-24 14:18:15 +08:00
parent cc1d6e1c05
commit 7ff4cebdbe
3 changed files with 40 additions and 26 deletions

View File

@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34 "https://api.cohere.ai", //34
"https://api.minimax.chat", //35 "https://api.minimax.chat", //35
"", //36 "", //36
"", //37 "https://api.dify.ai", //37
"https://api.jina.ai", //38 "https://api.jina.ai", //38
"https://api.cloudflare.com", //39 "https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40 "https://api.siliconflow.cn", //40

View File

@@ -1,6 +1,9 @@
package dto package dto
import "encoding/json" import (
"encoding/json"
"strings"
)
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
@@ -153,11 +156,24 @@ func (m *Message) StringContent() string {
if m.parsedStringContent != nil { if m.parsedStringContent != nil {
return *m.parsedStringContent return *m.parsedStringContent
} }
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
m.parsedStringContent = &stringContent
return stringContent return stringContent
} }
return string(m.Content)
contentStr := new(strings.Builder)
arrayContent := m.ParseContent()
for _, content := range arrayContent {
if content.Type == ContentTypeText {
contentStr.WriteString(content.Text)
}
}
stringContent = contentStr.String()
m.parsedStringContent = &stringContent
return stringContent
} }
func (m *Message) SetStringContent(content string) { func (m *Message) SetStringContent(content string) {

View File

@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if text == "" {
return 0
}
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))
} }
@@ -282,14 +285,10 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 { if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil { if message.Name != nil {
tokenNum += tokensPerName tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name) tokenNum += getTokenNum(tokenEncoder, *message.Name)
} }
} else {
arrayContent := message.ParseContent() arrayContent := message.ParseContent()
for _, m := range arrayContent { for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL { if m.Type == dto.ContentTypeImageURL {
@@ -309,7 +308,6 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
} }
} }
} }
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil return tokenNum, nil
} }