chore: sync dev branch proxy and workflow updates
This commit is contained in:
511
proxy/cache_tracker.go
Normal file
511
proxy/cache_tracker.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultPromptCacheTTL = 5 * time.Minute
|
||||
|
||||
type promptCacheUsage struct {
|
||||
CacheCreationInputTokens int
|
||||
CacheReadInputTokens int
|
||||
CacheCreation5mInputTokens int
|
||||
CacheCreation1hInputTokens int
|
||||
}
|
||||
|
||||
type promptCacheBreakpoint struct {
|
||||
Fingerprint [32]byte
|
||||
CumulativeTokens int
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
type promptCacheProfile struct {
|
||||
Breakpoints []promptCacheBreakpoint
|
||||
TotalInputTokens int
|
||||
}
|
||||
|
||||
type promptCacheEntry struct {
|
||||
ExpiresAt time.Time
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
type promptCacheTracker struct {
|
||||
mu sync.Mutex
|
||||
entriesByAccount map[string]map[[32]byte]promptCacheEntry
|
||||
maxSupportedTTL time.Duration
|
||||
}
|
||||
|
||||
func newPromptCacheTracker(maxTTL time.Duration) *promptCacheTracker {
|
||||
if maxTTL <= 0 {
|
||||
maxTTL = defaultPromptCacheTTL
|
||||
}
|
||||
return &promptCacheTracker{
|
||||
entriesByAccount: make(map[string]map[[32]byte]promptCacheEntry),
|
||||
maxSupportedTTL: maxTTL,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTokens int) *promptCacheProfile {
|
||||
blocks := flattenClaudeCacheBlocks(req)
|
||||
if len(blocks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
hasher := sha256.New()
|
||||
breakpoints := make([]promptCacheBreakpoint, 0)
|
||||
cumulativeTokens := 0
|
||||
|
||||
for _, block := range blocks {
|
||||
canonical := canonicalizeCacheValue(block.Value)
|
||||
writeHashChunk(hasher, canonical)
|
||||
cumulativeTokens += block.Tokens
|
||||
|
||||
if block.TTL <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var fingerprint [32]byte
|
||||
copy(fingerprint[:], hasher.Sum(nil))
|
||||
breakpoints = append(breakpoints, promptCacheBreakpoint{
|
||||
Fingerprint: fingerprint,
|
||||
CumulativeTokens: cumulativeTokens,
|
||||
TTL: block.TTL,
|
||||
})
|
||||
}
|
||||
|
||||
if len(breakpoints) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if totalInputTokens < cumulativeTokens {
|
||||
totalInputTokens = cumulativeTokens
|
||||
}
|
||||
|
||||
return &promptCacheProfile{
|
||||
Breakpoints: breakpoints,
|
||||
TotalInputTokens: totalInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfile) promptCacheUsage {
|
||||
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
|
||||
return promptCacheUsage{}
|
||||
}
|
||||
|
||||
last := profile.Breakpoints[len(profile.Breakpoints)-1]
|
||||
lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens)
|
||||
now := time.Now()
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.pruneExpiredLocked(now)
|
||||
|
||||
entries := t.entriesByAccount[accountID]
|
||||
if len(entries) == 0 {
|
||||
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0)
|
||||
return promptCacheUsage{
|
||||
CacheCreationInputTokens: lastTokens,
|
||||
CacheReadInputTokens: 0,
|
||||
CacheCreation5mInputTokens: cache5m,
|
||||
CacheCreation1hInputTokens: cache1h,
|
||||
}
|
||||
}
|
||||
|
||||
matchedTokens := 0
|
||||
for i := len(profile.Breakpoints) - 1; i >= 0; i-- {
|
||||
breakpoint := profile.Breakpoints[i]
|
||||
entry, ok := entries[breakpoint.Fingerprint]
|
||||
if !ok || entry.ExpiresAt.Before(now) {
|
||||
continue
|
||||
}
|
||||
entry.ExpiresAt = now.Add(entry.TTL)
|
||||
entries[breakpoint.Fingerprint] = entry
|
||||
matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
|
||||
break
|
||||
}
|
||||
|
||||
creation := maxInt(lastTokens-matchedTokens, 0)
|
||||
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, matchedTokens)
|
||||
return promptCacheUsage{
|
||||
CacheCreationInputTokens: creation,
|
||||
CacheReadInputTokens: matchedTokens,
|
||||
CacheCreation5mInputTokens: cache5m,
|
||||
CacheCreation1hInputTokens: cache1h,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfile) {
|
||||
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.pruneExpiredLocked(now)
|
||||
|
||||
entries := t.entriesByAccount[accountID]
|
||||
if entries == nil {
|
||||
entries = make(map[[32]byte]promptCacheEntry)
|
||||
t.entriesByAccount[accountID] = entries
|
||||
}
|
||||
|
||||
for _, breakpoint := range profile.Breakpoints {
|
||||
entries[breakpoint.Fingerprint] = promptCacheEntry{
|
||||
ExpiresAt: now.Add(breakpoint.TTL),
|
||||
TTL: breakpoint.TTL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) {
|
||||
for accountID, entries := range t.entriesByAccount {
|
||||
for fingerprint, entry := range entries {
|
||||
if !entry.ExpiresAt.After(now) {
|
||||
delete(entries, fingerprint)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
delete(t.entriesByAccount, accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type cacheablePromptBlock struct {
|
||||
Value interface{}
|
||||
Tokens int
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock {
|
||||
blocks := make([]cacheablePromptBlock, 0)
|
||||
blocks = append(blocks, buildCachePreludeBlock(req))
|
||||
|
||||
for toolIndex, tool := range req.Tools {
|
||||
toolValue := map[string]interface{}{
|
||||
"kind": "tool",
|
||||
"tool_index": toolIndex,
|
||||
"name": tool.Name,
|
||||
"description": tool.Description,
|
||||
"input_schema": tool.InputSchema,
|
||||
}
|
||||
blocks = append(blocks, cacheablePromptBlock{
|
||||
Value: toolValue,
|
||||
Tokens: estimateApproxTokens(canonicalizeCacheValue(toolValue)),
|
||||
TTL: normalizePromptCacheTTL(extractPromptCacheTTL(tool)),
|
||||
})
|
||||
}
|
||||
|
||||
appendSystemCacheBlocks(&blocks, req.System)
|
||||
|
||||
for messageIndex, msg := range req.Messages {
|
||||
appendMessageCacheBlocks(&blocks, messageIndex, msg)
|
||||
}
|
||||
|
||||
return blocks
|
||||
}
|
||||
|
||||
func buildCachePreludeBlock(req *ClaudeRequest) cacheablePromptBlock {
|
||||
prelude := map[string]interface{}{
|
||||
"kind": "request_prelude",
|
||||
"model": req.Model,
|
||||
"tool_choice": req.ToolChoice,
|
||||
}
|
||||
return cacheablePromptBlock{
|
||||
Value: prelude,
|
||||
Tokens: estimateApproxTokens(canonicalizeCacheValue(prelude)),
|
||||
}
|
||||
}
|
||||
|
||||
func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{}) {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "system",
|
||||
"system_index": 0,
|
||||
"block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": v,
|
||||
},
|
||||
})
|
||||
case []interface{}:
|
||||
for i, block := range v {
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "system",
|
||||
"system_index": i,
|
||||
"block": block,
|
||||
})
|
||||
}
|
||||
case []string:
|
||||
for i, block := range v {
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "system",
|
||||
"system_index": i,
|
||||
"block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": block,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, msg ClaudeMessage) {
|
||||
role := msg.Role
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "message",
|
||||
"message_index": messageIndex,
|
||||
"role": role,
|
||||
"block_index": 0,
|
||||
"block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
},
|
||||
})
|
||||
case []interface{}:
|
||||
for blockIndex, block := range content {
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "message",
|
||||
"message_index": messageIndex,
|
||||
"role": role,
|
||||
"block_index": blockIndex,
|
||||
"block": block,
|
||||
})
|
||||
}
|
||||
default:
|
||||
if content != nil {
|
||||
appendPromptBlock(blocks, map[string]interface{}{
|
||||
"kind": "message",
|
||||
"message_index": messageIndex,
|
||||
"role": role,
|
||||
"block_index": 0,
|
||||
"block": content,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}) {
|
||||
blockValue, _ := wrapper["block"]
|
||||
ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue))
|
||||
canonical := canonicalizeCacheValue(wrapper)
|
||||
*blocks = append(*blocks, cacheablePromptBlock{
|
||||
Value: wrapper,
|
||||
Tokens: estimateApproxTokens(canonical),
|
||||
TTL: ttl,
|
||||
})
|
||||
}
|
||||
|
||||
func extractPromptCacheTTL(value interface{}) time.Duration {
|
||||
block, ok := value.(map[string]interface{})
|
||||
if !ok {
|
||||
if raw, err := json.Marshal(value); err == nil {
|
||||
var decoded map[string]interface{}
|
||||
if json.Unmarshal(raw, &decoded) == nil {
|
||||
block = decoded
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
rawCache, ok := block["cache_control"]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
cacheControl, ok := rawCache.(map[string]interface{})
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
cacheType, _ := cacheControl["type"].(string)
|
||||
if !strings.EqualFold(cacheType, "ephemeral") {
|
||||
return 0
|
||||
}
|
||||
|
||||
if ttl, ok := parsePromptCacheTTLValue(cacheControl["ttl"]); ok {
|
||||
return ttl
|
||||
}
|
||||
return defaultPromptCacheTTL
|
||||
}
|
||||
|
||||
func parsePromptCacheTTLValue(value interface{}) (time.Duration, bool) {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(strings.ToLower(v))
|
||||
if trimmed == "" {
|
||||
return 0, false
|
||||
}
|
||||
if d, err := time.ParseDuration(trimmed); err == nil {
|
||||
return d, true
|
||||
}
|
||||
if seconds, err := strconv.Atoi(trimmed); err == nil {
|
||||
return time.Duration(seconds) * time.Second, true
|
||||
}
|
||||
case float64:
|
||||
if v > 0 {
|
||||
return time.Duration(v) * time.Second, true
|
||||
}
|
||||
case int:
|
||||
if v > 0 {
|
||||
return time.Duration(v) * time.Second, true
|
||||
}
|
||||
case int64:
|
||||
if v > 0 {
|
||||
return time.Duration(v) * time.Second, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func normalizePromptCacheTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return 0
|
||||
}
|
||||
if ttl > time.Hour {
|
||||
return time.Hour
|
||||
}
|
||||
if ttl > defaultPromptCacheTTL {
|
||||
return time.Hour
|
||||
}
|
||||
return defaultPromptCacheTTL
|
||||
}
|
||||
|
||||
func computePromptCacheTTLBreakdown(profile *promptCacheProfile, matchedTokens int) (int, int) {
|
||||
if profile == nil || len(profile.Breakpoints) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
cache5m := 0
|
||||
cache1h := 0
|
||||
previous := matchedTokens
|
||||
for _, breakpoint := range profile.Breakpoints {
|
||||
current := minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
|
||||
if current <= previous {
|
||||
continue
|
||||
}
|
||||
delta := current - previous
|
||||
if breakpoint.TTL >= time.Hour {
|
||||
cache1h += delta
|
||||
} else {
|
||||
cache5m += delta
|
||||
}
|
||||
previous = current
|
||||
}
|
||||
return cache5m, cache1h
|
||||
}
|
||||
|
||||
func billedClaudeInputTokens(inputTokens int, usage promptCacheUsage) int {
|
||||
return maxInt(inputTokens-usage.CacheCreationInputTokens-usage.CacheReadInputTokens, 0)
|
||||
}
|
||||
|
||||
func buildClaudeUsageMap(inputTokens, outputTokens int, usage promptCacheUsage, includeCache bool) map[string]interface{} {
|
||||
result := map[string]interface{}{
|
||||
"input_tokens": billedClaudeInputTokens(inputTokens, usage),
|
||||
"output_tokens": outputTokens,
|
||||
}
|
||||
if !includeCache {
|
||||
return result
|
||||
}
|
||||
result["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
|
||||
result["cache_read_input_tokens"] = usage.CacheReadInputTokens
|
||||
result["cache_creation"] = map[string]int{
|
||||
"ephemeral_5m_input_tokens": usage.CacheCreation5mInputTokens,
|
||||
"ephemeral_1h_input_tokens": usage.CacheCreation1hInputTokens,
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func canonicalizeCacheValue(value interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
writeCanonicalJSON(&buf, value)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func writeCanonicalJSON(buf *bytes.Buffer, value interface{}) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
buf.WriteString("null")
|
||||
case string:
|
||||
encoded, _ := json.Marshal(v)
|
||||
buf.Write(encoded)
|
||||
case bool:
|
||||
if v {
|
||||
buf.WriteString("true")
|
||||
} else {
|
||||
buf.WriteString("false")
|
||||
}
|
||||
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, json.Number:
|
||||
encoded, _ := json.Marshal(v)
|
||||
buf.Write(encoded)
|
||||
case []interface{}:
|
||||
buf.WriteByte('[')
|
||||
for i, item := range v {
|
||||
if i > 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
writeCanonicalJSON(buf, item)
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
case map[string]interface{}:
|
||||
buf.WriteByte('{')
|
||||
keys := make([]string, 0, len(v))
|
||||
for key := range v {
|
||||
if key == "cache_control" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
for i, key := range keys {
|
||||
if i > 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
encoded, _ := json.Marshal(key)
|
||||
buf.Write(encoded)
|
||||
buf.WriteByte(':')
|
||||
writeCanonicalJSON(buf, v[key])
|
||||
}
|
||||
buf.WriteByte('}')
|
||||
default:
|
||||
encoded, _ := json.Marshal(v)
|
||||
buf.Write(encoded)
|
||||
}
|
||||
}
|
||||
|
||||
func writeHashChunk(hasher hashWriter, chunk string) {
|
||||
length := strconv.Itoa(len(chunk))
|
||||
hasher.Write([]byte(length))
|
||||
hasher.Write([]byte{0})
|
||||
hasher.Write([]byte(chunk))
|
||||
hasher.Write([]byte{0})
|
||||
}
|
||||
|
||||
type hashWriter interface {
|
||||
Write([]byte) (int, error)
|
||||
Sum([]byte) []byte
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user