From 71c39c98936417a6b2fd38f5a635d1d2bad11c24 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 7 Aug 2025 15:40:12 +0800 Subject: [PATCH] feat: update Usage struct to support dynamic token handling with ceil function #1503 --- dto/openai_response.go | 119 +++++++++++++++++++++++- relay/channel/openai/relay-openai.go | 8 +- relay/channel/openai/relay_responses.go | 8 +- 3 files changed, 124 insertions(+), 11 deletions(-) diff --git a/dto/openai_response.go b/dto/openai_response.go index b050cd03..7e6ee584 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -3,6 +3,8 @@ package dto import ( "encoding/json" "fmt" + "math" + "one-api/common" "one-api/types" ) @@ -202,13 +204,124 @@ type Usage struct { PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"` CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` + InputTokens any `json:"input_tokens"` + OutputTokens any `json:"output_tokens"` + //CacheReadInputTokens any `json:"cache_read_input_tokens,omitempty"` + InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` // OpenRouter Params Cost any `json:"cost,omitempty"` } +func (u *Usage) UnmarshalJSON(data []byte) error { + // first normal unmarshal + if err := common.Unmarshal(data, u); err != nil { + return fmt.Errorf("unmarshal Usage failed: %w", err) + } + + // then ceil the input and output tokens + ceil := func(val any) int { + switch v := val.(type) { + case float64: + return int(math.Ceil(v)) + case int: + return v + case string: + var intVal int + _, err := fmt.Sscanf(v, "%d", &intVal) + if err != nil { + return 0 // or handle error appropriately + } + return intVal + default: + return 0 // or handle error appropriately + } + } + + // input_tokens must be int + if u.InputTokens != nil { + u.InputTokens = ceil(u.InputTokens) + } + if u.OutputTokens != nil { + u.OutputTokens = ceil(u.OutputTokens) + } + return nil +} + +func (u *Usage) GetInputTokens() int { + if u.InputTokens == nil { + return 0 + } + + switch v := u.InputTokens.(type) { + case int: + return v + case float64: + return int(math.Ceil(v)) + case string: + var intVal int + _, err := fmt.Sscanf(v, "%d", &intVal) + if err != nil { + return 0 // or handle error appropriately + } + return intVal + default: + return 0 // or handle error appropriately + } +} + +func (u *Usage) GetOutputTokens() int { + if u.OutputTokens == nil { + return 0 + } + + switch v := u.OutputTokens.(type) { + case int: + return v + case float64: + return int(math.Ceil(v)) + case string: + var intVal int + _, err := fmt.Sscanf(v, "%d", &intVal) + if err != nil { + return 0 // or handle error appropriately + } + return intVal + default: + return 0 // or handle error appropriately + } +} + +//func (u *Usage) MarshalJSON() ([]byte, error) { +// ceil := func(val any) int { +// switch v := val.(type) { +// case float64: +// return int(math.Ceil(v)) +// case int: +// return v +// case string: +// var intVal int +// _, err := fmt.Sscanf(v, "%d", &intVal) +// if err != nil { +// return 0 // or handle error appropriately +// } +// return intVal +// default: +// return 0 // or handle error appropriately +// } +// } +// +// // input_tokens must be int +// if u.InputTokens != nil { +// u.InputTokens = ceil(u.InputTokens) +// } +// if u.OutputTokens != nil { +// u.OutputTokens = ceil(u.OutputTokens) +// } +// +// // done +// return common.Marshal(u) +//} + type InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` CachedCreationTokens int `json:"-"` diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 9ae0a200..f5e29209 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -570,11 +570,11 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h // because the upstream has already consumed resources and returned content // We should still perform billing even if parsing fails // format - if usageResp.InputTokens > 0 { - usageResp.PromptTokens += usageResp.InputTokens + if usageResp.GetInputTokens() > 0 { + usageResp.PromptTokens += usageResp.GetInputTokens() } - if usageResp.OutputTokens > 0 { - usageResp.CompletionTokens += usageResp.OutputTokens + if usageResp.GetOutputTokens() > 0 { + usageResp.CompletionTokens += usageResp.GetOutputTokens() } if usageResp.InputTokensDetails != nil { usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index bae6fcb6..2c996f91 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -38,8 +38,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http // compute usage usage := dto.Usage{} if responsesResponse.Usage != nil { - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.PromptTokens = responsesResponse.Usage.GetInputTokens() + usage.CompletionTokens = responsesResponse.Usage.GetOutputTokens() usage.TotalTokens = responsesResponse.Usage.TotalTokens if responsesResponse.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens @@ -70,8 +70,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp switch streamResponse.Type { case "response.completed": if streamResponse.Response.Usage != nil { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.PromptTokens = streamResponse.Response.Usage.GetInputTokens() + usage.CompletionTokens = streamResponse.Response.Usage.GetOutputTokens() usage.TotalTokens = streamResponse.Response.Usage.TotalTokens if streamResponse.Response.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens