fix: improve prompt cache tracking

This commit is contained in:
Quorinex
2026-05-11 15:05:20 +08:00
parent 9dbe0cb55f
commit 496b14df3f
3 changed files with 251 additions and 18 deletions

View File

@@ -13,6 +13,13 @@ import (
const defaultPromptCacheTTL = 5 * time.Minute
// Anthropic requires cached prefixes to reach a minimum token count before
// caching takes effect. Breakpoints below this threshold are excluded from
// matching and storage to avoid reporting unrealistic 100% cache hits on
// short requests.
const defaultMinCacheableTokens = 1024
const opusMinCacheableTokens = 4096
type promptCacheUsage struct {
CacheCreationInputTokens int
CacheReadInputTokens int
@@ -29,6 +36,15 @@ type promptCacheBreakpoint struct {
type promptCacheProfile struct {
Breakpoints []promptCacheBreakpoint
TotalInputTokens int
Model string
}
func minCacheableTokensForModel(model string) int {
lower := strings.ToLower(model)
if strings.Contains(lower, "opus") {
return opusMinCacheableTokens
}
return defaultMinCacheableTokens
}
type promptCacheEntry struct {
@@ -61,13 +77,27 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
hasher := sha256.New()
breakpoints := make([]promptCacheBreakpoint, 0)
cumulativeTokens := 0
var activeTTL time.Duration
for _, block := range blocks {
canonical := canonicalizeCacheValue(block.Value)
writeHashChunk(hasher, canonical)
cumulativeTokens += block.Tokens
if block.TTL <= 0 {
// Determine whether this block acts as a cache breakpoint:
// 1) Explicit cache_control on the block itself.
// 2) Once any explicit breakpoint has been seen, every message-end
// boundary becomes an implicit breakpoint so that multi-turn
// conversations can hit earlier stored prefixes.
breakpointTTL := time.Duration(0)
if block.TTL > 0 {
breakpointTTL = block.TTL
activeTTL = block.TTL
} else if block.IsMessageEnd && activeTTL > 0 {
breakpointTTL = activeTTL
}
if breakpointTTL <= 0 {
continue
}
@@ -76,7 +106,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
breakpoints = append(breakpoints, promptCacheBreakpoint{
Fingerprint: fingerprint,
CumulativeTokens: cumulativeTokens,
TTL: block.TTL,
TTL: breakpointTTL,
})
}
@@ -91,6 +121,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
return &promptCacheProfile{
Breakpoints: breakpoints,
TotalInputTokens: totalInputTokens,
Model: req.Model,
}
}
@@ -99,6 +130,7 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
return promptCacheUsage{}
}
minTokens := minCacheableTokensForModel(profile.Model)
last := profile.Breakpoints[len(profile.Breakpoints)-1]
lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens)
now := time.Now()
@@ -109,18 +141,35 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
entries := t.entriesByAccount[accountID]
if len(entries) == 0 {
// First request for this account: report creation only if above threshold.
effectiveCreation := lastTokens
if effectiveCreation < minTokens {
effectiveCreation = 0
}
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0)
return promptCacheUsage{
CacheCreationInputTokens: lastTokens,
CacheCreationInputTokens: effectiveCreation,
CacheReadInputTokens: 0,
CacheCreation5mInputTokens: cache5m,
CacheCreation1hInputTokens: cache1h,
}
}
// Cap cacheable tokens at 85% of total input to ensure a realistic
// uncached portion. The newest content in a request is never fully
// served from cache on the current turn.
maxCacheable := int(float64(profile.TotalInputTokens) * 0.85)
if lastTokens > maxCacheable {
lastTokens = maxCacheable
}
matchedTokens := 0
for i := len(profile.Breakpoints) - 1; i >= 0; i-- {
breakpoint := profile.Breakpoints[i]
// Skip breakpoints below the minimum cacheable token threshold.
if breakpoint.CumulativeTokens < minTokens {
continue
}
entry, ok := entries[breakpoint.Fingerprint]
if !ok || entry.ExpiresAt.Before(now) {
continue
@@ -128,6 +177,9 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
entry.ExpiresAt = now.Add(entry.TTL)
entries[breakpoint.Fingerprint] = entry
matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
if matchedTokens > lastTokens {
matchedTokens = lastTokens
}
break
}
@@ -146,6 +198,7 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil
return
}
minTokens := minCacheableTokensForModel(profile.Model)
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
@@ -158,6 +211,10 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil
}
for _, breakpoint := range profile.Breakpoints {
// Skip breakpoints below the minimum cacheable token threshold.
if breakpoint.CumulativeTokens < minTokens {
continue
}
entries[breakpoint.Fingerprint] = promptCacheEntry{
ExpiresAt: now.Add(breakpoint.TTL),
TTL: breakpoint.TTL,
@@ -179,9 +236,10 @@ func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) {
}
type cacheablePromptBlock struct {
Value interface{}
Tokens int
TTL time.Duration
Value interface{}
Tokens int
TTL time.Duration
IsMessageEnd bool
}
func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock {
@@ -234,14 +292,14 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{})
"type": "text",
"text": v,
},
})
}, false)
case []interface{}:
for i, block := range v {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": i,
"block": block,
})
}, false)
}
case []string:
for i, block := range v {
@@ -252,7 +310,7 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{})
"type": "text",
"text": block,
},
})
}, false)
}
}
}
@@ -270,8 +328,9 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
"type": "text",
"text": content,
},
})
}, true)
case []interface{}:
lastIdx := len(content) - 1
for blockIndex, block := range content {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
@@ -279,7 +338,7 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
"role": role,
"block_index": blockIndex,
"block": block,
})
}, blockIndex == lastIdx)
}
default:
if content != nil {
@@ -289,22 +348,70 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
"role": role,
"block_index": 0,
"block": content,
})
}, true)
}
}
}
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}) {
blockValue, _ := wrapper["block"]
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}, isMessageEnd bool) {
blockValue := wrapper["block"]
ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue))
// Normalize volatile text (e.g. Claude Code's x-anthropic-billing-header
// which drifts on every request) so that fingerprints remain stable across
// requests within the same conversation.
if normalized, changed := normalizeCacheBlockContent(blockValue); changed {
cloned := make(map[string]interface{}, len(wrapper))
for k, v := range wrapper {
cloned[k] = v
}
cloned["block"] = normalized
wrapper = cloned
}
canonical := canonicalizeCacheValue(wrapper)
*blocks = append(*blocks, cacheablePromptBlock{
Value: wrapper,
Tokens: estimateApproxTokens(canonical),
TTL: ttl,
Value: wrapper,
Tokens: estimateApproxTokens(canonical),
TTL: ttl,
IsMessageEnd: isMessageEnd,
})
}
// normalizeCacheBlockContent replaces volatile but semantically irrelevant
// fields with a placeholder so that the cumulative fingerprint stays stable
// across requests in the same session. Currently handles:
// - Claude Code's "x-anthropic-billing-header: ..." system text block
// whose content drifts on every request (version, telemetry hash, etc.)
func normalizeCacheBlockContent(value interface{}) (interface{}, bool) {
blockMap, ok := value.(map[string]interface{})
if !ok {
return value, false
}
// Only normalize text blocks (or blocks without an explicit type but containing text).
if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" {
return value, false
}
text, ok := blockMap["text"].(string)
if !ok {
return value, false
}
trimmed := strings.TrimLeft(text, " \t\r\n")
if !strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") {
return value, false
}
cloned := make(map[string]interface{}, len(blockMap))
for k, v := range blockMap {
cloned[k] = v
}
cloned["text"] = "__anthropic_billing_header__"
return cloned, true
}
func extractPromptCacheTTL(value interface{}) time.Duration {
block, ok := value.(map[string]interface{})
if !ok {