fix: accurate input_tokens via contextUsageEvent + smart routing for SDK clients

This commit is contained in:
Naive YH
2026-05-11 17:23:21 +08:00
parent acc5fe45ce
commit 31aa6aa421
2 changed files with 346 additions and 11 deletions

View File

@@ -262,6 +262,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
h.handleClaudeMessages(w, r)
case path == "/cc/v1/messages":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleClaudeMessagesBuffered(w, r)
case path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
@@ -631,9 +637,13 @@ func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Re
// 转换请求
kiroPayload := ClaudeToKiro(&req, thinking)
// 流式或非流式
// 流式或非流式SDK 客户端Claude Code、opencode 等)自动使用缓冲模式以获取精确 message_start
if req.Stream {
h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
if isAnthropicSDKRequest(r) {
h.handleClaudeStreamBuffered(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
} else {
h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
}
} else {
h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
}
@@ -657,6 +667,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
msgID := "msg_" + uuid.New().String()
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
var toolUses []KiroToolUse
var nextContentIndex int
var rawContentBuilder strings.Builder
@@ -978,6 +989,9 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
@@ -999,7 +1013,9 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
}
closeActiveBlock()
if inputTokens <= 0 {
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
@@ -1042,6 +1058,290 @@ func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event str
flusher.Flush()
}
// isAnthropicSDKRequest 检测请求是否来自基于 Anthropic 官方 SDK 的客户端
// (Claude Code、opencode、Roo Code 等),这类客户端读取 message_start.input_tokens 来展示上下文用量
func isAnthropicSDKRequest(r *http.Request) bool {
if r.Header.Get("x-stainless-lang") != "" {
return true
}
ua := strings.ToLower(r.Header.Get("User-Agent"))
return strings.Contains(ua, "claude") || strings.Contains(ua, "anthropic-sdk")
}
// handleClaudeMessagesBuffered Claude API 缓冲模式处理(/cc/v1/messages 及自动识别的 SDK 客户端)
func (h *Handler) handleClaudeMessagesBuffered(w http.ResponseWriter, r *http.Request) {
h.handleClaudeMessagesInternalBuffered(w, r)
}
func (h *Handler) handleClaudeMessagesInternalBuffered(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error())
return
}
if msg := validateClaudeRequestShape(&req); msg != "" {
h.sendClaudeError(w, 400, "invalid_request_error", msg)
return
}
account := h.pool.GetNext()
if account == nil {
h.sendClaudeError(w, 503, "api_error", "No available accounts")
return
}
if err := h.ensureValidToken(account); err != nil {
h.sendClaudeError(w, 503, "api_error", "Token refresh failed: "+err.Error())
return
}
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
req.Model = actualModel
estimatedInputTokens := estimateClaudeRequestInputTokens(&req)
cacheProfile := h.promptCache.BuildClaudeProfile(&req, estimatedInputTokens)
cacheUsage := h.promptCache.Compute(account.ID, cacheProfile)
kiroPayload := ClaudeToKiro(&req, thinking)
if req.Stream {
h.handleClaudeStreamBuffered(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
} else {
h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
}
}
// handleClaudeStreamBuffered Claude 缓冲流式响应
// 等待上游流完成后得到精确 input_tokens回填 message_start 后一次性推送所有 SSE 事件
// 等待期间每 25 秒发送 ping 事件保活
func (h *Handler) handleClaudeStreamBuffered(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendClaudeError(w, 500, "api_error", "Streaming not supported")
return
}
// ping 保活 goroutine25 秒间隔,防止客户端超时断开)
pingStop := make(chan struct{})
var stopOnce sync.Once
stopPing := func() { stopOnce.Do(func() { close(pingStop) }) }
defer stopPing()
go func() {
ticker := time.NewTicker(25 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fmt.Fprintf(w, "event: ping\ndata: {}\n\n")
flusher.Flush()
case <-pingStop:
return
}
}
}()
// 缓冲阶段:收集所有内容
var contentBuilder strings.Builder
var thinkingBuilder strings.Builder
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
thinkingBuilder.WriteString(text)
} else {
contentBuilder.WriteString(text)
}
},
OnToolUse: func(tu KiroToolUse) {
toolUses = append(toolUses, tu)
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
stopPing()
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
h.sendSSE(w, flusher, "error", map[string]interface{}{
"type": "error",
"error": map[string]string{"type": "api_error", "message": err.Error()},
})
return
}
// 确定精确 input_tokens
finalInputTokens := estimatedInputTokens
if realInputTokens > 0 {
finalInputTokens = realInputTokens
} else if inputTokens > 0 {
finalInputTokens = inputTokens
}
// 处理 thinking 内容
thinkingFormat := config.GetThinkingConfig().ClaudeFormat
rawContent := contentBuilder.String()
rawThinking := thinkingBuilder.String()
outputContent, extractedReasoning := extractThinkingFromContent(rawContent)
thinkingOutput := rawThinking
if thinking && thinkingOutput == "" && extractedReasoning != "" {
thinkingOutput = extractedReasoning
}
if !thinking {
thinkingOutput = ""
}
outputTokens = estimateClaudeOutputTokens(outputContent, thinkingOutput, toolUses)
h.recordSuccess(finalInputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, finalInputTokens+outputTokens, credits)
h.promptCache.Update(account.ID, cacheProfile)
msgID := "msg_" + uuid.New().String()
contentIndex := 0
// 推送阶段message_start 携带精确 input_tokens
h.sendSSE(w, flusher, "message_start", map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []interface{}{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"usage": buildClaudeUsageMap(finalInputTokens, 0, cacheUsage, cacheProfile != nil),
},
})
h.sendSSE(w, flusher, "ping", map[string]interface{}{"type": "ping"})
// 推送 thinking 块
if thinking && thinkingOutput != "" {
switch thinkingFormat {
case "think":
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "text_delta", "text": "<think>" + thinkingOutput + "</think>"},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
case "reasoning_content":
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "text_delta", "text": thinkingOutput},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
default: // native thinking block
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "thinking", "thinking": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "thinking_delta", "thinking": thinkingOutput},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
}
}
// 推送文本块
if outputContent != "" {
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "text_delta", "text": outputContent},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
}
// 推送工具调用块
for _, tu := range toolUses {
inputJSON, _ := json.Marshal(tu.Input)
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]interface{}{
"type": "tool_use", "id": tu.ToolUseID, "name": tu.Name, "input": map[string]interface{}{},
},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]interface{}{"type": "input_json_delta", "partial_json": string(inputJSON)},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
}
stopReason := "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
}
h.sendSSE(w, flusher, "message_delta", map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": stopReason},
"usage": buildClaudeUsageMap(finalInputTokens, outputTokens, cacheUsage, cacheProfile != nil),
})
h.sendSSE(w, flusher, "message_stop", map[string]interface{}{"type": "message_stop"})
}
// backgroundStatsSaver 后台定时保存统计数据
func (h *Handler) backgroundStatsSaver() {
ticker := time.NewTicker(30 * time.Second)
@@ -1103,6 +1403,7 @@ func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.A
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
@@ -1125,6 +1426,9 @@ func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.A
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
@@ -1145,7 +1449,9 @@ func (h *Handler) handleClaudeNonStream(w http.ResponseWriter, account *config.A
thinkingContent = ""
}
if inputTokens <= 0 {
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputTokens = estimateClaudeOutputTokens(finalContent, thinkingContent, toolUses)
@@ -1262,6 +1568,7 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
var toolCallIndex int
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
var rawContentBuilder strings.Builder
var rawReasoningBuilder strings.Builder
@@ -1554,6 +1861,9 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
@@ -1570,7 +1880,9 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
eventThinkingOpen = false
}
if inputTokens <= 0 {
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputContent, extractedReasoning := extractThinkingFromContent(rawContentBuilder.String())
@@ -1626,6 +1938,7 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
@@ -1639,6 +1952,9 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A
OnComplete: func(inTok, outTok int) { inputTokens = inTok; outputTokens = outTok },
OnError: func(err error) { h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429")) },
OnCredits: func(c float64) { credits = c },
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
@@ -1657,7 +1973,9 @@ func (h *Handler) handleOpenAINonStream(w http.ResponseWriter, account *config.A
reasoningContent = ""
}
if inputTokens <= 0 {
if realInputTokens > 0 {
inputTokens = realInputTokens
} else if inputTokens <= 0 {
inputTokens = estimatedInputTokens
}
outputTokens = estimateOpenAIOutputTokens(finalContent, reasoningContent, toolUses)