Revert "feat: update Usage struct to support dynamic token handling with ceil function #1503"
This reverts commit 71c39c9893.
This commit is contained in:
@@ -3,8 +3,6 @@ package dto
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -204,124 +202,13 @@ type Usage struct {
|
|||||||
|
|
||||||
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
|
||||||
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
|
||||||
InputTokens any `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens any `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
//CacheReadInputTokens any `json:"cache_read_input_tokens,omitempty"`
|
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
|
||||||
// OpenRouter Params
|
// OpenRouter Params
|
||||||
Cost any `json:"cost,omitempty"`
|
Cost any `json:"cost,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Usage) UnmarshalJSON(data []byte) error {
|
|
||||||
// first normal unmarshal
|
|
||||||
if err := common.Unmarshal(data, u); err != nil {
|
|
||||||
return fmt.Errorf("unmarshal Usage failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// then ceil the input and output tokens
|
|
||||||
ceil := func(val any) int {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case float64:
|
|
||||||
return int(math.Ceil(v))
|
|
||||||
case int:
|
|
||||||
return v
|
|
||||||
case string:
|
|
||||||
var intVal int
|
|
||||||
_, err := fmt.Sscanf(v, "%d", &intVal)
|
|
||||||
if err != nil {
|
|
||||||
return 0 // or handle error appropriately
|
|
||||||
}
|
|
||||||
return intVal
|
|
||||||
default:
|
|
||||||
return 0 // or handle error appropriately
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// input_tokens must be int
|
|
||||||
if u.InputTokens != nil {
|
|
||||||
u.InputTokens = ceil(u.InputTokens)
|
|
||||||
}
|
|
||||||
if u.OutputTokens != nil {
|
|
||||||
u.OutputTokens = ceil(u.OutputTokens)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *Usage) GetInputTokens() int {
|
|
||||||
if u.InputTokens == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := u.InputTokens.(type) {
|
|
||||||
case int:
|
|
||||||
return v
|
|
||||||
case float64:
|
|
||||||
return int(math.Ceil(v))
|
|
||||||
case string:
|
|
||||||
var intVal int
|
|
||||||
_, err := fmt.Sscanf(v, "%d", &intVal)
|
|
||||||
if err != nil {
|
|
||||||
return 0 // or handle error appropriately
|
|
||||||
}
|
|
||||||
return intVal
|
|
||||||
default:
|
|
||||||
return 0 // or handle error appropriately
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *Usage) GetOutputTokens() int {
|
|
||||||
if u.OutputTokens == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := u.OutputTokens.(type) {
|
|
||||||
case int:
|
|
||||||
return v
|
|
||||||
case float64:
|
|
||||||
return int(math.Ceil(v))
|
|
||||||
case string:
|
|
||||||
var intVal int
|
|
||||||
_, err := fmt.Sscanf(v, "%d", &intVal)
|
|
||||||
if err != nil {
|
|
||||||
return 0 // or handle error appropriately
|
|
||||||
}
|
|
||||||
return intVal
|
|
||||||
default:
|
|
||||||
return 0 // or handle error appropriately
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//func (u *Usage) MarshalJSON() ([]byte, error) {
|
|
||||||
// ceil := func(val any) int {
|
|
||||||
// switch v := val.(type) {
|
|
||||||
// case float64:
|
|
||||||
// return int(math.Ceil(v))
|
|
||||||
// case int:
|
|
||||||
// return v
|
|
||||||
// case string:
|
|
||||||
// var intVal int
|
|
||||||
// _, err := fmt.Sscanf(v, "%d", &intVal)
|
|
||||||
// if err != nil {
|
|
||||||
// return 0 // or handle error appropriately
|
|
||||||
// }
|
|
||||||
// return intVal
|
|
||||||
// default:
|
|
||||||
// return 0 // or handle error appropriately
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // input_tokens must be int
|
|
||||||
// if u.InputTokens != nil {
|
|
||||||
// u.InputTokens = ceil(u.InputTokens)
|
|
||||||
// }
|
|
||||||
// if u.OutputTokens != nil {
|
|
||||||
// u.OutputTokens = ceil(u.OutputTokens)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // done
|
|
||||||
// return common.Marshal(u)
|
|
||||||
//}
|
|
||||||
|
|
||||||
type InputTokenDetails struct {
|
type InputTokenDetails struct {
|
||||||
CachedTokens int `json:"cached_tokens"`
|
CachedTokens int `json:"cached_tokens"`
|
||||||
CachedCreationTokens int `json:"-"`
|
CachedCreationTokens int `json:"-"`
|
||||||
|
|||||||
@@ -570,11 +570,11 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|||||||
// because the upstream has already consumed resources and returned content
|
// because the upstream has already consumed resources and returned content
|
||||||
// We should still perform billing even if parsing fails
|
// We should still perform billing even if parsing fails
|
||||||
// format
|
// format
|
||||||
if usageResp.GetInputTokens() > 0 {
|
if usageResp.InputTokens > 0 {
|
||||||
usageResp.PromptTokens += usageResp.GetInputTokens()
|
usageResp.PromptTokens += usageResp.InputTokens
|
||||||
}
|
}
|
||||||
if usageResp.GetOutputTokens() > 0 {
|
if usageResp.OutputTokens > 0 {
|
||||||
usageResp.CompletionTokens += usageResp.GetOutputTokens()
|
usageResp.CompletionTokens += usageResp.OutputTokens
|
||||||
}
|
}
|
||||||
if usageResp.InputTokensDetails != nil {
|
if usageResp.InputTokensDetails != nil {
|
||||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
// compute usage
|
// compute usage
|
||||||
usage := dto.Usage{}
|
usage := dto.Usage{}
|
||||||
if responsesResponse.Usage != nil {
|
if responsesResponse.Usage != nil {
|
||||||
usage.PromptTokens = responsesResponse.Usage.GetInputTokens()
|
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||||
usage.CompletionTokens = responsesResponse.Usage.GetOutputTokens()
|
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
|
||||||
usage.TotalTokens = responsesResponse.Usage.TotalTokens
|
usage.TotalTokens = responsesResponse.Usage.TotalTokens
|
||||||
if responsesResponse.Usage.InputTokensDetails != nil {
|
if responsesResponse.Usage.InputTokensDetails != nil {
|
||||||
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||||
@@ -70,8 +70,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
|||||||
switch streamResponse.Type {
|
switch streamResponse.Type {
|
||||||
case "response.completed":
|
case "response.completed":
|
||||||
if streamResponse.Response.Usage != nil {
|
if streamResponse.Response.Usage != nil {
|
||||||
usage.PromptTokens = streamResponse.Response.Usage.GetInputTokens()
|
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||||
usage.CompletionTokens = streamResponse.Response.Usage.GetOutputTokens()
|
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||||
|
|||||||
Reference in New Issue
Block a user