Merge pull request #555 from sususu98/fix/gemini-thoughts-token-billing
fix: include Gemini thoughtsTokenCount in output token billing
This commit is contained in:
@@ -155,6 +155,7 @@ type GeminiUsageMetadata struct {
|
|||||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||||
|
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
|||||||
if geminiResp.UsageMetadata != nil {
|
if geminiResp.UsageMetadata != nil {
|
||||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||||
usage.CacheReadInputTokens = cached
|
usage.CacheReadInputTokens = cached
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
|||||||
if geminiResp.UsageMetadata != nil {
|
if geminiResp.UsageMetadata != nil {
|
||||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
|
||||||
p.cacheReadTokens = cached
|
p.cacheReadTokens = cached
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
if v1Resp.Response.UsageMetadata != nil {
|
if v1Resp.Response.UsageMetadata != nil {
|
||||||
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
||||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
|
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
|
||||||
usage.CacheReadInputTokens = cached
|
usage.CacheReadInputTokens = cached
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -553,6 +553,75 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) {
|
|||||||
require.NotContains(t, body, "event: error")
|
require.NotContains(t, body, "event: error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestHandleGeminiStreamingResponse_ThoughtsTokenCount
|
||||||
|
// 验证:Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens
|
||||||
|
func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now())
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90
|
||||||
|
require.Equal(t, 90, result.usage.InputTokens)
|
||||||
|
// candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110
|
||||||
|
require.Equal(t, 110, result.usage.OutputTokens)
|
||||||
|
require.Equal(t, 10, result.usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleClaudeStreamingResponse_ThoughtsTokenCount
|
||||||
|
// 验证:Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens
|
||||||
|
func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
svc := newAntigravityTestService(&config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
})
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`)
|
||||||
|
fmt.Fprintln(pw, "")
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro")
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.usage)
|
||||||
|
// promptTokenCount=50 → InputTokens=50
|
||||||
|
require.Equal(t, 50, result.usage.InputTokens)
|
||||||
|
// candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35
|
||||||
|
require.Equal(t, 35, result.usage.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
// --- 流式客户端断开检测测试 ---
|
// --- 流式客户端断开检测测试 ---
|
||||||
|
|
||||||
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage
|
||||||
|
|||||||
@@ -2663,11 +2663,12 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
|
|||||||
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
||||||
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
||||||
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
|
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
|
||||||
|
thoughts, _ := asInt(usageMeta["thoughtsTokenCount"])
|
||||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||||
return &ClaudeUsage{
|
return &ClaudeUsage{
|
||||||
InputTokens: prompt - cached,
|
InputTokens: prompt - cached,
|
||||||
OutputTokens: cand,
|
OutputTokens: cand + thoughts,
|
||||||
CacheReadInputTokens: cached,
|
CacheReadInputTokens: cached,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||||
@@ -203,3 +205,70 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing
|
|||||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
resp map[string]any
|
||||||
|
wantInput int
|
||||||
|
wantOutput int
|
||||||
|
wantCacheRead int
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with thoughtsTokenCount",
|
||||||
|
resp: map[string]any{
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": float64(100),
|
||||||
|
"candidatesTokenCount": float64(20),
|
||||||
|
"thoughtsTokenCount": float64(50),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantInput: 100,
|
||||||
|
wantOutput: 70,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with thoughtsTokenCount and cache",
|
||||||
|
resp: map[string]any{
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": float64(100),
|
||||||
|
"candidatesTokenCount": float64(20),
|
||||||
|
"cachedContentTokenCount": float64(30),
|
||||||
|
"thoughtsTokenCount": float64(50),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantInput: 70,
|
||||||
|
wantOutput: 70,
|
||||||
|
wantCacheRead: 30,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without thoughtsTokenCount (old model)",
|
||||||
|
resp: map[string]any{
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": float64(100),
|
||||||
|
"candidatesTokenCount": float64(20),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantInput: 100,
|
||||||
|
wantOutput: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no usageMetadata",
|
||||||
|
resp: map[string]any{},
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
usage := extractGeminiUsage(tt.resp)
|
||||||
|
if tt.wantNil {
|
||||||
|
require.Nil(t, usage)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, tt.wantInput, usage.InputTokens)
|
||||||
|
require.Equal(t, tt.wantOutput, usage.OutputTokens)
|
||||||
|
require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user