Some checks failed
Build Docker Image / build (push) Has been cancelled
Kiro backend does not support Anthropic prompt cache protocol. The local cache tracker simulates cache hits/creation for Claude Code compatibility, but subtracting those values from input_tokens caused the reported input_tokens to drop to single digits. input_tokens now reflects the real value; cache_creation_input_tokens and cache_read_input_tokens are still reported for protocol compliance. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
615 lines
16 KiB
Go
615 lines
16 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const defaultPromptCacheTTL = 5 * time.Minute
|
|
|
|
// Anthropic requires cached prefixes to reach a minimum token count before
|
|
// caching takes effect. Breakpoints below this threshold are excluded from
|
|
// matching and storage to avoid reporting unrealistic 100% cache hits on
|
|
// short requests.
|
|
const defaultMinCacheableTokens = 1024
|
|
const opusMinCacheableTokens = 4096
|
|
|
|
type promptCacheUsage struct {
|
|
CacheCreationInputTokens int
|
|
CacheReadInputTokens int
|
|
CacheCreation5mInputTokens int
|
|
CacheCreation1hInputTokens int
|
|
}
|
|
|
|
type promptCacheBreakpoint struct {
|
|
Fingerprint [32]byte
|
|
CumulativeTokens int
|
|
TTL time.Duration
|
|
}
|
|
|
|
type promptCacheProfile struct {
|
|
Breakpoints []promptCacheBreakpoint
|
|
TotalInputTokens int
|
|
Model string
|
|
}
|
|
|
|
func minCacheableTokensForModel(model string) int {
|
|
lower := strings.ToLower(model)
|
|
if strings.Contains(lower, "opus") {
|
|
return opusMinCacheableTokens
|
|
}
|
|
return defaultMinCacheableTokens
|
|
}
|
|
|
|
type promptCacheEntry struct {
|
|
ExpiresAt time.Time
|
|
TTL time.Duration
|
|
}
|
|
|
|
type promptCacheTracker struct {
|
|
mu sync.Mutex
|
|
entriesByAccount map[string]map[[32]byte]promptCacheEntry
|
|
maxSupportedTTL time.Duration
|
|
}
|
|
|
|
func newPromptCacheTracker(maxTTL time.Duration) *promptCacheTracker {
|
|
if maxTTL <= 0 {
|
|
maxTTL = defaultPromptCacheTTL
|
|
}
|
|
return &promptCacheTracker{
|
|
entriesByAccount: make(map[string]map[[32]byte]promptCacheEntry),
|
|
maxSupportedTTL: maxTTL,
|
|
}
|
|
}
|
|
|
|
func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTokens int) *promptCacheProfile {
|
|
blocks := flattenClaudeCacheBlocks(req)
|
|
if len(blocks) == 0 {
|
|
return nil
|
|
}
|
|
|
|
hasher := sha256.New()
|
|
breakpoints := make([]promptCacheBreakpoint, 0)
|
|
cumulativeTokens := 0
|
|
var activeTTL time.Duration
|
|
|
|
for _, block := range blocks {
|
|
canonical := canonicalizeCacheValue(block.Value)
|
|
writeHashChunk(hasher, canonical)
|
|
cumulativeTokens += block.Tokens
|
|
|
|
// Determine whether this block acts as a cache breakpoint:
|
|
// 1) Explicit cache_control on the block itself.
|
|
// 2) Once any explicit breakpoint has been seen, every message-end
|
|
// boundary becomes an implicit breakpoint so that multi-turn
|
|
// conversations can hit earlier stored prefixes.
|
|
breakpointTTL := time.Duration(0)
|
|
if block.TTL > 0 {
|
|
breakpointTTL = block.TTL
|
|
activeTTL = block.TTL
|
|
} else if block.IsMessageEnd && activeTTL > 0 {
|
|
breakpointTTL = activeTTL
|
|
}
|
|
|
|
if breakpointTTL <= 0 {
|
|
continue
|
|
}
|
|
|
|
var fingerprint [32]byte
|
|
copy(fingerprint[:], hasher.Sum(nil))
|
|
breakpoints = append(breakpoints, promptCacheBreakpoint{
|
|
Fingerprint: fingerprint,
|
|
CumulativeTokens: cumulativeTokens,
|
|
TTL: breakpointTTL,
|
|
})
|
|
}
|
|
|
|
if len(breakpoints) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if totalInputTokens < cumulativeTokens {
|
|
totalInputTokens = cumulativeTokens
|
|
}
|
|
|
|
return &promptCacheProfile{
|
|
Breakpoints: breakpoints,
|
|
TotalInputTokens: totalInputTokens,
|
|
Model: req.Model,
|
|
}
|
|
}
|
|
|
|
func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfile) promptCacheUsage {
|
|
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
|
|
return promptCacheUsage{}
|
|
}
|
|
|
|
minTokens := minCacheableTokensForModel(profile.Model)
|
|
last := profile.Breakpoints[len(profile.Breakpoints)-1]
|
|
lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens)
|
|
now := time.Now()
|
|
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.pruneExpiredLocked(now)
|
|
|
|
entries := t.entriesByAccount[accountID]
|
|
if len(entries) == 0 {
|
|
// First request for this account: report creation only if above threshold.
|
|
effectiveCreation := lastTokens
|
|
if effectiveCreation < minTokens {
|
|
effectiveCreation = 0
|
|
}
|
|
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0)
|
|
return promptCacheUsage{
|
|
CacheCreationInputTokens: effectiveCreation,
|
|
CacheReadInputTokens: 0,
|
|
CacheCreation5mInputTokens: cache5m,
|
|
CacheCreation1hInputTokens: cache1h,
|
|
}
|
|
}
|
|
|
|
// Cap cacheable tokens at 85% of total input to ensure a realistic
|
|
// uncached portion. The newest content in a request is never fully
|
|
// served from cache on the current turn.
|
|
maxCacheable := int(float64(profile.TotalInputTokens) * 0.85)
|
|
if lastTokens > maxCacheable {
|
|
lastTokens = maxCacheable
|
|
}
|
|
|
|
matchedTokens := 0
|
|
for i := len(profile.Breakpoints) - 1; i >= 0; i-- {
|
|
breakpoint := profile.Breakpoints[i]
|
|
// Skip breakpoints below the minimum cacheable token threshold.
|
|
if breakpoint.CumulativeTokens < minTokens {
|
|
continue
|
|
}
|
|
entry, ok := entries[breakpoint.Fingerprint]
|
|
if !ok || entry.ExpiresAt.Before(now) {
|
|
continue
|
|
}
|
|
entry.ExpiresAt = now.Add(entry.TTL)
|
|
entries[breakpoint.Fingerprint] = entry
|
|
matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
|
|
if matchedTokens > lastTokens {
|
|
matchedTokens = lastTokens
|
|
}
|
|
break
|
|
}
|
|
|
|
creation := maxInt(lastTokens-matchedTokens, 0)
|
|
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, matchedTokens)
|
|
return promptCacheUsage{
|
|
CacheCreationInputTokens: creation,
|
|
CacheReadInputTokens: matchedTokens,
|
|
CacheCreation5mInputTokens: cache5m,
|
|
CacheCreation1hInputTokens: cache1h,
|
|
}
|
|
}
|
|
|
|
func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfile) {
|
|
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
|
|
return
|
|
}
|
|
|
|
minTokens := minCacheableTokensForModel(profile.Model)
|
|
now := time.Now()
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.pruneExpiredLocked(now)
|
|
|
|
entries := t.entriesByAccount[accountID]
|
|
if entries == nil {
|
|
entries = make(map[[32]byte]promptCacheEntry)
|
|
t.entriesByAccount[accountID] = entries
|
|
}
|
|
|
|
for _, breakpoint := range profile.Breakpoints {
|
|
// Skip breakpoints below the minimum cacheable token threshold.
|
|
if breakpoint.CumulativeTokens < minTokens {
|
|
continue
|
|
}
|
|
entries[breakpoint.Fingerprint] = promptCacheEntry{
|
|
ExpiresAt: now.Add(breakpoint.TTL),
|
|
TTL: breakpoint.TTL,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) {
|
|
for accountID, entries := range t.entriesByAccount {
|
|
for fingerprint, entry := range entries {
|
|
if !entry.ExpiresAt.After(now) {
|
|
delete(entries, fingerprint)
|
|
}
|
|
}
|
|
if len(entries) == 0 {
|
|
delete(t.entriesByAccount, accountID)
|
|
}
|
|
}
|
|
}
|
|
|
|
type cacheablePromptBlock struct {
|
|
Value interface{}
|
|
Tokens int
|
|
TTL time.Duration
|
|
IsMessageEnd bool
|
|
}
|
|
|
|
func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock {
|
|
blocks := make([]cacheablePromptBlock, 0)
|
|
blocks = append(blocks, buildCachePreludeBlock(req))
|
|
|
|
for toolIndex, tool := range req.Tools {
|
|
toolValue := map[string]interface{}{
|
|
"kind": "tool",
|
|
"tool_index": toolIndex,
|
|
"name": tool.Name,
|
|
"description": tool.Description,
|
|
"input_schema": tool.InputSchema,
|
|
}
|
|
blocks = append(blocks, cacheablePromptBlock{
|
|
Value: toolValue,
|
|
Tokens: estimateApproxTokens(canonicalizeCacheValue(toolValue)),
|
|
TTL: normalizePromptCacheTTL(extractPromptCacheTTL(tool)),
|
|
})
|
|
}
|
|
|
|
appendSystemCacheBlocks(&blocks, req.System)
|
|
|
|
for messageIndex, msg := range req.Messages {
|
|
appendMessageCacheBlocks(&blocks, messageIndex, msg)
|
|
}
|
|
|
|
return blocks
|
|
}
|
|
|
|
func buildCachePreludeBlock(req *ClaudeRequest) cacheablePromptBlock {
|
|
prelude := map[string]interface{}{
|
|
"kind": "request_prelude",
|
|
"model": req.Model,
|
|
"tool_choice": req.ToolChoice,
|
|
}
|
|
return cacheablePromptBlock{
|
|
Value: prelude,
|
|
Tokens: estimateApproxTokens(canonicalizeCacheValue(prelude)),
|
|
}
|
|
}
|
|
|
|
func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{}) {
|
|
switch v := system.(type) {
|
|
case string:
|
|
appendPromptBlock(blocks, map[string]interface{}{
|
|
"kind": "system",
|
|
"system_index": 0,
|
|
"block": map[string]interface{}{
|
|
"type": "text",
|
|
"text": v,
|
|
},
|
|
}, false)
|
|
case []interface{}:
|
|
for i, block := range v {
|
|
appendPromptBlock(blocks, map[string]interface{}{
|
|
"kind": "system",
|
|
"system_index": i,
|
|
"block": block,
|
|
}, false)
|
|
}
|
|
case []string:
|
|
for i, block := range v {
|
|
appendPromptBlock(blocks, map[string]interface{}{
|
|
"kind": "system",
|
|
"system_index": i,
|
|
"block": map[string]interface{}{
|
|
"type": "text",
|
|
"text": block,
|
|
},
|
|
}, false)
|
|
}
|
|
}
|
|
}
|
|
|
|
func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, msg ClaudeMessage) {
|
|
role := msg.Role
|
|
switch content := msg.Content.(type) {
|
|
case string:
|
|
appendPromptBlock(blocks, map[string]interface{}{
|
|
"kind": "message",
|
|
"message_index": messageIndex,
|
|
"role": role,
|
|
"block_index": 0,
|
|
"block": map[string]interface{}{
|
|
"type": "text",
|
|
"text": content,
|
|
},
|
|
}, true)
|
|
case []interface{}:
|
|
lastIdx := len(content) - 1
|
|
for blockIndex, block := range content {
|
|
appendPromptBlock(blocks, map[string]interface{}{
|
|
"kind": "message",
|
|
"message_index": messageIndex,
|
|
"role": role,
|
|
"block_index": blockIndex,
|
|
"block": block,
|
|
}, blockIndex == lastIdx)
|
|
}
|
|
default:
|
|
if content != nil {
|
|
appendPromptBlock(blocks, map[string]interface{}{
|
|
"kind": "message",
|
|
"message_index": messageIndex,
|
|
"role": role,
|
|
"block_index": 0,
|
|
"block": content,
|
|
}, true)
|
|
}
|
|
}
|
|
}
|
|
|
|
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}, isMessageEnd bool) {
|
|
blockValue := wrapper["block"]
|
|
ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue))
|
|
|
|
// Normalize volatile text (e.g. Claude Code's x-anthropic-billing-header
|
|
// which drifts on every request) so that fingerprints remain stable across
|
|
// requests within the same conversation.
|
|
if normalized, changed := normalizeCacheBlockContent(blockValue); changed {
|
|
cloned := make(map[string]interface{}, len(wrapper))
|
|
for k, v := range wrapper {
|
|
cloned[k] = v
|
|
}
|
|
cloned["block"] = normalized
|
|
wrapper = cloned
|
|
}
|
|
|
|
canonical := canonicalizeCacheValue(wrapper)
|
|
*blocks = append(*blocks, cacheablePromptBlock{
|
|
Value: wrapper,
|
|
Tokens: estimateApproxTokens(canonical),
|
|
TTL: ttl,
|
|
IsMessageEnd: isMessageEnd,
|
|
})
|
|
}
|
|
|
|
// normalizeCacheBlockContent replaces volatile but semantically irrelevant
|
|
// fields with a placeholder so that the cumulative fingerprint stays stable
|
|
// across requests in the same session. Currently handles:
|
|
// - Claude Code's "x-anthropic-billing-header: ..." system text block
|
|
// whose content drifts on every request (version, telemetry hash, etc.)
|
|
func normalizeCacheBlockContent(value interface{}) (interface{}, bool) {
|
|
blockMap, ok := value.(map[string]interface{})
|
|
if !ok {
|
|
return value, false
|
|
}
|
|
|
|
// Only normalize text blocks (or blocks without an explicit type but containing text).
|
|
if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" {
|
|
return value, false
|
|
}
|
|
|
|
text, ok := blockMap["text"].(string)
|
|
if !ok {
|
|
return value, false
|
|
}
|
|
|
|
trimmed := strings.TrimLeft(text, " \t\r\n")
|
|
if !strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") {
|
|
return value, false
|
|
}
|
|
|
|
cloned := make(map[string]interface{}, len(blockMap))
|
|
for k, v := range blockMap {
|
|
cloned[k] = v
|
|
}
|
|
cloned["text"] = "__anthropic_billing_header__"
|
|
return cloned, true
|
|
}
|
|
|
|
func extractPromptCacheTTL(value interface{}) time.Duration {
|
|
block, ok := value.(map[string]interface{})
|
|
if !ok {
|
|
if raw, err := json.Marshal(value); err == nil {
|
|
var decoded map[string]interface{}
|
|
if json.Unmarshal(raw, &decoded) == nil {
|
|
block = decoded
|
|
ok = true
|
|
}
|
|
}
|
|
}
|
|
if !ok {
|
|
return 0
|
|
}
|
|
|
|
rawCache, ok := block["cache_control"]
|
|
if !ok {
|
|
return 0
|
|
}
|
|
cacheControl, ok := rawCache.(map[string]interface{})
|
|
if !ok {
|
|
return 0
|
|
}
|
|
cacheType, _ := cacheControl["type"].(string)
|
|
if !strings.EqualFold(cacheType, "ephemeral") {
|
|
return 0
|
|
}
|
|
|
|
if ttl, ok := parsePromptCacheTTLValue(cacheControl["ttl"]); ok {
|
|
return ttl
|
|
}
|
|
return defaultPromptCacheTTL
|
|
}
|
|
|
|
func parsePromptCacheTTLValue(value interface{}) (time.Duration, bool) {
|
|
switch v := value.(type) {
|
|
case string:
|
|
trimmed := strings.TrimSpace(strings.ToLower(v))
|
|
if trimmed == "" {
|
|
return 0, false
|
|
}
|
|
if d, err := time.ParseDuration(trimmed); err == nil {
|
|
return d, true
|
|
}
|
|
if seconds, err := strconv.Atoi(trimmed); err == nil {
|
|
return time.Duration(seconds) * time.Second, true
|
|
}
|
|
case float64:
|
|
if v > 0 {
|
|
return time.Duration(v) * time.Second, true
|
|
}
|
|
case int:
|
|
if v > 0 {
|
|
return time.Duration(v) * time.Second, true
|
|
}
|
|
case int64:
|
|
if v > 0 {
|
|
return time.Duration(v) * time.Second, true
|
|
}
|
|
}
|
|
return 0, false
|
|
}
|
|
|
|
func normalizePromptCacheTTL(ttl time.Duration) time.Duration {
|
|
if ttl <= 0 {
|
|
return 0
|
|
}
|
|
if ttl > time.Hour {
|
|
return time.Hour
|
|
}
|
|
if ttl > defaultPromptCacheTTL {
|
|
return time.Hour
|
|
}
|
|
return defaultPromptCacheTTL
|
|
}
|
|
|
|
func computePromptCacheTTLBreakdown(profile *promptCacheProfile, matchedTokens int) (int, int) {
|
|
if profile == nil || len(profile.Breakpoints) == 0 {
|
|
return 0, 0
|
|
}
|
|
|
|
cache5m := 0
|
|
cache1h := 0
|
|
previous := matchedTokens
|
|
for _, breakpoint := range profile.Breakpoints {
|
|
current := minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
|
|
if current <= previous {
|
|
continue
|
|
}
|
|
delta := current - previous
|
|
if breakpoint.TTL >= time.Hour {
|
|
cache1h += delta
|
|
} else {
|
|
cache5m += delta
|
|
}
|
|
previous = current
|
|
}
|
|
return cache5m, cache1h
|
|
}
|
|
|
|
func buildClaudeUsageMap(inputTokens, outputTokens int, usage promptCacheUsage, includeCache bool) map[string]interface{} {
|
|
result := map[string]interface{}{
|
|
"input_tokens": inputTokens,
|
|
"output_tokens": outputTokens,
|
|
}
|
|
if !includeCache {
|
|
return result
|
|
}
|
|
result["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
|
|
result["cache_read_input_tokens"] = usage.CacheReadInputTokens
|
|
result["cache_creation"] = map[string]int{
|
|
"ephemeral_5m_input_tokens": usage.CacheCreation5mInputTokens,
|
|
"ephemeral_1h_input_tokens": usage.CacheCreation1hInputTokens,
|
|
}
|
|
return result
|
|
}
|
|
|
|
func canonicalizeCacheValue(value interface{}) string {
|
|
var buf bytes.Buffer
|
|
writeCanonicalJSON(&buf, value)
|
|
return buf.String()
|
|
}
|
|
|
|
func writeCanonicalJSON(buf *bytes.Buffer, value interface{}) {
|
|
switch v := value.(type) {
|
|
case nil:
|
|
buf.WriteString("null")
|
|
case string:
|
|
encoded, _ := json.Marshal(v)
|
|
buf.Write(encoded)
|
|
case bool:
|
|
if v {
|
|
buf.WriteString("true")
|
|
} else {
|
|
buf.WriteString("false")
|
|
}
|
|
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, json.Number:
|
|
encoded, _ := json.Marshal(v)
|
|
buf.Write(encoded)
|
|
case []interface{}:
|
|
buf.WriteByte('[')
|
|
for i, item := range v {
|
|
if i > 0 {
|
|
buf.WriteByte(',')
|
|
}
|
|
writeCanonicalJSON(buf, item)
|
|
}
|
|
buf.WriteByte(']')
|
|
case map[string]interface{}:
|
|
buf.WriteByte('{')
|
|
keys := make([]string, 0, len(v))
|
|
for key := range v {
|
|
if key == "cache_control" {
|
|
continue
|
|
}
|
|
keys = append(keys, key)
|
|
}
|
|
sort.Strings(keys)
|
|
for i, key := range keys {
|
|
if i > 0 {
|
|
buf.WriteByte(',')
|
|
}
|
|
encoded, _ := json.Marshal(key)
|
|
buf.Write(encoded)
|
|
buf.WriteByte(':')
|
|
writeCanonicalJSON(buf, v[key])
|
|
}
|
|
buf.WriteByte('}')
|
|
default:
|
|
encoded, _ := json.Marshal(v)
|
|
buf.Write(encoded)
|
|
}
|
|
}
|
|
|
|
func writeHashChunk(hasher hashWriter, chunk string) {
|
|
length := strconv.Itoa(len(chunk))
|
|
hasher.Write([]byte(length))
|
|
hasher.Write([]byte{0})
|
|
hasher.Write([]byte(chunk))
|
|
hasher.Write([]byte{0})
|
|
}
|
|
|
|
type hashWriter interface {
|
|
Write([]byte) (int, error)
|
|
Sum([]byte) []byte
|
|
}
|
|
|
|
func minInt(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func maxInt(a, b int) int {
|
|
if a > b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|