fix: prompt calculation

User will correctly get estimated prompt usage when upstream returns either zero or nothing.
This commit is contained in:
funnycups
2025-08-16 22:54:00 +08:00
committed by GitHub
parent 206ed55db4
commit e3473e3c39
2 changed files with 12 additions and 6 deletions

View File

@@ -128,6 +128,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
return return
} }
relayInfo.SetPromptTokens(tokens)
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta) priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
if err != nil { if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError) newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)

View File

@@ -197,22 +197,26 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
forceFormat = true forceFormat = true
} }
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { usageModified := false
completionTokens := 0 if simpleResponse.Usage.PromptTokens == 0 {
for _, choice := range simpleResponse.Choices { completionTokens := simpleResponse.Usage.CompletionTokens
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) if completionTokens == 0 {
completionTokens += ctkm for _, choice := range simpleResponse.Choices {
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
} }
simpleResponse.Usage = dto.Usage{ simpleResponse.Usage = dto.Usage{
PromptTokens: info.PromptTokens, PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens, TotalTokens: info.PromptTokens + completionTokens,
} }
usageModified = true
} }
switch info.RelayFormat { switch info.RelayFormat {
case types.RelayFormatOpenAI: case types.RelayFormatOpenAI:
if forceFormat { if forceFormat || usageModified {
responseBody, err = common.Marshal(simpleResponse) responseBody, err = common.Marshal(simpleResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)