From a56d9ea98bb4d7adabd829b393e6585151060c91 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 20 Jun 2025 23:01:10 +0800 Subject: [PATCH 1/4] =?UTF-8?q?fix:=20gemini=20=E5=8E=9F=E7=94=9F=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E6=B5=81=E6=A8=A1=E5=BC=8F=E4=B8=AD=E6=96=AD=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E6=9C=AA=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/gemini/relay-gemini-native.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index cf7920dc..3a017a11 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -75,6 +75,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.SetEventStreamHeaders(c) + // 本地统计的completion tokens + localCompletionTokens := 0 + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse err := common.DecodeJsonStr(data, &geminiResponse) @@ -89,6 +92,12 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info if part.InlineData != nil && part.InlineData.MimeType != "" { imageCount++ } + // 本地统计completion tokens + textTokens, err := service.CountTextToken(part.Text, info.UpstreamModelName) + if err != nil { + common.LogError(c, "error counting text token: "+err.Error()) + } + localCompletionTokens += textTokens } } @@ -122,6 +131,12 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info } } + // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens + if usage.CompletionTokens == 0 { + usage.CompletionTokens = localCompletionTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + // 计算最终使用量 // usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens From a9e5d99ea3ec6ed1591b7e4006e51de9a718c1ba Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 21 Jun 2025 00:54:40 +0800 Subject: [PATCH 2/4] refactor: token counter logic --- relay/audio_handler.go | 2 +- relay/channel/claude/relay-claude.go | 6 +-- relay/channel/cloudflare/relay_cloudflare.go | 6 +-- relay/channel/cohere/relay-cohere.go | 2 +- relay/channel/coze/relay-coze.go | 2 +- relay/channel/dify/relay-dify.go | 2 +- relay/channel/gemini/relay-gemini-native.go | 17 +++----- relay/channel/openai/relay-openai.go | 18 ++++---- relay/channel/openai/relay_responses.go | 2 +- relay/channel/palm/adaptor.go | 2 +- relay/channel/palm/relay-palm.go | 2 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/xai/text.go | 2 +- relay/embedding_handler.go | 2 +- relay/gemini_handler.go | 8 ++-- relay/relay-text.go | 6 +-- relay/rerank_handler.go | 8 ++-- relay/responses_handler.go | 8 ++-- service/token_counter.go | 44 ++++++++------------ service/usage_helpr.go | 6 +-- 20 files changed, 64 insertions(+), 83 deletions(-) diff --git a/relay/audio_handler.go b/relay/audio_handler.go index e55de042..96cf1019 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -66,7 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) + promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ba20adea..5e15d3a2 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) } else { if claudeInfo.Usage.PromptTokens == 0 { //上游出错 @@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau if common.DebugEnabled { common.SysError("claude response usage is not complete, maybe upstream error") } - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } @@ -618,7 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } if requestMode == RequestModeCompletion { - completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) + completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index a487429c..50d4928a 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err := scanner.Err(); err != nil { common.LogError(c, "error_scanning_stream_response: "+err.Error()) } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) @@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) for _, choice := range response.Choices { responseText += choice.Message.StringContent() } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) response.Usage = *usage response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) @@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) + usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 8a044bf2..29064242 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } }) if usage.PromptTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } return nil, usage } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 6db40213..e9719cb9 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 93e3e8d6..b3ae5927 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -250,7 +250,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } usage.CompletionTokens += nodeToken diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 3a017a11..1a497b9f 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -9,6 +9,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -75,8 +76,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.SetEventStreamHeaders(c) - // 本地统计的completion tokens - localCompletionTokens := 0 + responseText := strings.Builder{} helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse @@ -92,12 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info if part.InlineData != nil && part.InlineData.MimeType != "" { imageCount++ } - // 本地统计completion tokens - textTokens, err := service.CountTextToken(part.Text, info.UpstreamModelName) - if err != nil { - common.LogError(c, "error counting text token: "+err.Error()) + if part.Text != "" { + responseText.WriteString(part.Text) } - localCompletionTokens += textTokens } } @@ -133,13 +130,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens if usage.CompletionTokens == 0 { - usage.CompletionTokens = localCompletionTokens - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) } - // 计算最终使用量 - // usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为 //helper.Done(c) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4dc0fc60..71590cd6 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -8,7 +8,6 @@ import ( "math" "mime/multipart" "net/http" - "path/filepath" "one-api/common" "one-api/constant" "one-api/dto" @@ -16,6 +15,7 @@ import ( "one-api/relay/helper" "one-api/service" "os" + "path/filepath" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } else { if info.ChannelType == common.ChannelTypeDeepSeek { @@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } - + forceFormat := false if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { forceFormat = forceFmt @@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ @@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. - // Analogous to nginx's load balancing, it will only retry if it can't be requested or - // if the upstream returns a specific status code, once the upstream has already written the header, - // the subsequent failure of the response body should be regarded as a non-recoverable error, + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. defer resp.Body.Close() usage := &dto.Usage{} @@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) { if err = c.ShouldBind(&reqBody); err != nil { return 0, errors.WithStack(err) } - ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 + ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 reqFp, err := reqBody.File.Open() if err != nil { return 0, errors.WithStack(err) } - defer reqFp.Close() + defer reqFp.Close() tmpFp, err := os.CreateTemp("", "audio-*"+ext) if err != nil { diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 1d1e060e..da9382c3 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc tempStr := responseTextBuilder.String() if len(tempStr) > 0 { // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName) usage.CompletionTokens = completionTokens } } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 3a06e7ee..aee4a307 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 0c6f8641..9d3dbd67 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) + completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 44718a25..7ea3aae7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = tencentHandler(c, resp) } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index e019c2dc..408160fb 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel }) if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index fbf4990a..849c70da 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -15,7 +15,7 @@ import ( ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) return token } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index fa41cc7b..14d58cc5 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, return sensitiveWords, err } -func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay } inputText := strings.Join(inputTexts, "\n") - inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getGeminiInputTokens(req, relayInfo) + promptTokens := getGeminiInputTokens(req, relayInfo) if err != nil { return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) } diff --git a/relay/relay-text.go b/relay/relay-text.go index bf5a0259..db8d0d3b 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -251,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re case relayconstant.RelayModeChatCompletions: promptTokens, err = service.CountTokenChatRequest(info, *textRequest) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) default: err = errors.New("unknown relay mode") promptTokens = 0 diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 4d02c84f..319811b8 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -14,12 +14,10 @@ import ( ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { - token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) + token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) for _, document := range rerankRequest.Documents { - tkm, err := service.CountTokenInput(document, rerankRequest.Model) - if err == nil { - token += tkm - } + tkm := service.CountTokenInput(document, rerankRequest.Model) + token += tkm } return token } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 8e8a3451..9d4adf49 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom return sensitiveWords, err } -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { - inputTokens, err := service.CountTokenInput(req.Input, req.Model) +func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { + inputTokens := service.CountTokenInput(req.Input, req.Model) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -72,7 +72,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getInputTokens(req, relayInfo) + promptTokens := getInputTokens(req, relayInfo) if err != nil { return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) } diff --git a/service/token_counter.go b/service/token_counter.go index 82de0a05..53c6c2fa 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA countStr += fmt.Sprintf("%v", tool.Function.Parameters) } } - toolTokens, err := CountTokenInput(countStr, request.Model) + toolTokens := CountTokenInput(countStr, request.Model) if err != nil { return 0, err } @@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro // Count tokens in system message if request.System != "" { - systemTokens, err := CountTokenInput(request.System, model) + systemTokens := CountTokenInput(request.System, model) if err != nil { return 0, err } @@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, switch request.Type { case dto.RealtimeEventTypeSessionUpdate: if request.Session != nil { - msgTokens, err := CountTextToken(request.Session.Instructions, model) - if err != nil { - return 0, 0, err - } + msgTokens := CountTextToken(request.Session.Instructions, model) textToken += msgTokens } case dto.RealtimeEventResponseAudioDelta: @@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, audioToken += atk case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: // count text token - tkm, err := CountTextToken(request.Delta, model) - if err != nil { - return 0, 0, fmt.Errorf("error counting text token: %v", err) - } + tkm := CountTextToken(request.Delta, model) textToken += tkm case dto.RealtimeEventInputAudioBufferAppend: // count audio token @@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, case "message": for _, content := range request.Item.Content { if content.Type == "input_text" { - tokens, err := CountTextToken(content.Text, model) - if err != nil { - return 0, 0, err - } + tokens := CountTextToken(content.Text, model) textToken += tokens } } @@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, if !info.IsFirstRequest { if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { for _, tool := range info.RealtimeTools { - toolTokens, err := CountTokenInput(tool, model) - if err != nil { - return 0, 0, err - } + toolTokens := CountTokenInput(tool, model) textToken += 8 textToken += toolTokens } @@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod return tokenNum, nil } -func CountTokenInput(input any, model string) (int, error) { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: return CountTextToken(v, model) @@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) { func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { tokens := 0 for _, message := range messages { - tkm, _ := CountTokenInput(message.Delta.GetContentString(), model) + tkm := CountTokenInput(message.Delta.GetContentString(), model) tokens += tkm if message.Delta.ToolCalls != nil { for _, tool := range message.Delta.ToolCalls { - tkm, _ := CountTokenInput(tool.Function.Name, model) + tkm := CountTokenInput(tool.Function.Name, model) tokens += tkm - tkm, _ = CountTokenInput(tool.Function.Arguments, model) + tkm = CountTokenInput(tool.Function.Arguments, model) tokens += tkm } } @@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountTTSToken(text string, model string) (int, error) { +func CountTTSToken(text string, model string) int { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text), nil + return utf8.RuneCountInString(text) } else { return CountTextToken(text, model) } @@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) //} // CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTextToken(text string, model string) (int, error) { - var err error +func CountTextToken(text string, model string) int { + if text == "" { + return 0 + } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err + return getTokenNum(tokenEncoder, text) } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index c52e1e15..ca9c0830 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -16,13 +16,13 @@ import ( // return 0, errors.New("unknown relay mode") //} -func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTextToken(responseText, modeName) + ctkm := CountTextToken(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage, err + return usage } func ValidUsage(usage *dto.Usage) bool { From 0708452939d610a73fd8530f6b8ddf48c2cf2ff2 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 21 Jun 2025 01:08:15 +0800 Subject: [PATCH 3/4] fix: improve usage calculation in GeminiTextGenerationStreamHandler --- relay/channel/gemini/relay-gemini-native.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 1a497b9f..39757cef 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -130,7 +130,13 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens if usage.CompletionTokens == 0 { - usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + str := responseText.String() + if len(str) > 0 { + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 空补全,不需要使用量 + usage = &dto.Usage{} + } } // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为 From 7afd3f97eec60111f18d231dcc9f9a6bc20045f5 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 21 Jun 2025 01:16:54 +0800 Subject: [PATCH 4/4] fix: remove unnecessary error handling in token counting functions --- relay/audio_handler.go | 3 --- relay/channel/claude/relay-claude.go | 3 --- relay/channel/coze/relay-coze.go | 12 +++++------- relay/channel/dify/relay-dify.go | 9 +-------- relay/responses_handler.go | 3 --- 5 files changed, 6 insertions(+), 24 deletions(-) diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 96cf1019..c1ce1a02 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -67,9 +67,6 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) - } preConsumedTokens = promptTokens relayInfo.PromptTokens = promptTokens } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 5e15d3a2..406ebc8a 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -619,9 +619,6 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } if requestMode == RequestModeCompletion { completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) - } claudeInfo.Usage.PromptTokens = info.PromptTokens claudeInfo.Usage.CompletionTokens = completionTokens claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index e9719cb9..ac76476f 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo var currentEvent string var currentData string - var usage dto.Usage + var usage = &dto.Usage{} for scanner.Scan() { line := scanner.Text() @@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if line == "" { if currentEvent != "" && currentData != "" { // handle last event - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) currentEvent = "" currentData = "" } @@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo // Last event if currentEvent != "" && currentData != "" { - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) } if err := scanner.Err(); err != nil { @@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo helper.Done(c) if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count")) } - return nil, &usage + return nil, usage } func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index b3ae5927..115aed1b 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re return true }) helper.Done(c) - err := resp.Body.Close() - if err != nil { - // return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - common.SysError("close_response_body_failed: " + err.Error()) - } if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } usage.CompletionTokens += nodeToken return nil, usage diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 9d4adf49..e744e354 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -73,9 +73,6 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) relayInfo.SetPromptTokens(promptTokens) } else { promptTokens := getInputTokens(req, relayInfo) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) - } c.Set("prompt_tokens", promptTokens) }