diff --git a/common/constants.go b/common/constants.go index 04fb1b9a..bcab24fc 100644 --- a/common/constants.go +++ b/common/constants.go @@ -276,7 +276,7 @@ var ChannelBaseURLs = []string{ "https://api.cohere.ai", //34 "https://api.minimax.chat", //35 "", //36 - "", //37 + "https://api.dify.ai", //37 "https://api.jina.ai", //38 "https://api.cloudflare.com", //39 "https://api.siliconflow.cn", //40 diff --git a/dto/openai_request.go b/dto/openai_request.go index 028e0286..88cb6c30 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -1,6 +1,9 @@ package dto -import "encoding/json" +import ( + "encoding/json" + "strings" +) type ResponseFormat struct { Type string `json:"type,omitempty"` @@ -153,11 +156,24 @@ func (m *Message) StringContent() string { if m.parsedStringContent != nil { return *m.parsedStringContent } + var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { + m.parsedStringContent = &stringContent return stringContent } - return string(m.Content) + + contentStr := new(strings.Builder) + arrayContent := m.ParseContent() + for _, content := range arrayContent { + if content.Type == ContentTypeText { + contentStr.WriteString(content.Text) + } + } + stringContent = contentStr.String() + m.parsedStringContent = &stringContent + + return stringContent } func (m *Message) SetStringContent(content string) { diff --git a/service/token_counter.go b/service/token_counter.go index 319c9b11..0a7e6de3 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -78,6 +78,9 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { + if text == "" { + return 0 + } return len(tokenEncoder.Encode(text, nil, nil)) } @@ -282,30 +285,25 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod tokenNum += tokensPerMessage tokenNum += getTokenNum(tokenEncoder, message.Role) if len(message.Content) > 0 { - if message.IsStringContent() { - stringContent := message.StringContent() - tokenNum += getTokenNum(tokenEncoder, stringContent) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) - } - } else { - arrayContent := message.ParseContent() - for _, m := range arrayContent { - if m.Type == dto.ContentTypeImageURL { - imageUrl := m.ImageUrl.(dto.MessageImageUrl) - imageTokenNum, err := getImageToken(info, &imageUrl, model, stream) - if err != nil { - return 0, err - } - tokenNum += imageTokenNum - log.Printf("image token num: %d", imageTokenNum) - } else if m.Type == dto.ContentTypeInputAudio { - // TODO: 音频token数量计算 - tokenNum += 100 - } else { - tokenNum += getTokenNum(tokenEncoder, m.Text) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + arrayContent := message.ParseContent() + for _, m := range arrayContent { + if m.Type == dto.ContentTypeImageURL { + imageUrl := m.ImageUrl.(dto.MessageImageUrl) + imageTokenNum, err := getImageToken(info, &imageUrl, model, stream) + if err != nil { + return 0, err } + tokenNum += imageTokenNum + log.Printf("image token num: %d", imageTokenNum) + } else if m.Type == dto.ContentTypeInputAudio { + // TODO: 音频token数量计算 + tokenNum += 100 + } else { + tokenNum += getTokenNum(tokenEncoder, m.Text) } } }