fix: improve prompt cache tracking
This commit is contained in:
@@ -13,6 +13,13 @@ import (
|
||||
|
||||
const defaultPromptCacheTTL = 5 * time.Minute
|
||||
|
||||
// Anthropic requires cached prefixes to reach a minimum token count before
|
||||
// caching takes effect. Breakpoints below this threshold are excluded from
|
||||
// matching and storage to avoid reporting unrealistic 100% cache hits on
|
||||
// short requests.
|
||||
const defaultMinCacheableTokens = 1024
|
||||
const opusMinCacheableTokens = 4096
|
||||
|
||||
type promptCacheUsage struct {
|
||||
CacheCreationInputTokens int
|
||||
CacheReadInputTokens int
|
||||
@@ -29,6 +36,15 @@ type promptCacheBreakpoint struct {
|
||||
type promptCacheProfile struct {
|
||||
Breakpoints []promptCacheBreakpoint
|
||||
TotalInputTokens int
|
||||
Model string
|
||||
}
|
||||
|
||||
func minCacheableTokensForModel(model string) int {
|
||||
lower := strings.ToLower(model)
|
||||
if strings.Contains(lower, "opus") {
|
||||
return opusMinCacheableTokens
|
||||
}
|
||||
return defaultMinCacheableTokens
|
||||
}
|
||||
|
||||
type promptCacheEntry struct {
|
||||
@@ -61,13 +77,27 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
|
||||
hasher := sha256.New()
|
||||
breakpoints := make([]promptCacheBreakpoint, 0)
|
||||
cumulativeTokens := 0
|
||||
var activeTTL time.Duration
|
||||
|
||||
for _, block := range blocks {
|
||||
canonical := canonicalizeCacheValue(block.Value)
|
||||
writeHashChunk(hasher, canonical)
|
||||
cumulativeTokens += block.Tokens
|
||||
|
||||
if block.TTL <= 0 {
|
||||
// Determine whether this block acts as a cache breakpoint:
|
||||
// 1) Explicit cache_control on the block itself.
|
||||
// 2) Once any explicit breakpoint has been seen, every message-end
|
||||
// boundary becomes an implicit breakpoint so that multi-turn
|
||||
// conversations can hit earlier stored prefixes.
|
||||
breakpointTTL := time.Duration(0)
|
||||
if block.TTL > 0 {
|
||||
breakpointTTL = block.TTL
|
||||
activeTTL = block.TTL
|
||||
} else if block.IsMessageEnd && activeTTL > 0 {
|
||||
breakpointTTL = activeTTL
|
||||
}
|
||||
|
||||
if breakpointTTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -76,7 +106,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
|
||||
breakpoints = append(breakpoints, promptCacheBreakpoint{
|
||||
Fingerprint: fingerprint,
|
||||
CumulativeTokens: cumulativeTokens,
|
||||
TTL: block.TTL,
|
||||
TTL: breakpointTTL,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -91,6 +121,7 @@ func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTo
|
||||
return &promptCacheProfile{
|
||||
Breakpoints: breakpoints,
|
||||
TotalInputTokens: totalInputTokens,
|
||||
Model: req.Model,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,6 +130,7 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
|
||||
return promptCacheUsage{}
|
||||
}
|
||||
|
||||
minTokens := minCacheableTokensForModel(profile.Model)
|
||||
last := profile.Breakpoints[len(profile.Breakpoints)-1]
|
||||
lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens)
|
||||
now := time.Now()
|
||||
@@ -109,18 +141,35 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
|
||||
|
||||
entries := t.entriesByAccount[accountID]
|
||||
if len(entries) == 0 {
|
||||
// First request for this account: report creation only if above threshold.
|
||||
effectiveCreation := lastTokens
|
||||
if effectiveCreation < minTokens {
|
||||
effectiveCreation = 0
|
||||
}
|
||||
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0)
|
||||
return promptCacheUsage{
|
||||
CacheCreationInputTokens: lastTokens,
|
||||
CacheCreationInputTokens: effectiveCreation,
|
||||
CacheReadInputTokens: 0,
|
||||
CacheCreation5mInputTokens: cache5m,
|
||||
CacheCreation1hInputTokens: cache1h,
|
||||
}
|
||||
}
|
||||
|
||||
// Cap cacheable tokens at 85% of total input to ensure a realistic
|
||||
// uncached portion. The newest content in a request is never fully
|
||||
// served from cache on the current turn.
|
||||
maxCacheable := int(float64(profile.TotalInputTokens) * 0.85)
|
||||
if lastTokens > maxCacheable {
|
||||
lastTokens = maxCacheable
|
||||
}
|
||||
|
||||
matchedTokens := 0
|
||||
for i := len(profile.Breakpoints) - 1; i >= 0; i-- {
|
||||
breakpoint := profile.Breakpoints[i]
|
||||
// Skip breakpoints below the minimum cacheable token threshold.
|
||||
if breakpoint.CumulativeTokens < minTokens {
|
||||
continue
|
||||
}
|
||||
entry, ok := entries[breakpoint.Fingerprint]
|
||||
if !ok || entry.ExpiresAt.Before(now) {
|
||||
continue
|
||||
@@ -128,6 +177,9 @@ func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfi
|
||||
entry.ExpiresAt = now.Add(entry.TTL)
|
||||
entries[breakpoint.Fingerprint] = entry
|
||||
matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
|
||||
if matchedTokens > lastTokens {
|
||||
matchedTokens = lastTokens
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@@ -146,6 +198,7 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil
|
||||
return
|
||||
}
|
||||
|
||||
minTokens := minCacheableTokensForModel(profile.Model)
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
@@ -158,6 +211,10 @@ func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfil
|
||||
}
|
||||
|
||||
for _, breakpoint := range profile.Breakpoints {
|
||||
// Skip breakpoints below the minimum cacheable token threshold.
|
||||
if breakpoint.CumulativeTokens < minTokens {
|
||||
continue
|
||||
}
|
||||
entries[breakpoint.Fingerprint] = promptCacheEntry{
|
||||
ExpiresAt: now.Add(breakpoint.TTL),
|
||||
TTL: breakpoint.TTL,
|
||||
@@ -179,9 +236,10 @@ func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) {
|
||||
}
|
||||
|
||||
type cacheablePromptBlock struct {
|
||||
Value interface{}
|
||||
Tokens int
|
||||
TTL time.Duration
|
||||
Value interface{}
|
||||
Tokens int
|
||||
TTL time.Duration
|
||||
IsMessageEnd bool
|
||||
}
|
||||
|
||||
func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock {
|
||||
@@ -234,14 +292,14 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{})
|
||||
"type": "text",
|
||||
"text": v,
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
case []interface{}:
|
||||
for i, block := range v {
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "system",
|
||||
"system_index": i,
|
||||
"block": block,
|
||||
})
|
||||
}, false)
|
||||
}
|
||||
case []string:
|
||||
for i, block := range v {
|
||||
@@ -252,7 +310,7 @@ func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{})
|
||||
"type": "text",
|
||||
"text": block,
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,8 +328,9 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
|
||||
"type": "text",
|
||||
"text": content,
|
||||
},
|
||||
})
|
||||
}, true)
|
||||
case []interface{}:
|
||||
lastIdx := len(content) - 1
|
||||
for blockIndex, block := range content {
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "message",
|
||||
@@ -279,7 +338,7 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
|
||||
"role": role,
|
||||
"block_index": blockIndex,
|
||||
"block": block,
|
||||
})
|
||||
}, blockIndex == lastIdx)
|
||||
}
|
||||
default:
|
||||
if content != nil {
|
||||
@@ -289,22 +348,70 @@ func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int,
|
||||
"role": role,
|
||||
"block_index": 0,
|
||||
"block": content,
|
||||
})
|
||||
}, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}) {
|
||||
blockValue, _ := wrapper["block"]
|
||||
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}, isMessageEnd bool) {
|
||||
blockValue := wrapper["block"]
|
||||
ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue))
|
||||
|
||||
// Normalize volatile text (e.g. Claude Code's x-anthropic-billing-header
|
||||
// which drifts on every request) so that fingerprints remain stable across
|
||||
// requests within the same conversation.
|
||||
if normalized, changed := normalizeCacheBlockContent(blockValue); changed {
|
||||
cloned := make(map[string]interface{}, len(wrapper))
|
||||
for k, v := range wrapper {
|
||||
cloned[k] = v
|
||||
}
|
||||
cloned["block"] = normalized
|
||||
wrapper = cloned
|
||||
}
|
||||
|
||||
canonical := canonicalizeCacheValue(wrapper)
|
||||
*blocks = append(*blocks, cacheablePromptBlock{
|
||||
Value: wrapper,
|
||||
Tokens: estimateApproxTokens(canonical),
|
||||
TTL: ttl,
|
||||
Value: wrapper,
|
||||
Tokens: estimateApproxTokens(canonical),
|
||||
TTL: ttl,
|
||||
IsMessageEnd: isMessageEnd,
|
||||
})
|
||||
}
|
||||
|
||||
// normalizeCacheBlockContent replaces volatile but semantically irrelevant
|
||||
// fields with a placeholder so that the cumulative fingerprint stays stable
|
||||
// across requests in the same session. Currently handles:
|
||||
// - Claude Code's "x-anthropic-billing-header: ..." system text block
|
||||
// whose content drifts on every request (version, telemetry hash, etc.)
|
||||
func normalizeCacheBlockContent(value interface{}) (interface{}, bool) {
|
||||
blockMap, ok := value.(map[string]interface{})
|
||||
if !ok {
|
||||
return value, false
|
||||
}
|
||||
|
||||
// Only normalize text blocks (or blocks without an explicit type but containing text).
|
||||
if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" {
|
||||
return value, false
|
||||
}
|
||||
|
||||
text, ok := blockMap["text"].(string)
|
||||
if !ok {
|
||||
return value, false
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeft(text, " \t\r\n")
|
||||
if !strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") {
|
||||
return value, false
|
||||
}
|
||||
|
||||
cloned := make(map[string]interface{}, len(blockMap))
|
||||
for k, v := range blockMap {
|
||||
cloned[k] = v
|
||||
}
|
||||
cloned["text"] = "__anthropic_billing_header__"
|
||||
return cloned, true
|
||||
}
|
||||
|
||||
func extractPromptCacheTTL(value interface{}) time.Duration {
|
||||
block, ok := value.(map[string]interface{})
|
||||
if !ok {
|
||||
|
||||
Reference in New Issue
Block a user