diff --git a/relay/audio_handler.go b/relay/audio_handler.go index e55de042..96cf1019 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -66,7 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) + promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ba20adea..5e15d3a2 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) } else { if claudeInfo.Usage.PromptTokens == 0 { //上游出错 @@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau if common.DebugEnabled { common.SysError("claude response usage is not complete, maybe upstream error") } - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } @@ -618,7 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } if requestMode == RequestModeCompletion { - completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) + completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index a487429c..50d4928a 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err := scanner.Err(); err != nil { common.LogError(c, "error_scanning_stream_response: "+err.Error()) } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) @@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) for _, choice := range response.Choices { responseText += choice.Message.StringContent() } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) response.Usage = *usage response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) @@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) + usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 8a044bf2..29064242 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } }) if usage.PromptTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } return nil, usage } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 6db40213..e9719cb9 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 93e3e8d6..b3ae5927 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -250,7 +250,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } usage.CompletionTokens += nodeToken diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 3a017a11..1a497b9f 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -9,6 +9,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -75,8 +76,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.SetEventStreamHeaders(c) - // 本地统计的completion tokens - localCompletionTokens := 0 + responseText := strings.Builder{} helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse @@ -92,12 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info if part.InlineData != nil && part.InlineData.MimeType != "" { imageCount++ } - // 本地统计completion tokens - textTokens, err := service.CountTextToken(part.Text, info.UpstreamModelName) - if err != nil { - common.LogError(c, "error counting text token: "+err.Error()) + if part.Text != "" { + responseText.WriteString(part.Text) } - localCompletionTokens += textTokens } } @@ -133,13 +130,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens if usage.CompletionTokens == 0 { - usage.CompletionTokens = localCompletionTokens - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) } - // 计算最终使用量 - // usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为 //helper.Done(c) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4dc0fc60..71590cd6 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -8,7 +8,6 @@ import ( "math" "mime/multipart" "net/http" - "path/filepath" "one-api/common" "one-api/constant" "one-api/dto" @@ -16,6 +15,7 @@ import ( "one-api/relay/helper" "one-api/service" "os" + "path/filepath" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } else { if info.ChannelType == common.ChannelTypeDeepSeek { @@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } - + forceFormat := false if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { forceFormat = forceFmt @@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ @@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. - // Analogous to nginx's load balancing, it will only retry if it can't be requested or - // if the upstream returns a specific status code, once the upstream has already written the header, - // the subsequent failure of the response body should be regarded as a non-recoverable error, + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. defer resp.Body.Close() usage := &dto.Usage{} @@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) { if err = c.ShouldBind(&reqBody); err != nil { return 0, errors.WithStack(err) } - ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 + ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 reqFp, err := reqBody.File.Open() if err != nil { return 0, errors.WithStack(err) } - defer reqFp.Close() + defer reqFp.Close() tmpFp, err := os.CreateTemp("", "audio-*"+ext) if err != nil { diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 1d1e060e..da9382c3 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc tempStr := responseTextBuilder.String() if len(tempStr) > 0 { // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName) usage.CompletionTokens = completionTokens } } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 3a06e7ee..aee4a307 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 0c6f8641..9d3dbd67 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) + completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 44718a25..7ea3aae7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = tencentHandler(c, resp) } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index e019c2dc..408160fb 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel }) if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index fbf4990a..849c70da 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -15,7 +15,7 @@ import ( ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) return token } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index fa41cc7b..14d58cc5 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, return sensitiveWords, err } -func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay } inputText := strings.Join(inputTexts, "\n") - inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getGeminiInputTokens(req, relayInfo) + promptTokens := getGeminiInputTokens(req, relayInfo) if err != nil { return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) } diff --git a/relay/relay-text.go b/relay/relay-text.go index bf5a0259..db8d0d3b 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -251,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re case relayconstant.RelayModeChatCompletions: promptTokens, err = service.CountTokenChatRequest(info, *textRequest) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) default: err = errors.New("unknown relay mode") promptTokens = 0 diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 4d02c84f..319811b8 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -14,12 +14,10 @@ import ( ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { - token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) + token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) for _, document := range rerankRequest.Documents { - tkm, err := service.CountTokenInput(document, rerankRequest.Model) - if err == nil { - token += tkm - } + tkm := service.CountTokenInput(document, rerankRequest.Model) + token += tkm } return token } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 8e8a3451..9d4adf49 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom return sensitiveWords, err } -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { - inputTokens, err := service.CountTokenInput(req.Input, req.Model) +func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { + inputTokens := service.CountTokenInput(req.Input, req.Model) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -72,7 +72,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getInputTokens(req, relayInfo) + promptTokens := getInputTokens(req, relayInfo) if err != nil { return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) } diff --git a/service/token_counter.go b/service/token_counter.go index 82de0a05..53c6c2fa 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA countStr += fmt.Sprintf("%v", tool.Function.Parameters) } } - toolTokens, err := CountTokenInput(countStr, request.Model) + toolTokens := CountTokenInput(countStr, request.Model) if err != nil { return 0, err } @@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro // Count tokens in system message if request.System != "" { - systemTokens, err := CountTokenInput(request.System, model) + systemTokens := CountTokenInput(request.System, model) if err != nil { return 0, err } @@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, switch request.Type { case dto.RealtimeEventTypeSessionUpdate: if request.Session != nil { - msgTokens, err := CountTextToken(request.Session.Instructions, model) - if err != nil { - return 0, 0, err - } + msgTokens := CountTextToken(request.Session.Instructions, model) textToken += msgTokens } case dto.RealtimeEventResponseAudioDelta: @@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, audioToken += atk case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: // count text token - tkm, err := CountTextToken(request.Delta, model) - if err != nil { - return 0, 0, fmt.Errorf("error counting text token: %v", err) - } + tkm := CountTextToken(request.Delta, model) textToken += tkm case dto.RealtimeEventInputAudioBufferAppend: // count audio token @@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, case "message": for _, content := range request.Item.Content { if content.Type == "input_text" { - tokens, err := CountTextToken(content.Text, model) - if err != nil { - return 0, 0, err - } + tokens := CountTextToken(content.Text, model) textToken += tokens } } @@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, if !info.IsFirstRequest { if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { for _, tool := range info.RealtimeTools { - toolTokens, err := CountTokenInput(tool, model) - if err != nil { - return 0, 0, err - } + toolTokens := CountTokenInput(tool, model) textToken += 8 textToken += toolTokens } @@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod return tokenNum, nil } -func CountTokenInput(input any, model string) (int, error) { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: return CountTextToken(v, model) @@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) { func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { tokens := 0 for _, message := range messages { - tkm, _ := CountTokenInput(message.Delta.GetContentString(), model) + tkm := CountTokenInput(message.Delta.GetContentString(), model) tokens += tkm if message.Delta.ToolCalls != nil { for _, tool := range message.Delta.ToolCalls { - tkm, _ := CountTokenInput(tool.Function.Name, model) + tkm := CountTokenInput(tool.Function.Name, model) tokens += tkm - tkm, _ = CountTokenInput(tool.Function.Arguments, model) + tkm = CountTokenInput(tool.Function.Arguments, model) tokens += tkm } } @@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountTTSToken(text string, model string) (int, error) { +func CountTTSToken(text string, model string) int { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text), nil + return utf8.RuneCountInString(text) } else { return CountTextToken(text, model) } @@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) //} // CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTextToken(text string, model string) (int, error) { - var err error +func CountTextToken(text string, model string) int { + if text == "" { + return 0 + } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err + return getTokenNum(tokenEncoder, text) } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index c52e1e15..ca9c0830 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -16,13 +16,13 @@ import ( // return 0, errors.New("unknown relay mode") //} -func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTextToken(responseText, modeName) + ctkm := CountTextToken(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage, err + return usage } func ValidUsage(usage *dto.Usage) bool {