refactor: token counter logic
This commit is contained in:
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
||||
toolTokens := CountTokenInput(countStr, request.Model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
||||
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens, err := CountTokenInput(request.System, model)
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
switch request.Type {
|
||||
case dto.RealtimeEventTypeSessionUpdate:
|
||||
if request.Session != nil {
|
||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||||
textToken += msgTokens
|
||||
}
|
||||
case dto.RealtimeEventResponseAudioDelta:
|
||||
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
audioToken += atk
|
||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||
// count text token
|
||||
tkm, err := CountTextToken(request.Delta, model)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
||||
}
|
||||
tkm := CountTextToken(request.Delta, model)
|
||||
textToken += tkm
|
||||
case dto.RealtimeEventInputAudioBufferAppend:
|
||||
// count audio token
|
||||
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
case "message":
|
||||
for _, content := range request.Item.Content {
|
||||
if content.Type == "input_text" {
|
||||
tokens, err := CountTextToken(content.Text, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
tokens := CountTextToken(content.Text, model)
|
||||
textToken += tokens
|
||||
}
|
||||
}
|
||||
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
if !info.IsFirstRequest {
|
||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||
for _, tool := range info.RealtimeTools {
|
||||
toolTokens, err := CountTokenInput(tool, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
toolTokens := CountTokenInput(tool, model)
|
||||
textToken += 8
|
||||
textToken += toolTokens
|
||||
}
|
||||
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string) (int, error) {
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTextToken(v, model)
|
||||
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
||||
tkm := CountTokenInput(tool.Function.Name, model)
|
||||
tokens += tkm
|
||||
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
||||
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTTSToken(text string, model string) (int, error) {
|
||||
func CountTTSToken(text string, model string) int {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text), nil
|
||||
return utf8.RuneCountInString(text)
|
||||
} else {
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTextToken(text string, model string) (int, error) {
|
||||
var err error
|
||||
func CountTextToken(text string, model string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text), err
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
// return 0, errors.New("unknown relay mode")
|
||||
//}
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm, err := CountTextToken(responseText, modeName)
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
return usage
|
||||
}
|
||||
|
||||
func ValidUsage(usage *dto.Usage) bool {
|
||||
|
||||
Reference in New Issue
Block a user