Files
kirogo/proxy/cache_tracker.go
huangzhenpc e8ab5b11e7
Some checks failed
Build Docker Image / build (push) Has been cancelled
fix: stop deducting simulated cache tokens from input_tokens
Kiro backend does not support Anthropic prompt cache protocol.
The local cache tracker simulates cache hits/creation for Claude Code
compatibility, but subtracting those values from input_tokens caused
the reported input_tokens to drop to single digits.

input_tokens now reflects the real value; cache_creation_input_tokens
and cache_read_input_tokens are still reported for protocol compliance.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:14:51 +08:00

615 lines
16 KiB
Go

package proxy
import (
"bytes"
"crypto/sha256"
"encoding/json"
"sort"
"strconv"
"strings"
"sync"
"time"
)
const defaultPromptCacheTTL = 5 * time.Minute
// Anthropic requires cached prefixes to reach a minimum token count before
// caching takes effect. Breakpoints below this threshold are excluded from
// matching and storage to avoid reporting unrealistic 100% cache hits on
// short requests.
const defaultMinCacheableTokens = 1024
const opusMinCacheableTokens = 4096
type promptCacheUsage struct {
CacheCreationInputTokens int
CacheReadInputTokens int
CacheCreation5mInputTokens int
CacheCreation1hInputTokens int
}
type promptCacheBreakpoint struct {
Fingerprint [32]byte
CumulativeTokens int
TTL time.Duration
}
type promptCacheProfile struct {
Breakpoints []promptCacheBreakpoint
TotalInputTokens int
Model string
}
func minCacheableTokensForModel(model string) int {
lower := strings.ToLower(model)
if strings.Contains(lower, "opus") {
return opusMinCacheableTokens
}
return defaultMinCacheableTokens
}
type promptCacheEntry struct {
ExpiresAt time.Time
TTL time.Duration
}
type promptCacheTracker struct {
mu sync.Mutex
entriesByAccount map[string]map[[32]byte]promptCacheEntry
maxSupportedTTL time.Duration
}
func newPromptCacheTracker(maxTTL time.Duration) *promptCacheTracker {
if maxTTL <= 0 {
maxTTL = defaultPromptCacheTTL
}
return &promptCacheTracker{
entriesByAccount: make(map[string]map[[32]byte]promptCacheEntry),
maxSupportedTTL: maxTTL,
}
}
func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTokens int) *promptCacheProfile {
blocks := flattenClaudeCacheBlocks(req)
if len(blocks) == 0 {
return nil
}
hasher := sha256.New()
breakpoints := make([]promptCacheBreakpoint, 0)
cumulativeTokens := 0
var activeTTL time.Duration
for _, block := range blocks {
canonical := canonicalizeCacheValue(block.Value)
writeHashChunk(hasher, canonical)
cumulativeTokens += block.Tokens
// Determine whether this block acts as a cache breakpoint:
// 1) Explicit cache_control on the block itself.
// 2) Once any explicit breakpoint has been seen, every message-end
// boundary becomes an implicit breakpoint so that multi-turn
// conversations can hit earlier stored prefixes.
breakpointTTL := time.Duration(0)
if block.TTL > 0 {
breakpointTTL = block.TTL
activeTTL = block.TTL
} else if block.IsMessageEnd && activeTTL > 0 {
breakpointTTL = activeTTL
}
if breakpointTTL <= 0 {
continue
}
var fingerprint [32]byte
copy(fingerprint[:], hasher.Sum(nil))
breakpoints = append(breakpoints, promptCacheBreakpoint{
Fingerprint: fingerprint,
CumulativeTokens: cumulativeTokens,
TTL: breakpointTTL,
})
}
if len(breakpoints) == 0 {
return nil
}
if totalInputTokens < cumulativeTokens {
totalInputTokens = cumulativeTokens
}
return &promptCacheProfile{
Breakpoints: breakpoints,
TotalInputTokens: totalInputTokens,
Model: req.Model,
}
}
func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfile) promptCacheUsage {
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
return promptCacheUsage{}
}
minTokens := minCacheableTokensForModel(profile.Model)
last := profile.Breakpoints[len(profile.Breakpoints)-1]
lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens)
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
t.pruneExpiredLocked(now)
entries := t.entriesByAccount[accountID]
if len(entries) == 0 {
// First request for this account: report creation only if above threshold.
effectiveCreation := lastTokens
if effectiveCreation < minTokens {
effectiveCreation = 0
}
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0)
return promptCacheUsage{
CacheCreationInputTokens: effectiveCreation,
CacheReadInputTokens: 0,
CacheCreation5mInputTokens: cache5m,
CacheCreation1hInputTokens: cache1h,
}
}
// Cap cacheable tokens at 85% of total input to ensure a realistic
// uncached portion. The newest content in a request is never fully
// served from cache on the current turn.
maxCacheable := int(float64(profile.TotalInputTokens) * 0.85)
if lastTokens > maxCacheable {
lastTokens = maxCacheable
}
matchedTokens := 0
for i := len(profile.Breakpoints) - 1; i >= 0; i-- {
breakpoint := profile.Breakpoints[i]
// Skip breakpoints below the minimum cacheable token threshold.
if breakpoint.CumulativeTokens < minTokens {
continue
}
entry, ok := entries[breakpoint.Fingerprint]
if !ok || entry.ExpiresAt.Before(now) {
continue
}
entry.ExpiresAt = now.Add(entry.TTL)
entries[breakpoint.Fingerprint] = entry
matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
if matchedTokens > lastTokens {
matchedTokens = lastTokens
}
break
}
creation := maxInt(lastTokens-matchedTokens, 0)
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, matchedTokens)
return promptCacheUsage{
CacheCreationInputTokens: creation,
CacheReadInputTokens: matchedTokens,
CacheCreation5mInputTokens: cache5m,
CacheCreation1hInputTokens: cache1h,
}
}
func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfile) {
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
return
}
minTokens := minCacheableTokensForModel(profile.Model)
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
t.pruneExpiredLocked(now)
entries := t.entriesByAccount[accountID]
if entries == nil {
entries = make(map[[32]byte]promptCacheEntry)
t.entriesByAccount[accountID] = entries
}
for _, breakpoint := range profile.Breakpoints {
// Skip breakpoints below the minimum cacheable token threshold.
if breakpoint.CumulativeTokens < minTokens {
continue
}
entries[breakpoint.Fingerprint] = promptCacheEntry{
ExpiresAt: now.Add(breakpoint.TTL),
TTL: breakpoint.TTL,
}
}
}
func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) {
for accountID, entries := range t.entriesByAccount {
for fingerprint, entry := range entries {
if !entry.ExpiresAt.After(now) {
delete(entries, fingerprint)
}
}
if len(entries) == 0 {
delete(t.entriesByAccount, accountID)
}
}
}
type cacheablePromptBlock struct {
Value interface{}
Tokens int
TTL time.Duration
IsMessageEnd bool
}
func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock {
blocks := make([]cacheablePromptBlock, 0)
blocks = append(blocks, buildCachePreludeBlock(req))
for toolIndex, tool := range req.Tools {
toolValue := map[string]interface{}{
"kind": "tool",
"tool_index": toolIndex,
"name": tool.Name,
"description": tool.Description,
"input_schema": tool.InputSchema,
}
blocks = append(blocks, cacheablePromptBlock{
Value: toolValue,
Tokens: estimateApproxTokens(canonicalizeCacheValue(toolValue)),
TTL: normalizePromptCacheTTL(extractPromptCacheTTL(tool)),
})
}
appendSystemCacheBlocks(&blocks, req.System)
for messageIndex, msg := range req.Messages {
appendMessageCacheBlocks(&blocks, messageIndex, msg)
}
return blocks
}
func buildCachePreludeBlock(req *ClaudeRequest) cacheablePromptBlock {
prelude := map[string]interface{}{
"kind": "request_prelude",
"model": req.Model,
"tool_choice": req.ToolChoice,
}
return cacheablePromptBlock{
Value: prelude,
Tokens: estimateApproxTokens(canonicalizeCacheValue(prelude)),
}
}
func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{}) {
switch v := system.(type) {
case string:
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": 0,
"block": map[string]interface{}{
"type": "text",
"text": v,
},
}, false)
case []interface{}:
for i, block := range v {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": i,
"block": block,
}, false)
}
case []string:
for i, block := range v {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": i,
"block": map[string]interface{}{
"type": "text",
"text": block,
},
}, false)
}
}
}
func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, msg ClaudeMessage) {
role := msg.Role
switch content := msg.Content.(type) {
case string:
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
"message_index": messageIndex,
"role": role,
"block_index": 0,
"block": map[string]interface{}{
"type": "text",
"text": content,
},
}, true)
case []interface{}:
lastIdx := len(content) - 1
for blockIndex, block := range content {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
"message_index": messageIndex,
"role": role,
"block_index": blockIndex,
"block": block,
}, blockIndex == lastIdx)
}
default:
if content != nil {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
"message_index": messageIndex,
"role": role,
"block_index": 0,
"block": content,
}, true)
}
}
}
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}, isMessageEnd bool) {
blockValue := wrapper["block"]
ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue))
// Normalize volatile text (e.g. Claude Code's x-anthropic-billing-header
// which drifts on every request) so that fingerprints remain stable across
// requests within the same conversation.
if normalized, changed := normalizeCacheBlockContent(blockValue); changed {
cloned := make(map[string]interface{}, len(wrapper))
for k, v := range wrapper {
cloned[k] = v
}
cloned["block"] = normalized
wrapper = cloned
}
canonical := canonicalizeCacheValue(wrapper)
*blocks = append(*blocks, cacheablePromptBlock{
Value: wrapper,
Tokens: estimateApproxTokens(canonical),
TTL: ttl,
IsMessageEnd: isMessageEnd,
})
}
// normalizeCacheBlockContent replaces volatile but semantically irrelevant
// fields with a placeholder so that the cumulative fingerprint stays stable
// across requests in the same session. Currently handles:
// - Claude Code's "x-anthropic-billing-header: ..." system text block
// whose content drifts on every request (version, telemetry hash, etc.)
func normalizeCacheBlockContent(value interface{}) (interface{}, bool) {
blockMap, ok := value.(map[string]interface{})
if !ok {
return value, false
}
// Only normalize text blocks (or blocks without an explicit type but containing text).
if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" {
return value, false
}
text, ok := blockMap["text"].(string)
if !ok {
return value, false
}
trimmed := strings.TrimLeft(text, " \t\r\n")
if !strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") {
return value, false
}
cloned := make(map[string]interface{}, len(blockMap))
for k, v := range blockMap {
cloned[k] = v
}
cloned["text"] = "__anthropic_billing_header__"
return cloned, true
}
func extractPromptCacheTTL(value interface{}) time.Duration {
block, ok := value.(map[string]interface{})
if !ok {
if raw, err := json.Marshal(value); err == nil {
var decoded map[string]interface{}
if json.Unmarshal(raw, &decoded) == nil {
block = decoded
ok = true
}
}
}
if !ok {
return 0
}
rawCache, ok := block["cache_control"]
if !ok {
return 0
}
cacheControl, ok := rawCache.(map[string]interface{})
if !ok {
return 0
}
cacheType, _ := cacheControl["type"].(string)
if !strings.EqualFold(cacheType, "ephemeral") {
return 0
}
if ttl, ok := parsePromptCacheTTLValue(cacheControl["ttl"]); ok {
return ttl
}
return defaultPromptCacheTTL
}
func parsePromptCacheTTLValue(value interface{}) (time.Duration, bool) {
switch v := value.(type) {
case string:
trimmed := strings.TrimSpace(strings.ToLower(v))
if trimmed == "" {
return 0, false
}
if d, err := time.ParseDuration(trimmed); err == nil {
return d, true
}
if seconds, err := strconv.Atoi(trimmed); err == nil {
return time.Duration(seconds) * time.Second, true
}
case float64:
if v > 0 {
return time.Duration(v) * time.Second, true
}
case int:
if v > 0 {
return time.Duration(v) * time.Second, true
}
case int64:
if v > 0 {
return time.Duration(v) * time.Second, true
}
}
return 0, false
}
func normalizePromptCacheTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return 0
}
if ttl > time.Hour {
return time.Hour
}
if ttl > defaultPromptCacheTTL {
return time.Hour
}
return defaultPromptCacheTTL
}
func computePromptCacheTTLBreakdown(profile *promptCacheProfile, matchedTokens int) (int, int) {
if profile == nil || len(profile.Breakpoints) == 0 {
return 0, 0
}
cache5m := 0
cache1h := 0
previous := matchedTokens
for _, breakpoint := range profile.Breakpoints {
current := minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
if current <= previous {
continue
}
delta := current - previous
if breakpoint.TTL >= time.Hour {
cache1h += delta
} else {
cache5m += delta
}
previous = current
}
return cache5m, cache1h
}
func buildClaudeUsageMap(inputTokens, outputTokens int, usage promptCacheUsage, includeCache bool) map[string]interface{} {
result := map[string]interface{}{
"input_tokens": inputTokens,
"output_tokens": outputTokens,
}
if !includeCache {
return result
}
result["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
result["cache_read_input_tokens"] = usage.CacheReadInputTokens
result["cache_creation"] = map[string]int{
"ephemeral_5m_input_tokens": usage.CacheCreation5mInputTokens,
"ephemeral_1h_input_tokens": usage.CacheCreation1hInputTokens,
}
return result
}
func canonicalizeCacheValue(value interface{}) string {
var buf bytes.Buffer
writeCanonicalJSON(&buf, value)
return buf.String()
}
func writeCanonicalJSON(buf *bytes.Buffer, value interface{}) {
switch v := value.(type) {
case nil:
buf.WriteString("null")
case string:
encoded, _ := json.Marshal(v)
buf.Write(encoded)
case bool:
if v {
buf.WriteString("true")
} else {
buf.WriteString("false")
}
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, json.Number:
encoded, _ := json.Marshal(v)
buf.Write(encoded)
case []interface{}:
buf.WriteByte('[')
for i, item := range v {
if i > 0 {
buf.WriteByte(',')
}
writeCanonicalJSON(buf, item)
}
buf.WriteByte(']')
case map[string]interface{}:
buf.WriteByte('{')
keys := make([]string, 0, len(v))
for key := range v {
if key == "cache_control" {
continue
}
keys = append(keys, key)
}
sort.Strings(keys)
for i, key := range keys {
if i > 0 {
buf.WriteByte(',')
}
encoded, _ := json.Marshal(key)
buf.Write(encoded)
buf.WriteByte(':')
writeCanonicalJSON(buf, v[key])
}
buf.WriteByte('}')
default:
encoded, _ := json.Marshal(v)
buf.Write(encoded)
}
}
func writeHashChunk(hasher hashWriter, chunk string) {
length := strconv.Itoa(len(chunk))
hasher.Write([]byte(length))
hasher.Write([]byte{0})
hasher.Write([]byte(chunk))
hasher.Write([]byte{0})
}
type hashWriter interface {
Write([]byte) (int, error)
Sum([]byte) []byte
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}