chore: sync dev branch proxy and workflow updates

This commit is contained in:
Quorinex
2026-05-10 18:57:40 +08:00
parent a063efd494
commit a24529d783
13 changed files with 1062 additions and 100 deletions

511
proxy/cache_tracker.go Normal file
View File

@@ -0,0 +1,511 @@
package proxy
import (
"bytes"
"crypto/sha256"
"encoding/json"
"sort"
"strconv"
"strings"
"sync"
"time"
)
const defaultPromptCacheTTL = 5 * time.Minute
type promptCacheUsage struct {
CacheCreationInputTokens int
CacheReadInputTokens int
CacheCreation5mInputTokens int
CacheCreation1hInputTokens int
}
type promptCacheBreakpoint struct {
Fingerprint [32]byte
CumulativeTokens int
TTL time.Duration
}
type promptCacheProfile struct {
Breakpoints []promptCacheBreakpoint
TotalInputTokens int
}
type promptCacheEntry struct {
ExpiresAt time.Time
TTL time.Duration
}
type promptCacheTracker struct {
mu sync.Mutex
entriesByAccount map[string]map[[32]byte]promptCacheEntry
maxSupportedTTL time.Duration
}
func newPromptCacheTracker(maxTTL time.Duration) *promptCacheTracker {
if maxTTL <= 0 {
maxTTL = defaultPromptCacheTTL
}
return &promptCacheTracker{
entriesByAccount: make(map[string]map[[32]byte]promptCacheEntry),
maxSupportedTTL: maxTTL,
}
}
func (t *promptCacheTracker) BuildClaudeProfile(req *ClaudeRequest, totalInputTokens int) *promptCacheProfile {
blocks := flattenClaudeCacheBlocks(req)
if len(blocks) == 0 {
return nil
}
hasher := sha256.New()
breakpoints := make([]promptCacheBreakpoint, 0)
cumulativeTokens := 0
for _, block := range blocks {
canonical := canonicalizeCacheValue(block.Value)
writeHashChunk(hasher, canonical)
cumulativeTokens += block.Tokens
if block.TTL <= 0 {
continue
}
var fingerprint [32]byte
copy(fingerprint[:], hasher.Sum(nil))
breakpoints = append(breakpoints, promptCacheBreakpoint{
Fingerprint: fingerprint,
CumulativeTokens: cumulativeTokens,
TTL: block.TTL,
})
}
if len(breakpoints) == 0 {
return nil
}
if totalInputTokens < cumulativeTokens {
totalInputTokens = cumulativeTokens
}
return &promptCacheProfile{
Breakpoints: breakpoints,
TotalInputTokens: totalInputTokens,
}
}
func (t *promptCacheTracker) Compute(accountID string, profile *promptCacheProfile) promptCacheUsage {
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
return promptCacheUsage{}
}
last := profile.Breakpoints[len(profile.Breakpoints)-1]
lastTokens := minInt(last.CumulativeTokens, profile.TotalInputTokens)
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
t.pruneExpiredLocked(now)
entries := t.entriesByAccount[accountID]
if len(entries) == 0 {
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, 0)
return promptCacheUsage{
CacheCreationInputTokens: lastTokens,
CacheReadInputTokens: 0,
CacheCreation5mInputTokens: cache5m,
CacheCreation1hInputTokens: cache1h,
}
}
matchedTokens := 0
for i := len(profile.Breakpoints) - 1; i >= 0; i-- {
breakpoint := profile.Breakpoints[i]
entry, ok := entries[breakpoint.Fingerprint]
if !ok || entry.ExpiresAt.Before(now) {
continue
}
entry.ExpiresAt = now.Add(entry.TTL)
entries[breakpoint.Fingerprint] = entry
matchedTokens = minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
break
}
creation := maxInt(lastTokens-matchedTokens, 0)
cache5m, cache1h := computePromptCacheTTLBreakdown(profile, matchedTokens)
return promptCacheUsage{
CacheCreationInputTokens: creation,
CacheReadInputTokens: matchedTokens,
CacheCreation5mInputTokens: cache5m,
CacheCreation1hInputTokens: cache1h,
}
}
func (t *promptCacheTracker) Update(accountID string, profile *promptCacheProfile) {
if t == nil || profile == nil || len(profile.Breakpoints) == 0 || accountID == "" {
return
}
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
t.pruneExpiredLocked(now)
entries := t.entriesByAccount[accountID]
if entries == nil {
entries = make(map[[32]byte]promptCacheEntry)
t.entriesByAccount[accountID] = entries
}
for _, breakpoint := range profile.Breakpoints {
entries[breakpoint.Fingerprint] = promptCacheEntry{
ExpiresAt: now.Add(breakpoint.TTL),
TTL: breakpoint.TTL,
}
}
}
func (t *promptCacheTracker) pruneExpiredLocked(now time.Time) {
for accountID, entries := range t.entriesByAccount {
for fingerprint, entry := range entries {
if !entry.ExpiresAt.After(now) {
delete(entries, fingerprint)
}
}
if len(entries) == 0 {
delete(t.entriesByAccount, accountID)
}
}
}
type cacheablePromptBlock struct {
Value interface{}
Tokens int
TTL time.Duration
}
func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock {
blocks := make([]cacheablePromptBlock, 0)
blocks = append(blocks, buildCachePreludeBlock(req))
for toolIndex, tool := range req.Tools {
toolValue := map[string]interface{}{
"kind": "tool",
"tool_index": toolIndex,
"name": tool.Name,
"description": tool.Description,
"input_schema": tool.InputSchema,
}
blocks = append(blocks, cacheablePromptBlock{
Value: toolValue,
Tokens: estimateApproxTokens(canonicalizeCacheValue(toolValue)),
TTL: normalizePromptCacheTTL(extractPromptCacheTTL(tool)),
})
}
appendSystemCacheBlocks(&blocks, req.System)
for messageIndex, msg := range req.Messages {
appendMessageCacheBlocks(&blocks, messageIndex, msg)
}
return blocks
}
func buildCachePreludeBlock(req *ClaudeRequest) cacheablePromptBlock {
prelude := map[string]interface{}{
"kind": "request_prelude",
"model": req.Model,
"tool_choice": req.ToolChoice,
}
return cacheablePromptBlock{
Value: prelude,
Tokens: estimateApproxTokens(canonicalizeCacheValue(prelude)),
}
}
func appendSystemCacheBlocks(blocks *[]cacheablePromptBlock, system interface{}) {
switch v := system.(type) {
case string:
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": 0,
"block": map[string]interface{}{
"type": "text",
"text": v,
},
})
case []interface{}:
for i, block := range v {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": i,
"block": block,
})
}
case []string:
for i, block := range v {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "system",
"system_index": i,
"block": map[string]interface{}{
"type": "text",
"text": block,
},
})
}
}
}
func appendMessageCacheBlocks(blocks *[]cacheablePromptBlock, messageIndex int, msg ClaudeMessage) {
role := msg.Role
switch content := msg.Content.(type) {
case string:
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
"message_index": messageIndex,
"role": role,
"block_index": 0,
"block": map[string]interface{}{
"type": "text",
"text": content,
},
})
case []interface{}:
for blockIndex, block := range content {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
"message_index": messageIndex,
"role": role,
"block_index": blockIndex,
"block": block,
})
}
default:
if content != nil {
appendPromptBlock(blocks, map[string]interface{}{
"kind": "message",
"message_index": messageIndex,
"role": role,
"block_index": 0,
"block": content,
})
}
}
}
func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interface{}) {
blockValue, _ := wrapper["block"]
ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue))
canonical := canonicalizeCacheValue(wrapper)
*blocks = append(*blocks, cacheablePromptBlock{
Value: wrapper,
Tokens: estimateApproxTokens(canonical),
TTL: ttl,
})
}
func extractPromptCacheTTL(value interface{}) time.Duration {
block, ok := value.(map[string]interface{})
if !ok {
if raw, err := json.Marshal(value); err == nil {
var decoded map[string]interface{}
if json.Unmarshal(raw, &decoded) == nil {
block = decoded
ok = true
}
}
}
if !ok {
return 0
}
rawCache, ok := block["cache_control"]
if !ok {
return 0
}
cacheControl, ok := rawCache.(map[string]interface{})
if !ok {
return 0
}
cacheType, _ := cacheControl["type"].(string)
if !strings.EqualFold(cacheType, "ephemeral") {
return 0
}
if ttl, ok := parsePromptCacheTTLValue(cacheControl["ttl"]); ok {
return ttl
}
return defaultPromptCacheTTL
}
func parsePromptCacheTTLValue(value interface{}) (time.Duration, bool) {
switch v := value.(type) {
case string:
trimmed := strings.TrimSpace(strings.ToLower(v))
if trimmed == "" {
return 0, false
}
if d, err := time.ParseDuration(trimmed); err == nil {
return d, true
}
if seconds, err := strconv.Atoi(trimmed); err == nil {
return time.Duration(seconds) * time.Second, true
}
case float64:
if v > 0 {
return time.Duration(v) * time.Second, true
}
case int:
if v > 0 {
return time.Duration(v) * time.Second, true
}
case int64:
if v > 0 {
return time.Duration(v) * time.Second, true
}
}
return 0, false
}
func normalizePromptCacheTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return 0
}
if ttl > time.Hour {
return time.Hour
}
if ttl > defaultPromptCacheTTL {
return time.Hour
}
return defaultPromptCacheTTL
}
func computePromptCacheTTLBreakdown(profile *promptCacheProfile, matchedTokens int) (int, int) {
if profile == nil || len(profile.Breakpoints) == 0 {
return 0, 0
}
cache5m := 0
cache1h := 0
previous := matchedTokens
for _, breakpoint := range profile.Breakpoints {
current := minInt(breakpoint.CumulativeTokens, profile.TotalInputTokens)
if current <= previous {
continue
}
delta := current - previous
if breakpoint.TTL >= time.Hour {
cache1h += delta
} else {
cache5m += delta
}
previous = current
}
return cache5m, cache1h
}
func billedClaudeInputTokens(inputTokens int, usage promptCacheUsage) int {
return maxInt(inputTokens-usage.CacheCreationInputTokens-usage.CacheReadInputTokens, 0)
}
func buildClaudeUsageMap(inputTokens, outputTokens int, usage promptCacheUsage, includeCache bool) map[string]interface{} {
result := map[string]interface{}{
"input_tokens": billedClaudeInputTokens(inputTokens, usage),
"output_tokens": outputTokens,
}
if !includeCache {
return result
}
result["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
result["cache_read_input_tokens"] = usage.CacheReadInputTokens
result["cache_creation"] = map[string]int{
"ephemeral_5m_input_tokens": usage.CacheCreation5mInputTokens,
"ephemeral_1h_input_tokens": usage.CacheCreation1hInputTokens,
}
return result
}
func canonicalizeCacheValue(value interface{}) string {
var buf bytes.Buffer
writeCanonicalJSON(&buf, value)
return buf.String()
}
func writeCanonicalJSON(buf *bytes.Buffer, value interface{}) {
switch v := value.(type) {
case nil:
buf.WriteString("null")
case string:
encoded, _ := json.Marshal(v)
buf.Write(encoded)
case bool:
if v {
buf.WriteString("true")
} else {
buf.WriteString("false")
}
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, json.Number:
encoded, _ := json.Marshal(v)
buf.Write(encoded)
case []interface{}:
buf.WriteByte('[')
for i, item := range v {
if i > 0 {
buf.WriteByte(',')
}
writeCanonicalJSON(buf, item)
}
buf.WriteByte(']')
case map[string]interface{}:
buf.WriteByte('{')
keys := make([]string, 0, len(v))
for key := range v {
if key == "cache_control" {
continue
}
keys = append(keys, key)
}
sort.Strings(keys)
for i, key := range keys {
if i > 0 {
buf.WriteByte(',')
}
encoded, _ := json.Marshal(key)
buf.Write(encoded)
buf.WriteByte(':')
writeCanonicalJSON(buf, v[key])
}
buf.WriteByte('}')
default:
encoded, _ := json.Marshal(v)
buf.Write(encoded)
}
}
func writeHashChunk(hasher hashWriter, chunk string) {
length := strconv.Itoa(len(chunk))
hasher.Write([]byte(length))
hasher.Write([]byte{0})
hasher.Write([]byte(chunk))
hasher.Write([]byte{0})
}
type hashWriter interface {
Write([]byte) (int, error)
Sum([]byte) []byte
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}