diff --git a/controller/relay.go b/controller/relay.go index c83127dd..c055ef71 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -128,6 +128,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { return } + relayInfo.SetPromptTokens(tokens) + priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) if err != nil { newAPIError = types.NewError(err, types.ErrorCodeModelPriceError) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 00dde46d..0debe48f 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -197,22 +197,26 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo forceFormat = true } - if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { - completionTokens := 0 - for _, choice := range simpleResponse.Choices { - ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) - completionTokens += ctkm + usageModified := false + if simpleResponse.Usage.PromptTokens == 0 { + completionTokens := simpleResponse.Usage.CompletionTokens + if completionTokens == 0 { + for _, choice := range simpleResponse.Choices { + ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + completionTokens += ctkm + } } simpleResponse.Usage = dto.Usage{ PromptTokens: info.PromptTokens, CompletionTokens: completionTokens, TotalTokens: info.PromptTokens + completionTokens, } + usageModified = true } switch info.RelayFormat { case types.RelayFormatOpenAI: - if forceFormat { + if forceFormat || usageModified { responseBody, err = common.Marshal(simpleResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody)