Some checks failed
Build Docker Image / build (push) Has been cancelled
Writes request body + response body of failed upstream calls to kiro_errors.log in the working directory. File is capped at 10MB; when the next write would exceed that, the file is truncated so only the most recent records are kept. Helps diagnose 400 "Improperly formed request" errors where both CW and Q reject the same payload. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
800 lines
22 KiB
Go
800 lines
22 KiB
Go
// Package proxy Kiro API 代理核心
|
||
// 负责调用 Kiro API 并解析 AWS Event Stream 响应
|
||
package proxy
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
cryptoRand "crypto/rand"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"kiro-go/config"
|
||
"log"
|
||
"net/http"
|
||
"net/url"
|
||
"strconv"
|
||
"strings"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
// 双端点配置(429 时自动 fallback)
|
||
type kiroEndpoint struct {
|
||
URL string
|
||
Origin string
|
||
AmzTarget string
|
||
Name string
|
||
}
|
||
|
||
var kiroEndpoints = []kiroEndpoint{
|
||
{
|
||
URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse",
|
||
Origin: "AI_EDITOR",
|
||
AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse",
|
||
Name: "CodeWhisperer",
|
||
},
|
||
{
|
||
URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||
Origin: "AI_EDITOR",
|
||
AmzTarget: "AmazonQDeveloperStreamingService.SendMessage",
|
||
Name: "AmazonQ",
|
||
},
|
||
}
|
||
|
||
// 全局 HTTP 客户端,支持运行时更换(代理重配置)
|
||
var kiroHttpStore atomic.Pointer[http.Client]
|
||
|
||
func init() {
|
||
InitKiroHttpClient("")
|
||
}
|
||
|
||
// buildKiroTransport 构建带可选代理的 Transport
|
||
func buildKiroTransport(proxyURL string) *http.Transport {
|
||
t := &http.Transport{
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 20,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
DisableCompression: false,
|
||
ForceAttemptHTTP2: true,
|
||
}
|
||
if proxyURL != "" {
|
||
if u, err := url.Parse(proxyURL); err == nil {
|
||
t.Proxy = http.ProxyURL(u)
|
||
// 代理不支持 HTTP/2 协议升级
|
||
t.ForceAttemptHTTP2 = false
|
||
}
|
||
}
|
||
return t
|
||
}
|
||
|
||
// InitKiroHttpClient 初始化(或重新初始化)Kiro API 的 HTTP 客户端
|
||
func InitKiroHttpClient(proxyURL string) {
|
||
client := &http.Client{
|
||
Timeout: 5 * time.Minute,
|
||
Transport: buildKiroTransport(proxyURL),
|
||
}
|
||
kiroHttpStore.Store(client)
|
||
}
|
||
|
||
// ==================== 请求结构 ====================
|
||
|
||
// KiroPayload Kiro API 请求体
|
||
type KiroPayload struct {
|
||
ConversationState struct {
|
||
ChatTriggerType string `json:"chatTriggerType"`
|
||
ConversationID string `json:"conversationId"`
|
||
CurrentMessage struct {
|
||
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||
} `json:"currentMessage"`
|
||
History []KiroHistoryMessage `json:"history,omitempty"`
|
||
} `json:"conversationState"`
|
||
ProfileArn string `json:"profileArn,omitempty"`
|
||
InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"`
|
||
}
|
||
|
||
type KiroUserInputMessage struct {
|
||
Content string `json:"content"`
|
||
ModelID string `json:"modelId,omitempty"`
|
||
Origin string `json:"origin"`
|
||
Images []KiroImage `json:"images,omitempty"`
|
||
UserInputMessageContext *UserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||
}
|
||
|
||
type UserInputMessageContext struct {
|
||
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||
}
|
||
|
||
type KiroToolWrapper struct {
|
||
ToolSpecification struct {
|
||
Name string `json:"name"`
|
||
Description string `json:"description"`
|
||
InputSchema InputSchema `json:"inputSchema"`
|
||
} `json:"toolSpecification"`
|
||
}
|
||
|
||
type InputSchema struct {
|
||
JSON interface{} `json:"json"`
|
||
}
|
||
|
||
type KiroToolResult struct {
|
||
ToolUseID string `json:"toolUseId"`
|
||
Content []KiroResultContent `json:"content"`
|
||
Status string `json:"status"`
|
||
}
|
||
|
||
type KiroResultContent struct {
|
||
Text string `json:"text"`
|
||
}
|
||
|
||
type KiroImage struct {
|
||
Format string `json:"format"`
|
||
Source struct {
|
||
Bytes string `json:"bytes"`
|
||
} `json:"source"`
|
||
}
|
||
|
||
type KiroHistoryMessage struct {
|
||
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||
}
|
||
|
||
type KiroAssistantResponseMessage struct {
|
||
Content string `json:"content"`
|
||
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||
}
|
||
|
||
type KiroToolUse struct {
|
||
ToolUseID string `json:"toolUseId"`
|
||
Name string `json:"name"`
|
||
Input map[string]interface{} `json:"input"`
|
||
}
|
||
|
||
type InferenceConfig struct {
|
||
MaxTokens int `json:"maxTokens,omitempty"`
|
||
Temperature float64 `json:"temperature,omitempty"`
|
||
TopP float64 `json:"topP,omitempty"`
|
||
}
|
||
|
||
// ==================== Stream Callbacks ====================
|
||
|
||
// KiroStreamCallback stream response callbacks
|
||
type KiroStreamCallback struct {
|
||
OnText func(text string, isThinking bool)
|
||
OnToolUse func(toolUse KiroToolUse)
|
||
OnComplete func(inputTokens, outputTokens int)
|
||
OnError func(err error)
|
||
OnCredits func(credits float64)
|
||
OnContextUsage func(percentage float64)
|
||
}
|
||
|
||
// ==================== API 调用 ====================
|
||
|
||
// getSortedEndpoints 根据首选端点配置排序端点列表
|
||
func getSortedEndpoints(preferred string) []kiroEndpoint {
|
||
if preferred == "amazonq" {
|
||
return []kiroEndpoint{kiroEndpoints[1], kiroEndpoints[0]}
|
||
}
|
||
if preferred == "codewhisperer" {
|
||
return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]}
|
||
}
|
||
// "auto" 或空值:默认顺序
|
||
return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]}
|
||
}
|
||
|
||
// CallKiroAPI 调用 Kiro API(流式),双端点自动 fallback
|
||
func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error {
|
||
if _, err := json.Marshal(payload); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 根据配置排序端点
|
||
endpoints := getSortedEndpoints(config.GetPreferredEndpoint())
|
||
invalidModelRetries := config.GetInvalidModelRetries()
|
||
firstByteTimeoutSec := config.GetFirstByteTimeoutSec()
|
||
firstByteRetries := config.GetFirstByteRetries()
|
||
|
||
modelID := payload.ConversationState.CurrentMessage.UserInputMessage.ModelID
|
||
accountLabel := account.Email
|
||
if accountLabel == "" {
|
||
accountLabel = account.ID
|
||
}
|
||
|
||
reqID := shortReqID()
|
||
epNames := make([]string, 0, len(endpoints))
|
||
for _, ep := range endpoints {
|
||
epNames = append(epNames, shortEndpoint(ep.Name))
|
||
}
|
||
log.Printf("[KiroAPI] REQ %s model=%s account=%s endpoints=%s", reqID, shortModel(modelID), accountLabel, strings.Join(epNames, ","))
|
||
|
||
requestStart := time.Now()
|
||
|
||
var lastErr error
|
||
var lastStatus string // 用于 FAIL 行总结
|
||
for _, ep := range endpoints {
|
||
payload.ConversationState.CurrentMessage.UserInputMessage.Origin = ep.Origin
|
||
epShort := shortEndpoint(ep.Name)
|
||
|
||
maxAttempts := invalidModelRetries + 1
|
||
if firstByteRetries+1 > maxAttempts {
|
||
maxAttempts = firstByteRetries + 1
|
||
}
|
||
invalidModelUsed := 0
|
||
firstByteUsed := 0
|
||
shouldFallback := false
|
||
|
||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||
reqBody, _ := json.Marshal(payload)
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
req, err := http.NewRequestWithContext(ctx, "POST", ep.URL, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
cancel()
|
||
lastErr = err
|
||
lastStatus = "ERR"
|
||
log.Printf("[KiroAPI] ERR %s %s/a%d new_request %v", reqID, epShort, attempt, err)
|
||
shouldFallback = true
|
||
break
|
||
}
|
||
|
||
host := ""
|
||
if parsedURL, parseErr := url.Parse(ep.URL); parseErr == nil {
|
||
host = parsedURL.Host
|
||
}
|
||
headerValues := buildStreamingHeaderValues(account, host)
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "*/*")
|
||
req.Header.Set("X-Amz-Target", ep.AmzTarget)
|
||
applyKiroBaseHeaders(req, account, headerValues)
|
||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||
req.Header.Set("Amz-Sdk-Request", fmt.Sprintf("attempt=%d; max=%d", attempt, maxAttempts))
|
||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||
|
||
attemptStart := time.Now()
|
||
|
||
resp, err := kiroHttpStore.Load().Do(req)
|
||
if err != nil {
|
||
cancel()
|
||
lastErr = err
|
||
lastStatus = "ERR"
|
||
log.Printf("[KiroAPI] ERR %s %s/a%d transport %s %v", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), err)
|
||
shouldFallback = true
|
||
break
|
||
}
|
||
|
||
if resp.StatusCode == 429 {
|
||
resp.Body.Close()
|
||
cancel()
|
||
lastErr = fmt.Errorf("quota exhausted on %s", ep.Name)
|
||
lastStatus = "429"
|
||
log.Printf("[KiroAPI] 429 %s %s/a%d quota_exhausted %s", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)))
|
||
shouldFallback = true
|
||
break
|
||
}
|
||
|
||
if resp.StatusCode != 200 {
|
||
errBody, _ := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
cancel()
|
||
lastErr = fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, ep.Name, string(errBody))
|
||
lastStatus = fmt.Sprintf("%d", resp.StatusCode)
|
||
bodyStr := string(errBody)
|
||
|
||
// 记录非 200 / 非 429 的请求体和响应体以便排查(本地滚动日志,上限 10MB)
|
||
if resp.StatusCode != 429 {
|
||
logKiroError(reqID, ep.Name, resp.StatusCode, accountLabel, modelID, reqBody, errBody)
|
||
}
|
||
|
||
if resp.StatusCode == 401 || resp.StatusCode == 403 {
|
||
log.Printf("[KiroAPI] %d %s %s/a%d auth_error %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200))
|
||
log.Printf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus)
|
||
return lastErr
|
||
}
|
||
|
||
if resp.StatusCode == 400 && strings.Contains(bodyStr, "INVALID_MODEL_ID") {
|
||
if invalidModelUsed < invalidModelRetries {
|
||
invalidModelUsed++
|
||
log.Printf("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s retry %d/%d", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), invalidModelUsed, invalidModelRetries)
|
||
continue
|
||
}
|
||
log.Printf("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s exhausted → fallback", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)))
|
||
shouldFallback = true
|
||
break
|
||
}
|
||
|
||
log.Printf("[KiroAPI] %d %s %s/a%d %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200))
|
||
shouldFallback = true
|
||
break
|
||
}
|
||
|
||
// 首字节超时
|
||
var firstByteReceived atomic.Bool
|
||
var firstByteTimedOut atomic.Bool
|
||
var firstByteAt time.Duration
|
||
var timer *time.Timer
|
||
if firstByteTimeoutSec > 0 {
|
||
timer = time.AfterFunc(time.Duration(firstByteTimeoutSec)*time.Second, func() {
|
||
if !firstByteReceived.Load() {
|
||
firstByteTimedOut.Store(true)
|
||
cancel()
|
||
}
|
||
})
|
||
}
|
||
|
||
onFirstByte := func() {
|
||
firstByteReceived.Store(true)
|
||
firstByteAt = time.Since(attemptStart)
|
||
if timer != nil {
|
||
timer.Stop()
|
||
}
|
||
}
|
||
|
||
err = parseEventStream(resp.Body, callback, onFirstByte)
|
||
resp.Body.Close()
|
||
if timer != nil {
|
||
timer.Stop()
|
||
}
|
||
cancel()
|
||
|
||
if err != nil && firstByteTimedOut.Load() && !firstByteReceived.Load() {
|
||
lastStatus = "TIMEOUT"
|
||
if firstByteUsed < firstByteRetries {
|
||
firstByteUsed++
|
||
lastErr = fmt.Errorf("first-byte timeout after %ds", firstByteTimeoutSec)
|
||
log.Printf("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds retry %d/%d", reqID, epShort, attempt, firstByteTimeoutSec, firstByteUsed, firstByteRetries)
|
||
continue
|
||
}
|
||
lastErr = fmt.Errorf("first-byte timeout after %ds on %s", firstByteTimeoutSec, ep.Name)
|
||
log.Printf("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds exhausted → fallback", reqID, epShort, attempt, firstByteTimeoutSec)
|
||
shouldFallback = true
|
||
break
|
||
}
|
||
|
||
status := "200"
|
||
if err != nil {
|
||
status = "ERR"
|
||
}
|
||
log.Printf("[KiroAPI] %s %s %s/a%d first_byte=%s total=%s", status, reqID, epShort, attempt, fmtMs(firstByteAt), fmtMs(time.Since(requestStart)))
|
||
return err
|
||
}
|
||
|
||
if !shouldFallback {
|
||
break
|
||
}
|
||
}
|
||
|
||
log.Printf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus)
|
||
if lastErr != nil {
|
||
return lastErr
|
||
}
|
||
return fmt.Errorf("all endpoints failed")
|
||
}
|
||
|
||
// shortReqID 生成 6 字符请求标识(base36)
|
||
func shortReqID() string {
|
||
var buf [3]byte
|
||
if _, err := cryptoRand.Read(buf[:]); err != nil {
|
||
return fmt.Sprintf("%06d", time.Now().UnixNano()%1000000)
|
||
}
|
||
return fmt.Sprintf("%02x%02x%02x", buf[0], buf[1], buf[2])
|
||
}
|
||
|
||
// shortEndpoint 把端点名缩短到 2 字符便于视觉对齐
|
||
func shortEndpoint(name string) string {
|
||
switch name {
|
||
case "CodeWhisperer":
|
||
return "CW"
|
||
case "AmazonQ":
|
||
return "Q "
|
||
default:
|
||
if len(name) >= 2 {
|
||
return name[:2]
|
||
}
|
||
return name
|
||
}
|
||
}
|
||
|
||
// shortModel 把长模型名截短:claude-opus-4.7 → opus-4.7
|
||
func shortModel(m string) string {
|
||
if strings.HasPrefix(m, "claude-") {
|
||
return m[len("claude-"):]
|
||
}
|
||
if m == "" {
|
||
return "-"
|
||
}
|
||
return m
|
||
}
|
||
|
||
// fmtMs 把耗时格式化成紧凑字符串:<1s 用 ms,>=1s 用 1 位小数 s
|
||
func fmtMs(d time.Duration) string {
|
||
if d <= 0 {
|
||
return "0ms"
|
||
}
|
||
if d < time.Second {
|
||
return fmt.Sprintf("%dms", d.Milliseconds())
|
||
}
|
||
return fmt.Sprintf("%.1fs", d.Seconds())
|
||
}
|
||
|
||
func truncateForLog(s string, max int) string {
|
||
s = strings.ReplaceAll(s, "\n", " ")
|
||
if len(s) <= max {
|
||
return s
|
||
}
|
||
return s[:max] + "...(truncated)"
|
||
}
|
||
|
||
// ==================== Event Stream 解析 ====================
|
||
|
||
// parseEventStream 解析 AWS Event Stream 二进制格式
|
||
// onFirstByte 会在读完第一个完整 event-stream 包 prelude 时触发一次(只一次),
|
||
// 供外层判断「首字节是否已收到」,以决定首字节超时时是否应该重试。
|
||
func parseEventStream(body io.Reader, callback *KiroStreamCallback, onFirstByte func()) error {
|
||
// 不使用 bufio,直接读取避免缓冲延迟
|
||
var inputTokens, outputTokens int
|
||
var totalCredits float64
|
||
var currentToolUse *toolUseState
|
||
var lastAssistantContent string
|
||
var lastReasoningContent string
|
||
firstByteFired := false
|
||
|
||
for {
|
||
// Prelude: 12 bytes (total_len + headers_len + crc)
|
||
prelude := make([]byte, 12)
|
||
_, err := io.ReadFull(body, prelude)
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if !firstByteFired {
|
||
firstByteFired = true
|
||
if onFirstByte != nil {
|
||
onFirstByte()
|
||
}
|
||
}
|
||
|
||
totalLength := int(prelude[0])<<24 | int(prelude[1])<<16 | int(prelude[2])<<8 | int(prelude[3])
|
||
headersLength := int(prelude[4])<<24 | int(prelude[5])<<16 | int(prelude[6])<<8 | int(prelude[7])
|
||
|
||
if totalLength < 16 {
|
||
continue
|
||
}
|
||
|
||
// 读取剩余部分
|
||
remaining := totalLength - 12
|
||
msgBuf := make([]byte, remaining)
|
||
_, err = io.ReadFull(body, msgBuf)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if headersLength > len(msgBuf)-4 {
|
||
continue
|
||
}
|
||
|
||
eventType := extractEventType(msgBuf[0:headersLength])
|
||
payloadBytes := msgBuf[headersLength : len(msgBuf)-4]
|
||
if len(payloadBytes) == 0 {
|
||
continue
|
||
}
|
||
|
||
var event map[string]interface{}
|
||
if err := json.Unmarshal(payloadBytes, &event); err != nil {
|
||
continue
|
||
}
|
||
|
||
inputTokens, outputTokens = updateTokensFromEvent(event, inputTokens, outputTokens)
|
||
|
||
// 处理事件
|
||
switch eventType {
|
||
case "assistantResponseEvent":
|
||
if content, ok := event["content"].(string); ok && content != "" {
|
||
normalized := normalizeChunk(content, &lastAssistantContent)
|
||
if normalized != "" {
|
||
callback.OnText(normalized, false)
|
||
}
|
||
}
|
||
case "reasoningContentEvent":
|
||
if text, ok := event["text"].(string); ok && text != "" {
|
||
normalized := normalizeChunk(text, &lastReasoningContent)
|
||
if normalized != "" {
|
||
callback.OnText(normalized, true)
|
||
}
|
||
}
|
||
case "toolUseEvent":
|
||
currentToolUse = handleToolUseEvent(event, currentToolUse, callback)
|
||
case "meteringEvent":
|
||
if usage, ok := event["usage"].(float64); ok {
|
||
totalCredits += usage
|
||
}
|
||
case "contextUsageEvent":
|
||
if pct, ok := event["contextUsagePercentage"].(float64); ok {
|
||
if callback.OnContextUsage != nil {
|
||
callback.OnContextUsage(pct)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if callback.OnCredits != nil && totalCredits > 0 {
|
||
callback.OnCredits(totalCredits)
|
||
}
|
||
|
||
callback.OnComplete(inputTokens, outputTokens)
|
||
return nil
|
||
}
|
||
|
||
func updateTokensFromEvent(event map[string]interface{}, currentInputTokens, currentOutputTokens int) (int, int) {
|
||
candidates := []map[string]interface{}{event}
|
||
collectUsageMaps(event, &candidates)
|
||
|
||
inputTokens := currentInputTokens
|
||
outputTokens := currentOutputTokens
|
||
|
||
for _, usage := range candidates {
|
||
if usage == nil {
|
||
continue
|
||
}
|
||
|
||
if v, ok := readTokenNumber(usage,
|
||
"outputTokens", "completionTokens", "totalOutputTokens",
|
||
"output_tokens", "completion_tokens", "total_output_tokens",
|
||
); ok {
|
||
outputTokens = v
|
||
}
|
||
|
||
if v, ok := readTokenNumber(usage,
|
||
"inputTokens", "promptTokens", "totalInputTokens",
|
||
"input_tokens", "prompt_tokens", "total_input_tokens",
|
||
); ok {
|
||
inputTokens = v
|
||
continue
|
||
}
|
||
|
||
uncached, _ := readTokenNumber(usage, "uncachedInputTokens", "uncached_input_tokens")
|
||
cacheRead, _ := readTokenNumber(usage, "cacheReadInputTokens", "cache_read_input_tokens")
|
||
cacheWrite, _ := readTokenNumber(usage, "cacheWriteInputTokens", "cache_write_input_tokens", "cacheCreationInputTokens", "cache_creation_input_tokens")
|
||
if uncached+cacheRead+cacheWrite > 0 {
|
||
inputTokens = uncached + cacheRead + cacheWrite
|
||
continue
|
||
}
|
||
|
||
total, ok := readTokenNumber(usage, "totalTokens", "total_tokens")
|
||
if ok && total > 0 {
|
||
candidateOutput := outputTokens
|
||
if v, vok := readTokenNumber(usage,
|
||
"outputTokens", "completionTokens", "totalOutputTokens",
|
||
"output_tokens", "completion_tokens", "total_output_tokens",
|
||
); vok {
|
||
candidateOutput = v
|
||
}
|
||
if total-candidateOutput > 0 {
|
||
inputTokens = total - candidateOutput
|
||
}
|
||
}
|
||
}
|
||
|
||
return inputTokens, outputTokens
|
||
}
|
||
|
||
// getContextWindowSize returns the context window size (in tokens) for a model.
|
||
func getContextWindowSize(model string) int {
|
||
m := strings.ToLower(model)
|
||
// sonnet-4.6, opus-4.6, opus-4.7 all have 1M context windows
|
||
if strings.Contains(m, "4.6") || strings.Contains(m, "4-6") ||
|
||
strings.Contains(m, "4.7") || strings.Contains(m, "4-7") {
|
||
return 1_000_000
|
||
}
|
||
return 200_000
|
||
}
|
||
|
||
func collectUsageMaps(v interface{}, out *[]map[string]interface{}) {
|
||
switch t := v.(type) {
|
||
case map[string]interface{}:
|
||
for k, child := range t {
|
||
lk := strings.ToLower(k)
|
||
if lk == "usage" || lk == "tokenusage" || lk == "token_usage" {
|
||
if m, ok := child.(map[string]interface{}); ok {
|
||
*out = append(*out, m)
|
||
}
|
||
}
|
||
collectUsageMaps(child, out)
|
||
}
|
||
case []interface{}:
|
||
for _, child := range t {
|
||
collectUsageMaps(child, out)
|
||
}
|
||
}
|
||
}
|
||
|
||
func normalizeChunk(chunk string, previous *string) string {
|
||
if chunk == "" {
|
||
return ""
|
||
}
|
||
|
||
prev := *previous
|
||
if prev == "" {
|
||
*previous = chunk
|
||
return chunk
|
||
}
|
||
|
||
if chunk == prev {
|
||
return ""
|
||
}
|
||
|
||
if strings.HasPrefix(chunk, prev) {
|
||
delta := chunk[len(prev):]
|
||
*previous = chunk
|
||
return delta
|
||
}
|
||
|
||
if strings.HasPrefix(prev, chunk) {
|
||
return ""
|
||
}
|
||
|
||
maxOverlap := 0
|
||
maxLen := len(prev)
|
||
if len(chunk) < maxLen {
|
||
maxLen = len(chunk)
|
||
}
|
||
for i := maxLen; i > 0; i-- {
|
||
if strings.HasSuffix(prev, chunk[:i]) {
|
||
maxOverlap = i
|
||
break
|
||
}
|
||
}
|
||
|
||
*previous = chunk
|
||
if maxOverlap > 0 {
|
||
return chunk[maxOverlap:]
|
||
}
|
||
|
||
return chunk
|
||
}
|
||
|
||
func readTokenNumber(m map[string]interface{}, keys ...string) (int, bool) {
|
||
for _, k := range keys {
|
||
v, ok := m[k]
|
||
if !ok {
|
||
continue
|
||
}
|
||
switch n := v.(type) {
|
||
case float64:
|
||
return int(n), true
|
||
case int:
|
||
return n, true
|
||
case int64:
|
||
return int(n), true
|
||
case json.Number:
|
||
if parsed, err := n.Int64(); err == nil {
|
||
return int(parsed), true
|
||
}
|
||
case string:
|
||
if parsed, err := strconv.Atoi(n); err == nil {
|
||
return parsed, true
|
||
}
|
||
if parsed, err := strconv.ParseFloat(n, 64); err == nil {
|
||
return int(parsed), true
|
||
}
|
||
}
|
||
}
|
||
return 0, false
|
||
}
|
||
|
||
// ==================== Tool Use 处理 ====================
|
||
|
||
type toolUseState struct {
|
||
ToolUseID string
|
||
Name string
|
||
InputBuffer strings.Builder
|
||
}
|
||
|
||
func handleToolUseEvent(event map[string]interface{}, current *toolUseState, callback *KiroStreamCallback) *toolUseState {
|
||
toolUseID, _ := event["toolUseId"].(string)
|
||
name, _ := event["name"].(string)
|
||
isStop, _ := event["stop"].(bool)
|
||
|
||
if toolUseID != "" && name != "" {
|
||
if current == nil {
|
||
current = &toolUseState{ToolUseID: toolUseID, Name: name}
|
||
} else if current.ToolUseID != toolUseID {
|
||
finishToolUse(current, callback)
|
||
current = &toolUseState{ToolUseID: toolUseID, Name: name}
|
||
}
|
||
}
|
||
|
||
if current != nil {
|
||
if input, ok := event["input"].(string); ok {
|
||
current.InputBuffer.WriteString(input)
|
||
} else if inputObj, ok := event["input"].(map[string]interface{}); ok {
|
||
data, _ := json.Marshal(inputObj)
|
||
current.InputBuffer.Reset()
|
||
current.InputBuffer.Write(data)
|
||
}
|
||
}
|
||
|
||
if isStop && current != nil {
|
||
finishToolUse(current, callback)
|
||
return nil
|
||
}
|
||
|
||
return current
|
||
}
|
||
|
||
func finishToolUse(state *toolUseState, callback *KiroStreamCallback) {
|
||
var input map[string]interface{}
|
||
if state.InputBuffer.Len() > 0 {
|
||
json.Unmarshal([]byte(state.InputBuffer.String()), &input)
|
||
}
|
||
if input == nil {
|
||
input = make(map[string]interface{})
|
||
}
|
||
callback.OnToolUse(KiroToolUse{
|
||
ToolUseID: state.ToolUseID,
|
||
Name: state.Name,
|
||
Input: input,
|
||
})
|
||
}
|
||
|
||
// extractEventType 从 headers 中提取事件类型
|
||
func extractEventType(headers []byte) string {
|
||
offset := 0
|
||
for offset < len(headers) {
|
||
if offset >= len(headers) {
|
||
break
|
||
}
|
||
nameLen := int(headers[offset])
|
||
offset++
|
||
if offset+nameLen > len(headers) {
|
||
break
|
||
}
|
||
name := string(headers[offset : offset+nameLen])
|
||
offset += nameLen
|
||
if offset >= len(headers) {
|
||
break
|
||
}
|
||
valueType := headers[offset]
|
||
offset++
|
||
|
||
if valueType == 7 { // String
|
||
if offset+2 > len(headers) {
|
||
break
|
||
}
|
||
valueLen := int(headers[offset])<<8 | int(headers[offset+1])
|
||
offset += 2
|
||
if offset+valueLen > len(headers) {
|
||
break
|
||
}
|
||
value := string(headers[offset : offset+valueLen])
|
||
offset += valueLen
|
||
if name == ":event-type" {
|
||
return value
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 跳过其他类型
|
||
skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16}
|
||
if valueType == 6 {
|
||
if offset+2 > len(headers) {
|
||
break
|
||
}
|
||
l := int(headers[offset])<<8 | int(headers[offset+1])
|
||
offset += 2 + l
|
||
} else if skip, ok := skipSizes[valueType]; ok {
|
||
offset += skip
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
return ""
|
||
}
|