feat: update Usage struct to support dynamic token handling with ceil function #1503

This commit is contained in:
CaIon
2025-08-07 15:40:12 +08:00
parent 0c0caad827
commit 71c39c9893
3 changed files with 124 additions and 11 deletions

View File

@@ -3,6 +3,8 @@ package dto
import (
"encoding/json"
"fmt"
"math"
"one-api/common"
"one-api/types"
)
@@ -202,13 +204,124 @@ type Usage struct {
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
InputTokens any `json:"input_tokens"`
OutputTokens any `json:"output_tokens"`
//CacheReadInputTokens any `json:"cache_read_input_tokens,omitempty"`
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
// OpenRouter Params
Cost any `json:"cost,omitempty"`
}
func (u *Usage) UnmarshalJSON(data []byte) error {
// first normal unmarshal
if err := common.Unmarshal(data, u); err != nil {
return fmt.Errorf("unmarshal Usage failed: %w", err)
}
// then ceil the input and output tokens
ceil := func(val any) int {
switch v := val.(type) {
case float64:
return int(math.Ceil(v))
case int:
return v
case string:
var intVal int
_, err := fmt.Sscanf(v, "%d", &intVal)
if err != nil {
return 0 // or handle error appropriately
}
return intVal
default:
return 0 // or handle error appropriately
}
}
// input_tokens must be int
if u.InputTokens != nil {
u.InputTokens = ceil(u.InputTokens)
}
if u.OutputTokens != nil {
u.OutputTokens = ceil(u.OutputTokens)
}
return nil
}
func (u *Usage) GetInputTokens() int {
if u.InputTokens == nil {
return 0
}
switch v := u.InputTokens.(type) {
case int:
return v
case float64:
return int(math.Ceil(v))
case string:
var intVal int
_, err := fmt.Sscanf(v, "%d", &intVal)
if err != nil {
return 0 // or handle error appropriately
}
return intVal
default:
return 0 // or handle error appropriately
}
}
func (u *Usage) GetOutputTokens() int {
if u.OutputTokens == nil {
return 0
}
switch v := u.OutputTokens.(type) {
case int:
return v
case float64:
return int(math.Ceil(v))
case string:
var intVal int
_, err := fmt.Sscanf(v, "%d", &intVal)
if err != nil {
return 0 // or handle error appropriately
}
return intVal
default:
return 0 // or handle error appropriately
}
}
//func (u *Usage) MarshalJSON() ([]byte, error) {
// ceil := func(val any) int {
// switch v := val.(type) {
// case float64:
// return int(math.Ceil(v))
// case int:
// return v
// case string:
// var intVal int
// _, err := fmt.Sscanf(v, "%d", &intVal)
// if err != nil {
// return 0 // or handle error appropriately
// }
// return intVal
// default:
// return 0 // or handle error appropriately
// }
// }
//
// // input_tokens must be int
// if u.InputTokens != nil {
// u.InputTokens = ceil(u.InputTokens)
// }
// if u.OutputTokens != nil {
// u.OutputTokens = ceil(u.OutputTokens)
// }
//
// // done
// return common.Marshal(u)
//}
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
CachedCreationTokens int `json:"-"`

View File

@@ -570,11 +570,11 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
// because the upstream has already consumed resources and returned content
// We should still perform billing even if parsing fails
// format
if usageResp.InputTokens > 0 {
usageResp.PromptTokens += usageResp.InputTokens
if usageResp.GetInputTokens() > 0 {
usageResp.PromptTokens += usageResp.GetInputTokens()
}
if usageResp.OutputTokens > 0 {
usageResp.CompletionTokens += usageResp.OutputTokens
if usageResp.GetOutputTokens() > 0 {
usageResp.CompletionTokens += usageResp.GetOutputTokens()
}
if usageResp.InputTokensDetails != nil {
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens

View File

@@ -38,8 +38,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
// compute usage
usage := dto.Usage{}
if responsesResponse.Usage != nil {
usage.PromptTokens = responsesResponse.Usage.InputTokens
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
usage.PromptTokens = responsesResponse.Usage.GetInputTokens()
usage.CompletionTokens = responsesResponse.Usage.GetOutputTokens()
usage.TotalTokens = responsesResponse.Usage.TotalTokens
if responsesResponse.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
@@ -70,8 +70,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
switch streamResponse.Type {
case "response.completed":
if streamResponse.Response.Usage != nil {
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.PromptTokens = streamResponse.Response.Usage.GetInputTokens()
usage.CompletionTokens = streamResponse.Response.Usage.GetOutputTokens()
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
if streamResponse.Response.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens