From d21d70a5cf384c97538a165db0ad474213b70949 Mon Sep 17 00:00:00 2001 From: sususu98 Date: Wed, 11 Feb 2026 15:41:54 +0800 Subject: [PATCH] fix: include Gemini thoughtsTokenCount in output token billing Gemini 2.5 Pro/Flash thinking models return thoughtsTokenCount separately from candidatesTokenCount in usageMetadata, but this field was not parsed or included in billing calculations, causing thinking tokens to be unbilled. - Add ThoughtsTokenCount field to GeminiUsageMetadata struct - Include thoughtsTokenCount in OutputTokens across all 3 Gemini usage parsing paths (non-streaming, streaming, compat layer) - Add tests covering thinking token scenarios Closes #554 --- .../internal/pkg/antigravity/gemini_types.go | 1 + .../pkg/antigravity/response_transformer.go | 2 +- .../pkg/antigravity/stream_transformer.go | 4 +- .../antigravity_gateway_service_test.go | 69 +++++++++++++++++++ .../service/gemini_messages_compat_service.go | 3 +- .../gemini_messages_compat_service_test.go | 69 +++++++++++++++++++ 6 files changed, 144 insertions(+), 4 deletions(-) diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index c1cc998c..32495827 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -155,6 +155,7 @@ type GeminiUsageMetadata struct { CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"` + ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) } // GeminiGroundingMetadata Gemini grounding 元数据(Web Search) diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index eb16f09d..463033f1 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -279,7 +279,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon if geminiResp.UsageMetadata != nil { cached := geminiResp.UsageMetadata.CachedContentTokenCount usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached - usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached } diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index b384658a..677435ad 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { if geminiResp.UsageMetadata != nil { cached := geminiResp.UsageMetadata.CachedContentTokenCount p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached - p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.cacheReadTokens = cached } @@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte if v1Resp.Response.UsageMetadata != nil { cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached - usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index a6a349c1..b312e5ca 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -553,6 +553,75 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { require.NotContains(t, body, "event: error") } +// TestHandleGeminiStreamingResponse_ThoughtsTokenCount +// 验证:Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens +func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + // promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90 + require.Equal(t, 90, result.usage.InputTokens) + // candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110 + require.Equal(t, 110, result.usage.OutputTokens) + require.Equal(t, 10, result.usage.CacheReadInputTokens) +} + +// TestHandleClaudeStreamingResponse_ThoughtsTokenCount +// 验证:Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens +func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + // promptTokenCount=50 → InputTokens=50 + require.Equal(t, 50, result.usage.InputTokens) + // candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35 + require.Equal(t, 35, result.usage.OutputTokens) +} + // --- 流式客户端断开检测测试 --- // TestStreamUpstreamResponse_ClientDisconnectDrainsUsage diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 7fa375ca..f3abd1dc 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2663,11 +2663,12 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage { prompt, _ := asInt(usageMeta["promptTokenCount"]) cand, _ := asInt(usageMeta["candidatesTokenCount"]) cached, _ := asInt(usageMeta["cachedContentTokenCount"]) + thoughts, _ := asInt(usageMeta["thoughtsTokenCount"]) // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 return &ClaudeUsage{ InputTokens: prompt - cached, - OutputTokens: cand, + OutputTokens: cand + thoughts, CacheReadInputTokens: cached, } } diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index f31b40ec..5bc26973 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "strings" "testing" + + "github.com/stretchr/testify/require" ) // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 @@ -203,3 +205,70 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s) } } + +func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) { + tests := []struct { + name string + resp map[string]any + wantInput int + wantOutput int + wantCacheRead int + wantNil bool + }{ + { + name: "with thoughtsTokenCount", + resp: map[string]any{ + "usageMetadata": map[string]any{ + "promptTokenCount": float64(100), + "candidatesTokenCount": float64(20), + "thoughtsTokenCount": float64(50), + }, + }, + wantInput: 100, + wantOutput: 70, + }, + { + name: "with thoughtsTokenCount and cache", + resp: map[string]any{ + "usageMetadata": map[string]any{ + "promptTokenCount": float64(100), + "candidatesTokenCount": float64(20), + "cachedContentTokenCount": float64(30), + "thoughtsTokenCount": float64(50), + }, + }, + wantInput: 70, + wantOutput: 70, + wantCacheRead: 30, + }, + { + name: "without thoughtsTokenCount (old model)", + resp: map[string]any{ + "usageMetadata": map[string]any{ + "promptTokenCount": float64(100), + "candidatesTokenCount": float64(20), + }, + }, + wantInput: 100, + wantOutput: 20, + }, + { + name: "no usageMetadata", + resp: map[string]any{}, + wantNil: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage := extractGeminiUsage(tt.resp) + if tt.wantNil { + require.Nil(t, usage) + return + } + require.NotNil(t, usage) + require.Equal(t, tt.wantInput, usage.InputTokens) + require.Equal(t, tt.wantOutput, usage.OutputTokens) + require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens) + }) + } +}