feat: Enhance token counting and content parsing for messages
This commit is contained in:
@@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.cohere.ai", //34
|
"https://api.cohere.ai", //34
|
||||||
"https://api.minimax.chat", //35
|
"https://api.minimax.chat", //35
|
||||||
"", //36
|
"", //36
|
||||||
"", //37
|
"https://api.dify.ai", //37
|
||||||
"https://api.jina.ai", //38
|
"https://api.jina.ai", //38
|
||||||
"https://api.cloudflare.com", //39
|
"https://api.cloudflare.com", //39
|
||||||
"https://api.siliconflow.cn", //40
|
"https://api.siliconflow.cn", //40
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
@@ -153,11 +156,24 @@ func (m *Message) StringContent() string {
|
|||||||
if m.parsedStringContent != nil {
|
if m.parsedStringContent != nil {
|
||||||
return *m.parsedStringContent
|
return *m.parsedStringContent
|
||||||
}
|
}
|
||||||
|
|
||||||
var stringContent string
|
var stringContent string
|
||||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||||
|
m.parsedStringContent = &stringContent
|
||||||
return stringContent
|
return stringContent
|
||||||
}
|
}
|
||||||
return string(m.Content)
|
|
||||||
|
contentStr := new(strings.Builder)
|
||||||
|
arrayContent := m.ParseContent()
|
||||||
|
for _, content := range arrayContent {
|
||||||
|
if content.Type == ContentTypeText {
|
||||||
|
contentStr.WriteString(content.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stringContent = contentStr.String()
|
||||||
|
m.parsedStringContent = &stringContent
|
||||||
|
|
||||||
|
return stringContent
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) SetStringContent(content string) {
|
func (m *Message) SetStringContent(content string) {
|
||||||
|
|||||||
@@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||||
|
if text == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
return len(tokenEncoder.Encode(text, nil, nil))
|
return len(tokenEncoder.Encode(text, nil, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,30 +285,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
|||||||
tokenNum += tokensPerMessage
|
tokenNum += tokensPerMessage
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if len(message.Content) > 0 {
|
if len(message.Content) > 0 {
|
||||||
if message.IsStringContent() {
|
if message.Name != nil {
|
||||||
stringContent := message.StringContent()
|
tokenNum += tokensPerName
|
||||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||||
if message.Name != nil {
|
}
|
||||||
tokenNum += tokensPerName
|
arrayContent := message.ParseContent()
|
||||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
for _, m := range arrayContent {
|
||||||
}
|
if m.Type == dto.ContentTypeImageURL {
|
||||||
} else {
|
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||||
arrayContent := message.ParseContent()
|
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
|
||||||
for _, m := range arrayContent {
|
if err != nil {
|
||||||
if m.Type == dto.ContentTypeImageURL {
|
return 0, err
|
||||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
|
||||||
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
tokenNum += imageTokenNum
|
|
||||||
log.Printf("image token num: %d", imageTokenNum)
|
|
||||||
} else if m.Type == dto.ContentTypeInputAudio {
|
|
||||||
// TODO: 音频token数量计算
|
|
||||||
tokenNum += 100
|
|
||||||
} else {
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, m.Text)
|
|
||||||
}
|
}
|
||||||
|
tokenNum += imageTokenNum
|
||||||
|
log.Printf("image token num: %d", imageTokenNum)
|
||||||
|
} else if m.Type == dto.ContentTypeInputAudio {
|
||||||
|
// TODO: 音频token数量计算
|
||||||
|
tokenNum += 100
|
||||||
|
} else {
|
||||||
|
tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user