fix: improve prompt cache tracking

This commit is contained in:
Quorinex
2026-05-11 15:05:20 +08:00
parent 9dbe0cb55f
commit 496b14df3f
3 changed files with 251 additions and 18 deletions

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 Quorinex
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -13,6 +13,13 @@ import (
const defaultPromptCacheTTL = 5 * time.Minute
// Anthropic requires cached prefixes to reach a minimum token count before
// caching takes effect. Breakpoints below this threshold are excluded from
// matching and storage to avoid reporting unrealistic 100% cache hits on
// short requests.
const defaultMinCacheableTokens = 1024
const opusMinCacheableTokens = 4096
type promptCacheUsage struct {
CacheCreationInputTokens int
CacheReadInputTokens int
@@ -29,6 +36,15 @@ type promptCacheBreakpoint struct {
type promptCacheProfile struct {
Breakpoints []promptCacheBreakpoint
TotalInputTokens int
Model string
}
func minCacheableTokensForModel(model string) int {
lower := strings.ToLower(model)
if strings.Contains(lower, "opus") {
return opusMinCacheableTokens
}
return defaultMinCacheableTokens
}
type promptCacheEntry struct {
@@ -61,13 +77,27 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
hasher := sha256.New()
breakpoints := make([]promptCacheBreakpoint, 0)
cumulativeTokens := 0
var activeTTL time.Duration
for _, block := range blocks {
canonical := canonicalizeCacheValue(block.Value)
writeHashChunk(hasher, canonical)
cumulativeTokens += block.Tokens
if block.TTL <= 0 {
// Determine whether this block acts as a cache breakpoint:
// 1) Explicit cache_control on the block itself.
// 2) Once any explicit breakpoint has been seen, every message-end
// boundary becomes an implicit breakpoint so that multi-turn
// conversations can hit earlier stored prefixes.
breakpointTTL := time.Duration(0)
if block.TTL > 0 {
breakpointTTL = block.TTL
activeTTL = block.TTL
} else if block.IsMessageEnd && activeTTL > 0 {
breakpointTTL = activeTTL
}
if breakpointTTL <= 0 {
continue
}
@@ -76,7 +106,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
breakpoints = append(breakpoints, promptCacheBreakpoint{
Fingerprint: fingerprint,
CumulativeTokens: cumulativeTokens,
TTL: block.TTL,
TTL: breakpointTTL,
})
}
@@ -91,6 +121,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
return &promptCacheProfile{
Breakpoints: breakpoints,
TotalInputTokens: totalInputTokens,
Model: req.Model,
}
}
@@ -99,6 +130,7 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
return promptCacheUsage{}
}
minTokens := minCacheableTokensForModel(profile.Model)
last := profile.Breakpoints[len(profile.Breakpoints)-1]
lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens)
now := time.Now()
@@ -109,18 +141,35 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
entries := t.entriesByAccount[accountID]
if len(entries) == 0 {
// First request for this account: report creation only if above threshold.
effectiveCreation := lastTokens
if effectiveCreation < minTokens {
effectiveCreation = 0
}
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0)
return promptCacheUsage{
CacheCreationInputTokens: lastTokens,
CacheCreationInputTokens: effectiveCreation,
CacheReadInputTokens: 0,
CacheCreation5mInputTokens: cache5m,
CacheCreation1hInputTokens: cache1h,
}
}
// Cap cacheable tokens at 85% of total input to ensure a realistic
// uncached portion. The newest content in a request is never fully
// served from cache on the current turn.
maxCacheable := int(float64(profile.TotalInputTokens) * 0.85)
if lastTokens > maxCacheable {
lastTokens = maxCacheable
}
matchedTokens := 0
for i := len(profile.Breakpoints) - 1; i >= 0; i-- {
breakpoint := profile.Breakpoints[i]
// Skip breakpoints below the minimum cacheable token threshold.
if breakpoint.CumulativeTokens < minTokens {
continue
}
entry, ok := entries[breakpoint.Fingerprint]
if !ok || entry.ExpiresAt.Before(now) {
continue
@@ -128,6 +177,9 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
entry.ExpiresAt = now.Add(entry.TTL)
entries[breakpoint.Fingerprint] = entry
matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
if matchedTokens > lastTokens {
matchedTokens = lastTokens
}
break
}
@@ -146,6 +198,7 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil
return
}
minTokens := minCacheableTokensForModel(profile.Model)
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
@@ -158,6 +211,10 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil
}
for _, breakpoint := range profile.Breakpoints {
// Skip breakpoints below the minimum cacheable token threshold.
if breakpoint.CumulativeTokens < minTokens {
continue
}
entries[breakpoint.Fingerprint] = promptCacheEntry{
ExpiresAt: now.Add(breakpoint.TTL),
TTL: breakpoint.TTL,
@@ -179,9 +236,10 @@ func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) {
}
type cacheablePromptBlock struct {
Value interface{}
Tokens int
TTL time.Duration
Value interface{}
Tokens int
TTL time.Duration
IsMessageEnd bool
}
func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock {
@@ -234,14 +292,14 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{})
"type": "text",
"text": v,
},
})
}, false)
case []interface{}:
for i, block := range v {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": i,
"block": block,
})
}, false)
}
case []string:
for i, block := range v {
@@ -252,7 +310,7 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{})
"type": "text",
"text": block,
},
})
}, false)
}
}
}
@@ -270,8 +328,9 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
"type": "text",
"text": content,
},
})
}, true)
case []interface{}:
lastIdx := len(content) - 1
for blockIndex, block := range content {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
@@ -279,7 +338,7 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
"role": role,
"block_index": blockIndex,
"block": block,
})
}, blockIndex == lastIdx)
}
default:
if content != nil {
@@ -289,22 +348,70 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
"role": role,
"block_index": 0,
"block": content,
})
}, true)
}
}
}
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}) {
blockValue, _ := wrapper["block"]
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}, isMessageEnd bool) {
blockValue := wrapper["block"]
ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue))
// Normalize volatile text (e.g. Claude Code's x-anthropic-billing-header
// which drifts on every request) so that fingerprints remain stable across
// requests within the same conversation.
if normalized, changed := normalizeCacheBlockContent(blockValue); changed {
cloned := make(map[string]interface{}, len(wrapper))
for k, v := range wrapper {
cloned[k] = v
}
cloned["block"] = normalized
wrapper = cloned
}
canonical := canonicalizeCacheValue(wrapper)
*blocks = append(*blocks, cacheablePromptBlock{
Value: wrapper,
Tokens: estimateApproxTokens(canonical),
TTL: ttl,
Value: wrapper,
Tokens: estimateApproxTokens(canonical),
TTL: ttl,
IsMessageEnd: isMessageEnd,
})
}
// normalizeCacheBlockContent replaces volatile but semantically irrelevant
// fields with a placeholder so that the cumulative fingerprint stays stable
// across requests in the same session. Currently handles:
// - Claude Code's "x-anthropic-billing-header: ..." system text block
// whose content drifts on every request (version, telemetry hash, etc.)
func normalizeCacheBlockContent(value interface{}) (interface{}, bool) {
blockMap, ok := value.(map[string]interface{})
if !ok {
return value, false
}
// Only normalize text blocks (or blocks without an explicit type but containing text).
if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" {
return value, false
}
text, ok := blockMap["text"].(string)
if !ok {
return value, false
}
trimmed := strings.TrimLeft(text, " \t\r\n")
if !strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") {
return value, false
}
cloned := make(map[string]interface{}, len(blockMap))
for k, v := range blockMap {
cloned[k] = v
}
cloned["text"] = "__anthropic_billing_header__"
return cloned, true
}
func extractPromptCacheTTL(value interface{}) time.Duration {
block, ok := value.(map[string]interface{})
if !ok {

View File

@@ -1,18 +1,20 @@
package proxy
import (
"strings"
"testing"
"time"
)
func TestPromptCacheTrackerComputeAndUpdate(t *testing.T) {
tracker := newPromptCacheTracker(time.Hour)
longSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80)
req := &ClaudeRequest{
Model: "claude-sonnet-4.5",
System: []interface{}{
map[string]interface{}{
"type": "text",
"text": "system prompt",
"text": longSystem,
"cache_control": map[string]interface{}{
"type": "ephemeral",
},
@@ -71,3 +73,106 @@ func TestBuildClaudeUsageMapIncludesCacheFields(t *testing.T) {
t.Fatalf("unexpected ttl breakdown: %#v", creation)
}
}
// TestPromptCacheStableAcrossBillingHeaderDrift verifies that Claude Code's
// per-request "x-anthropic-billing-header: cc_version=...; cch=...;" system
// block (whose content drifts on every request) does not break cache hits.
// The normalization logic should ensure the same conversation still matches.
func TestPromptCacheStableAcrossBillingHeaderDrift(t *testing.T) {
tracker := newPromptCacheTracker(time.Hour)
mainSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80)
build := func(billingHdr string) *ClaudeRequest {
return &ClaudeRequest{
Model: "claude-sonnet-4.5",
System: []interface{}{
map[string]interface{}{
"type": "text",
"text": billingHdr,
},
map[string]interface{}{
"type": "text",
"text": mainSystem,
"cache_control": map[string]interface{}{
"type": "ephemeral",
},
},
},
Messages: []ClaudeMessage{{Role: "user", Content: "hello world"}},
}
}
req1 := build("x-anthropic-billing-header: cc_version=2.1.87.1; cch=aaaa;")
profile1 := tracker.BuildClaudeProfile(req1, 2048)
if profile1 == nil {
t.Fatalf("profile1 should be built")
}
first := tracker.Compute("acct-1", profile1)
if first.CacheReadInputTokens != 0 {
t.Fatalf("expected no cache read on first request, got %+v", first)
}
tracker.Update("acct-1", profile1)
req2 := build("x-anthropic-billing-header: cc_version=2.1.87.42; cch=bbbb; padding=xxyyzz;")
profile2 := tracker.BuildClaudeProfile(req2, 2048)
if profile2 == nil {
t.Fatalf("profile2 should be built")
}
second := tracker.Compute("acct-1", profile2)
if second.CacheReadInputTokens == 0 {
t.Fatalf("expected cache read after billing header drift, got %+v", second)
}
}
// TestPromptCacheImplicitBreakpointAtMessageEnd verifies that once any
// explicit cache_control breakpoint has been seen, subsequent message-end
// boundaries act as implicit breakpoints. This allows multi-turn conversations
// to hit earlier stored prefix fingerprints even when the newest messages
// lack explicit cache_control.
func TestPromptCacheImplicitBreakpointAtMessageEnd(t *testing.T) {
tracker := newPromptCacheTracker(time.Hour)
systemText := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80)
baseSystem := []interface{}{
map[string]interface{}{
"type": "text",
"text": systemText,
"cache_control": map[string]interface{}{
"type": "ephemeral",
},
},
}
// Round 1: single user message.
req1 := &ClaudeRequest{
Model: "claude-sonnet-4.5",
System: baseSystem,
Messages: []ClaudeMessage{{Role: "user", Content: "question one"}},
}
profile1 := tracker.BuildClaudeProfile(req1, 2048)
if profile1 == nil {
t.Fatalf("profile1 should be built")
}
tracker.Update("acct-1", profile1)
// Round 2: conversation continues with new messages. The latest user
// message has no explicit cache_control; it should still hit the stored
// prefix via the implicit message-end breakpoint.
req2 := &ClaudeRequest{
Model: "claude-sonnet-4.5",
System: baseSystem,
Messages: []ClaudeMessage{
{Role: "user", Content: "question one"},
{Role: "assistant", Content: "answer one"},
{Role: "user", Content: "follow-up question"},
},
}
profile2 := tracker.BuildClaudeProfile(req2, 4096)
if profile2 == nil {
t.Fatalf("profile2 should be built")
}
result := tracker.Compute("acct-1", profile2)
if result.CacheReadInputTokens == 0 {
t.Fatalf("expected cache read via implicit message-end breakpoint, got %+v", result)
}
}