refactor: remove buffered stream mode, keep contextUsageEvent for accurate input tokens

This commit is contained in:
Quorinex
2026-05-11 19:47:39 +08:00
parent 31aa6aa421
commit 0203357b34
2 changed files with 8 additions and 301 deletions

View File

@@ -262,12 +262,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
h.handleClaudeMessages(w, r)
case path == "/cc/v1/messages":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
return
}
h.handleClaudeMessagesBuffered(w, r)
case path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
if !h.validateApiKey(r) {
h.sendClaudeError(w, 401, "authentication_error", "Invalid or missing API key")
@@ -637,13 +631,9 @@ func (h *Handler) handleClaudeMessagesInternal(w http.ResponseWriter, r *http.Re
// 转换请求
kiroPayload := ClaudeToKiro(&req, thinking)
// 流式或非流式SDK 客户端Claude Code、opencode 等)自动使用缓冲模式以获取精确 message_start
// Stream or non-stream
if req.Stream {
if isAnthropicSDKRequest(r) {
h.handleClaudeStreamBuffered(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
} else {
h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
}
h.handleClaudeStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
} else {
h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
}
@@ -1058,290 +1048,6 @@ func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event str
flusher.Flush()
}
// isAnthropicSDKRequest 检测请求是否来自基于 Anthropic 官方 SDK 的客户端
// (Claude Code、opencode、Roo Code 等),这类客户端读取 message_start.input_tokens 来展示上下文用量
func isAnthropicSDKRequest(r *http.Request) bool {
if r.Header.Get("x-stainless-lang") != "" {
return true
}
ua := strings.ToLower(r.Header.Get("User-Agent"))
return strings.Contains(ua, "claude") || strings.Contains(ua, "anthropic-sdk")
}
// handleClaudeMessagesBuffered Claude API 缓冲模式处理(/cc/v1/messages 及自动识别的 SDK 客户端)
func (h *Handler) handleClaudeMessagesBuffered(w http.ResponseWriter, r *http.Request) {
h.handleClaudeMessagesInternalBuffered(w, r)
}
func (h *Handler) handleClaudeMessagesInternalBuffered(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method Not Allowed", 405)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Failed to read request body")
return
}
var req ClaudeRequest
if err := json.Unmarshal(body, &req); err != nil {
h.sendClaudeError(w, 400, "invalid_request_error", "Invalid JSON: "+err.Error())
return
}
if msg := validateClaudeRequestShape(&req); msg != "" {
h.sendClaudeError(w, 400, "invalid_request_error", msg)
return
}
account := h.pool.GetNext()
if account == nil {
h.sendClaudeError(w, 503, "api_error", "No available accounts")
return
}
if err := h.ensureValidToken(account); err != nil {
h.sendClaudeError(w, 503, "api_error", "Token refresh failed: "+err.Error())
return
}
thinkingCfg := config.GetThinkingConfig()
actualModel, thinking := ParseModelAndThinking(req.Model, thinkingCfg.Suffix)
req.Model = actualModel
estimatedInputTokens := estimateClaudeRequestInputTokens(&req)
cacheProfile := h.promptCache.BuildClaudeProfile(&req, estimatedInputTokens)
cacheUsage := h.promptCache.Compute(account.ID, cacheProfile)
kiroPayload := ClaudeToKiro(&req, thinking)
if req.Stream {
h.handleClaudeStreamBuffered(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
} else {
h.handleClaudeNonStream(w, account, kiroPayload, req.Model, thinking, estimatedInputTokens, cacheUsage, cacheProfile)
}
}
// handleClaudeStreamBuffered Claude 缓冲流式响应
// 等待上游流完成后得到精确 input_tokens回填 message_start 后一次性推送所有 SSE 事件
// 等待期间每 25 秒发送 ping 事件保活
func (h *Handler) handleClaudeStreamBuffered(w http.ResponseWriter, account *config.Account, payload *KiroPayload, model string, thinking bool, estimatedInputTokens int, cacheUsage promptCacheUsage, cacheProfile *promptCacheProfile) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
h.sendClaudeError(w, 500, "api_error", "Streaming not supported")
return
}
// ping 保活 goroutine25 秒间隔,防止客户端超时断开)
pingStop := make(chan struct{})
var stopOnce sync.Once
stopPing := func() { stopOnce.Do(func() { close(pingStop) }) }
defer stopPing()
go func() {
ticker := time.NewTicker(25 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fmt.Fprintf(w, "event: ping\ndata: {}\n\n")
flusher.Flush()
case <-pingStop:
return
}
}
}()
// 缓冲阶段:收集所有内容
var contentBuilder strings.Builder
var thinkingBuilder strings.Builder
var toolUses []KiroToolUse
var inputTokens, outputTokens int
var credits float64
var realInputTokens int
callback := &KiroStreamCallback{
OnText: func(text string, isThinking bool) {
if isThinking {
thinkingBuilder.WriteString(text)
} else {
contentBuilder.WriteString(text)
}
},
OnToolUse: func(tu KiroToolUse) {
toolUses = append(toolUses, tu)
},
OnComplete: func(inTok, outTok int) {
inputTokens = inTok
outputTokens = outTok
},
OnError: func(err error) {
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
},
OnCredits: func(c float64) {
credits = c
},
OnContextUsage: func(pct float64) {
realInputTokens = int(pct * float64(getContextWindowSize(model)) / 100.0)
},
}
err := CallKiroAPI(account, payload, callback)
stopPing()
if err != nil {
h.recordFailure()
h.pool.RecordError(account.ID, strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "quota"))
h.sendSSE(w, flusher, "error", map[string]interface{}{
"type": "error",
"error": map[string]string{"type": "api_error", "message": err.Error()},
})
return
}
// 确定精确 input_tokens
finalInputTokens := estimatedInputTokens
if realInputTokens > 0 {
finalInputTokens = realInputTokens
} else if inputTokens > 0 {
finalInputTokens = inputTokens
}
// 处理 thinking 内容
thinkingFormat := config.GetThinkingConfig().ClaudeFormat
rawContent := contentBuilder.String()
rawThinking := thinkingBuilder.String()
outputContent, extractedReasoning := extractThinkingFromContent(rawContent)
thinkingOutput := rawThinking
if thinking && thinkingOutput == "" && extractedReasoning != "" {
thinkingOutput = extractedReasoning
}
if !thinking {
thinkingOutput = ""
}
outputTokens = estimateClaudeOutputTokens(outputContent, thinkingOutput, toolUses)
h.recordSuccess(finalInputTokens, outputTokens, credits)
h.pool.RecordSuccess(account.ID)
h.pool.UpdateStats(account.ID, finalInputTokens+outputTokens, credits)
h.promptCache.Update(account.ID, cacheProfile)
msgID := "msg_" + uuid.New().String()
contentIndex := 0
// 推送阶段message_start 携带精确 input_tokens
h.sendSSE(w, flusher, "message_start", map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []interface{}{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"usage": buildClaudeUsageMap(finalInputTokens, 0, cacheUsage, cacheProfile != nil),
},
})
h.sendSSE(w, flusher, "ping", map[string]interface{}{"type": "ping"})
// 推送 thinking 块
if thinking && thinkingOutput != "" {
switch thinkingFormat {
case "think":
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "text_delta", "text": "<think>" + thinkingOutput + "</think>"},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
case "reasoning_content":
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "text_delta", "text": thinkingOutput},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
default: // native thinking block
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "thinking", "thinking": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "thinking_delta", "thinking": thinkingOutput},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
}
}
// 推送文本块
if outputContent != "" {
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]string{"type": "text", "text": ""},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]string{"type": "text_delta", "text": outputContent},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
}
// 推送工具调用块
for _, tu := range toolUses {
inputJSON, _ := json.Marshal(tu.Input)
h.sendSSE(w, flusher, "content_block_start", map[string]interface{}{
"type": "content_block_start", "index": contentIndex,
"content_block": map[string]interface{}{
"type": "tool_use", "id": tu.ToolUseID, "name": tu.Name, "input": map[string]interface{}{},
},
})
h.sendSSE(w, flusher, "content_block_delta", map[string]interface{}{
"type": "content_block_delta", "index": contentIndex,
"delta": map[string]interface{}{"type": "input_json_delta", "partial_json": string(inputJSON)},
})
h.sendSSE(w, flusher, "content_block_stop", map[string]interface{}{
"type": "content_block_stop", "index": contentIndex,
})
contentIndex++
}
stopReason := "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
}
h.sendSSE(w, flusher, "message_delta", map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{"stop_reason": stopReason},
"usage": buildClaudeUsageMap(finalInputTokens, outputTokens, cacheUsage, cacheProfile != nil),
})
h.sendSSE(w, flusher, "message_stop", map[string]interface{}{"type": "message_stop"})
}
// backgroundStatsSaver 后台定时保存统计数据
func (h *Handler) backgroundStatsSaver() {
ticker := time.NewTicker(30 * time.Second)

View File

@@ -132,9 +132,9 @@ type InferenceConfig struct {
TopP float64 `json:"topP,omitempty"`
}
// ==================== 流式回调 ====================
// ==================== Stream Callbacks ====================
// KiroStreamCallback 流式响应回调
// KiroStreamCallback stream response callbacks
type KiroStreamCallback struct {
OnText func(text string, isThinking bool)
OnToolUse func(toolUse KiroToolUse)
@@ -377,11 +377,12 @@ func updateTokensFromEvent(event map[string]interface{}, currentInputTokens, cur
return inputTokens, outputTokens
}
// getContextWindowSize 返回模型的上下文窗口大小token 数)
// Kiro 托管的 Claude 模型窗口由 AWS 硬性规定,此处与官方保持一致
// getContextWindowSize returns the context window size (in tokens) for a model.
func getContextWindowSize(model string) int {
m := strings.ToLower(model)
if strings.Contains(m, "4.6") || strings.Contains(m, "4-6") {
// sonnet-4.6, opus-4.6, opus-4.7 all have 1M context windows
if strings.Contains(m, "4.6") || strings.Contains(m, "4-6") ||
strings.Contains(m, "4.7") || strings.Contains(m, "4-7") {
return 1_000_000
}
return 200_000