From 496b14df3fbf4110d8837821041b0c8b4a191913 Mon Sep 17 00:00:00 2001 From: Quorinex Date: Mon, 11 May 2026 15:05:20 +0800 Subject: [PATCH] fix: improve prompt cache tracking --- LICENSE | 21 ++++++ proxy/cache_tracker.go | 141 +++++++++++++++++++++++++++++++----- proxy/cache_tracker_test.go | 107 ++++++++++++++++++++++++++- 3 files changed, 251 insertions(+), 18 deletions(-) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1bf685b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Quorinex + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/proxy/cache_tracker.go b/proxy/cache_tracker.go index 338f208..582754b 100644 --- a/proxy/cache_tracker.go +++ b/proxy/cache_tracker.go @@ -13,6 +13,13 @@ import ( const defaultPromptCacheTTL = 5 * time.Minute +// Anthropic requires cached prefixes to reach a minimum token count before +// caching takes effect. Breakpoints below this threshold are excluded from +// matching and storage to avoid reporting unrealistic 100% cache hits on +// short requests. +const defaultMinCacheableTokens = 1024 +const opusMinCacheableTokens = 4096 + type promptCacheUsage struct { CacheCreationInputTokens int CacheReadInputTokens int @@ -29,6 +36,15 @@ type promptCacheBreakpoint struct { type promptCacheProfile struct { Breakpoints []promptCacheBreakpoint TotalInputTokens int + Model string +} + +func minCacheableTokensForModel(model string) int { + lower := strings.ToLower(model) + if strings.Contains(lower, "opus") { + return opusMinCacheableTokens + } + return defaultMinCacheableTokens } type promptCacheEntry struct { @@ -61,13 +77,27 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo hasher := sha256.New() breakpoints := make([]promptCacheBreakpoint, 0) cumulativeTokens := 0 + var activeTTL time.Duration for _, block := range blocks { canonical := canonicalizeCacheValue(block.Value) writeHashChunk(hasher, canonical) cumulativeTokens += block.Tokens - if block.TTL <= 0 { + // Determine whether this block acts as a cache breakpoint: + // 1) Explicit cache_control on the block itself. + // 2) Once any explicit breakpoint has been seen, every message-end + // boundary becomes an implicit breakpoint so that multi-turn + // conversations can hit earlier stored prefixes. + breakpointTTL := time.Duration(0) + if block.TTL > 0 { + breakpointTTL = block.TTL + activeTTL = block.TTL + } else if block.IsMessageEnd && activeTTL > 0 { + breakpointTTL = activeTTL + } + + if breakpointTTL <= 0 { continue } @@ -76,7 +106,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo breakpoints = append(breakpoints, promptCacheBreakpoint{ Fingerprint: fingerprint, CumulativeTokens: cumulativeTokens, - TTL: block.TTL, + TTL: breakpointTTL, }) } @@ -91,6 +121,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo return &promptCacheProfile{ Breakpoints: breakpoints, TotalInputTokens: totalInputTokens, + Model: req.Model, } } @@ -99,6 +130,7 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi return promptCacheUsage{} } + minTokens := minCacheableTokensForModel(profile.Model) last := profile.Breakpoints[len(profile.Breakpoints)-1] lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens) now := time.Now() @@ -109,18 +141,35 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi entries := t.entriesByAccount[accountID] if len(entries) == 0 { + // First request for this account: report creation only if above threshold. + effectiveCreation := lastTokens + if effectiveCreation < minTokens { + effectiveCreation = 0 + } cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0) return promptCacheUsage{ - CacheCreationInputTokens: lastTokens, + CacheCreationInputTokens: effectiveCreation, CacheReadInputTokens: 0, CacheCreation5mInputTokens: cache5m, CacheCreation1hInputTokens: cache1h, } } + // Cap cacheable tokens at 85% of total input to ensure a realistic + // uncached portion. The newest content in a request is never fully + // served from cache on the current turn. + maxCacheable := int(float64(profile.TotalInputTokens) * 0.85) + if lastTokens > maxCacheable { + lastTokens = maxCacheable + } + matchedTokens := 0 for i := len(profile.Breakpoints) - 1; i >= 0; i-- { breakpoint := profile.Breakpoints[i] + // Skip breakpoints below the minimum cacheable token threshold. + if breakpoint.CumulativeTokens < minTokens { + continue + } entry, ok := entries[breakpoint.Fingerprint] if !ok || entry.ExpiresAt.Before(now) { continue @@ -128,6 +177,9 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi entry.ExpiresAt = now.Add(entry.TTL) entries[breakpoint.Fingerprint] = entry matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens) + if matchedTokens > lastTokens { + matchedTokens = lastTokens + } break } @@ -146,6 +198,7 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil return } + minTokens := minCacheableTokensForModel(profile.Model) now := time.Now() t.mu.Lock() defer t.mu.Unlock() @@ -158,6 +211,10 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil } for _, breakpoint := range profile.Breakpoints { + // Skip breakpoints below the minimum cacheable token threshold. + if breakpoint.CumulativeTokens < minTokens { + continue + } entries[breakpoint.Fingerprint] = promptCacheEntry{ ExpiresAt: now.Add(breakpoint.TTL), TTL: breakpoint.TTL, @@ -179,9 +236,10 @@ func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) { } type cacheablePromptBlock struct { - Value interface{} - Tokens int - TTL time.Duration + Value interface{} + Tokens int + TTL time.Duration + IsMessageEnd bool } func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock { @@ -234,14 +292,14 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{}) "type": "text", "text": v, }, - }) + }, false) case []interface{}: for i, block := range v { appendPromptBlock(blocks, map[string]interface{}{ "kind": "system", "system_index": i, "block": block, - }) + }, false) } case []string: for i, block := range v { @@ -252,7 +310,7 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{}) "type": "text", "text": block, }, - }) + }, false) } } } @@ -270,8 +328,9 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, "type": "text", "text": content, }, - }) + }, true) case []interface{}: + lastIdx := len(content) - 1 for blockIndex, block := range content { appendPromptBlock(blocks, map[string]interface{}{ "kind": "message", @@ -279,7 +338,7 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, "role": role, "block_index": blockIndex, "block": block, - }) + }, blockIndex == lastIdx) } default: if content != nil { @@ -289,22 +348,70 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, "role": role, "block_index": 0, "block": content, - }) + }, true) } } } -func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}) { - blockValue, _ := wrapper["block"] +func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}, isMessageEnd bool) { + blockValue := wrapper["block"] ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue)) + + // Normalize volatile text (e.g. Claude Code's x-anthropic-billing-header + // which drifts on every request) so that fingerprints remain stable across + // requests within the same conversation. + if normalized, changed := normalizeCacheBlockContent(blockValue); changed { + cloned := make(map[string]interface{}, len(wrapper)) + for k, v := range wrapper { + cloned[k] = v + } + cloned["block"] = normalized + wrapper = cloned + } + canonical := canonicalizeCacheValue(wrapper) *blocks = append(*blocks, cacheablePromptBlock{ - Value: wrapper, - Tokens: estimateApproxTokens(canonical), - TTL: ttl, + Value: wrapper, + Tokens: estimateApproxTokens(canonical), + TTL: ttl, + IsMessageEnd: isMessageEnd, }) } +// normalizeCacheBlockContent replaces volatile but semantically irrelevant +// fields with a placeholder so that the cumulative fingerprint stays stable +// across requests in the same session. Currently handles: +// - Claude Code's "x-anthropic-billing-header: ..." system text block +// whose content drifts on every request (version, telemetry hash, etc.) +func normalizeCacheBlockContent(value interface{}) (interface{}, bool) { + blockMap, ok := value.(map[string]interface{}) + if !ok { + return value, false + } + + // Only normalize text blocks (or blocks without an explicit type but containing text). + if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" { + return value, false + } + + text, ok := blockMap["text"].(string) + if !ok { + return value, false + } + + trimmed := strings.TrimLeft(text, " \t\r\n") + if !strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") { + return value, false + } + + cloned := make(map[string]interface{}, len(blockMap)) + for k, v := range blockMap { + cloned[k] = v + } + cloned["text"] = "__anthropic_billing_header__" + return cloned, true +} + func extractPromptCacheTTL(value interface{}) time.Duration { block, ok := value.(map[string]interface{}) if !ok { diff --git a/proxy/cache_tracker_test.go b/proxy/cache_tracker_test.go index 1beba02..aa620c8 100644 --- a/proxy/cache_tracker_test.go +++ b/proxy/cache_tracker_test.go @@ -1,18 +1,20 @@ package proxy import ( + "strings" "testing" "time" ) func TestPromptCacheTrackerComputeAndUpdate(t *testing.T) { tracker := newPromptCacheTracker(time.Hour) + longSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) req := &ClaudeRequest{ Model: "claude-sonnet-4.5", System: []interface{}{ map[string]interface{}{ "type": "text", - "text": "system prompt", + "text": longSystem, "cache_control": map[string]interface{}{ "type": "ephemeral", }, @@ -71,3 +73,106 @@ func TestBuildClaudeUsageMapIncludesCacheFields(t *testing.T) { t.Fatalf("unexpected ttl breakdown: %#v", creation) } } + +// TestPromptCacheStableAcrossBillingHeaderDrift verifies that Claude Code's +// per-request "x-anthropic-billing-header: cc_version=...; cch=...;" system +// block (whose content drifts on every request) does not break cache hits. +// The normalization logic should ensure the same conversation still matches. +func TestPromptCacheStableAcrossBillingHeaderDrift(t *testing.T) { + tracker := newPromptCacheTracker(time.Hour) + mainSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) + + build := func(billingHdr string) *ClaudeRequest { + return &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: []interface{}{ + map[string]interface{}{ + "type": "text", + "text": billingHdr, + }, + map[string]interface{}{ + "type": "text", + "text": mainSystem, + "cache_control": map[string]interface{}{ + "type": "ephemeral", + }, + }, + }, + Messages: []ClaudeMessage{{Role: "user", Content: "hello world"}}, + } + } + + req1 := build("x-anthropic-billing-header: cc_version=2.1.87.1; cch=aaaa;") + profile1 := tracker.BuildClaudeProfile(req1, 2048) + if profile1 == nil { + t.Fatalf("profile1 should be built") + } + first := tracker.Compute("acct-1", profile1) + if first.CacheReadInputTokens != 0 { + t.Fatalf("expected no cache read on first request, got %+v", first) + } + tracker.Update("acct-1", profile1) + + req2 := build("x-anthropic-billing-header: cc_version=2.1.87.42; cch=bbbb; padding=xxyyzz;") + profile2 := tracker.BuildClaudeProfile(req2, 2048) + if profile2 == nil { + t.Fatalf("profile2 should be built") + } + second := tracker.Compute("acct-1", profile2) + if second.CacheReadInputTokens == 0 { + t.Fatalf("expected cache read after billing header drift, got %+v", second) + } +} + +// TestPromptCacheImplicitBreakpointAtMessageEnd verifies that once any +// explicit cache_control breakpoint has been seen, subsequent message-end +// boundaries act as implicit breakpoints. This allows multi-turn conversations +// to hit earlier stored prefix fingerprints even when the newest messages +// lack explicit cache_control. +func TestPromptCacheImplicitBreakpointAtMessageEnd(t *testing.T) { + tracker := newPromptCacheTracker(time.Hour) + systemText := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) + + baseSystem := []interface{}{ + map[string]interface{}{ + "type": "text", + "text": systemText, + "cache_control": map[string]interface{}{ + "type": "ephemeral", + }, + }, + } + + // Round 1: single user message. + req1 := &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: baseSystem, + Messages: []ClaudeMessage{{Role: "user", Content: "question one"}}, + } + profile1 := tracker.BuildClaudeProfile(req1, 2048) + if profile1 == nil { + t.Fatalf("profile1 should be built") + } + tracker.Update("acct-1", profile1) + + // Round 2: conversation continues with new messages. The latest user + // message has no explicit cache_control; it should still hit the stored + // prefix via the implicit message-end breakpoint. + req2 := &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: baseSystem, + Messages: []ClaudeMessage{ + {Role: "user", Content: "question one"}, + {Role: "assistant", Content: "answer one"}, + {Role: "user", Content: "follow-up question"}, + }, + } + profile2 := tracker.BuildClaudeProfile(req2, 4096) + if profile2 == nil { + t.Fatalf("profile2 should be built") + } + result := tracker.Compute("acct-1", profile2) + if result.CacheReadInputTokens == 0 { + t.Fatalf("expected cache read via implicit message-end breakpoint, got %+v", result) + } +}