// Package proxy Kiro API 代理核心 // 负责调用 Kiro API 并解析 AWS Event Stream 响应 package proxy import ( "bytes" "context" cryptoRand "crypto/rand" "encoding/json" "fmt" "io" "kiro-go/config" "log" "net/http" "net/url" "strconv" "strings" "sync/atomic" "time" "github.com/google/uuid" ) // 双端点配置(429 时自动 fallback) type kiroEndpoint struct { URL string Origin string AmzTarget string Name string } var kiroEndpoints = []kiroEndpoint{ { URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse", Origin: "AI_EDITOR", AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse", Name: "CodeWhisperer", }, { URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", Origin: "AI_EDITOR", AmzTarget: "AmazonQDeveloperStreamingService.SendMessage", Name: "AmazonQ", }, } // 全局 HTTP 客户端,支持运行时更换(代理重配置) var kiroHttpStore atomic.Pointer[http.Client] func init() { InitKiroHttpClient("") } // buildKiroTransport 构建带可选代理的 Transport func buildKiroTransport(proxyURL string) *http.Transport { t := &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 20, IdleConnTimeout: 90 * time.Second, DisableCompression: false, ForceAttemptHTTP2: true, } if proxyURL != "" { if u, err := url.Parse(proxyURL); err == nil { t.Proxy = http.ProxyURL(u) // 代理不支持 HTTP/2 协议升级 t.ForceAttemptHTTP2 = false } } return t } // InitKiroHttpClient 初始化(或重新初始化)Kiro API 的 HTTP 客户端 func InitKiroHttpClient(proxyURL string) { client := &http.Client{ Timeout: 5 * time.Minute, Transport: buildKiroTransport(proxyURL), } kiroHttpStore.Store(client) } // ==================== 请求结构 ==================== // KiroPayload Kiro API 请求体 type KiroPayload struct { ConversationState struct { ChatTriggerType string `json:"chatTriggerType"` ConversationID string `json:"conversationId"` CurrentMessage struct { UserInputMessage KiroUserInputMessage `json:"userInputMessage"` } `json:"currentMessage"` History []KiroHistoryMessage `json:"history,omitempty"` } `json:"conversationState"` ProfileArn string `json:"profileArn,omitempty"` InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"` } type KiroUserInputMessage struct { Content string `json:"content"` ModelID string `json:"modelId,omitempty"` Origin string `json:"origin"` Images []KiroImage `json:"images,omitempty"` UserInputMessageContext *UserInputMessageContext `json:"userInputMessageContext,omitempty"` } type UserInputMessageContext struct { Tools []KiroToolWrapper `json:"tools,omitempty"` ToolResults []KiroToolResult `json:"toolResults,omitempty"` } type KiroToolWrapper struct { ToolSpecification struct { Name string `json:"name"` Description string `json:"description"` InputSchema InputSchema `json:"inputSchema"` } `json:"toolSpecification"` } type InputSchema struct { JSON interface{} `json:"json"` } type KiroToolResult struct { ToolUseID string `json:"toolUseId"` Content []KiroResultContent `json:"content"` Status string `json:"status"` } type KiroResultContent struct { Text string `json:"text"` } type KiroImage struct { Format string `json:"format"` Source struct { Bytes string `json:"bytes"` } `json:"source"` } type KiroHistoryMessage struct { UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"` AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` } type KiroAssistantResponseMessage struct { Content string `json:"content"` ToolUses []KiroToolUse `json:"toolUses,omitempty"` } type KiroToolUse struct { ToolUseID string `json:"toolUseId"` Name string `json:"name"` Input map[string]interface{} `json:"input"` } type InferenceConfig struct { MaxTokens int `json:"maxTokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"topP,omitempty"` } // ==================== Stream Callbacks ==================== // KiroStreamCallback stream response callbacks type KiroStreamCallback struct { OnText func(text string, isThinking bool) OnToolUse func(toolUse KiroToolUse) OnComplete func(inputTokens, outputTokens int) OnError func(err error) OnCredits func(credits float64) OnContextUsage func(percentage float64) } // ==================== API 调用 ==================== // getSortedEndpoints 根据首选端点配置排序端点列表 func getSortedEndpoints(preferred string) []kiroEndpoint { if preferred == "amazonq" { return []kiroEndpoint{kiroEndpoints[1], kiroEndpoints[0]} } if preferred == "codewhisperer" { return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]} } // "auto" 或空值:默认顺序 return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]} } // CallKiroAPI 调用 Kiro API(流式),双端点自动 fallback func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error { if _, err := json.Marshal(payload); err != nil { return err } // 根据配置排序端点 endpoints := getSortedEndpoints(config.GetPreferredEndpoint()) invalidModelRetries := config.GetInvalidModelRetries() firstByteTimeoutSec := config.GetFirstByteTimeoutSec() firstByteRetries := config.GetFirstByteRetries() modelID := payload.ConversationState.CurrentMessage.UserInputMessage.ModelID accountLabel := account.Email if accountLabel == "" { accountLabel = account.ID } reqID := shortReqID() epNames := make([]string, 0, len(endpoints)) for _, ep := range endpoints { epNames = append(epNames, shortEndpoint(ep.Name)) } log.Printf("[KiroAPI] REQ %s model=%s account=%s endpoints=%s", reqID, shortModel(modelID), accountLabel, strings.Join(epNames, ",")) requestStart := time.Now() var lastErr error var lastStatus string // 用于 FAIL 行总结 for _, ep := range endpoints { payload.ConversationState.CurrentMessage.UserInputMessage.Origin = ep.Origin epShort := shortEndpoint(ep.Name) maxAttempts := invalidModelRetries + 1 if firstByteRetries+1 > maxAttempts { maxAttempts = firstByteRetries + 1 } invalidModelUsed := 0 firstByteUsed := 0 shouldFallback := false for attempt := 1; attempt <= maxAttempts; attempt++ { reqBody, _ := json.Marshal(payload) ctx, cancel := context.WithCancel(context.Background()) req, err := http.NewRequestWithContext(ctx, "POST", ep.URL, bytes.NewReader(reqBody)) if err != nil { cancel() lastErr = err lastStatus = "ERR" log.Printf("[KiroAPI] ERR %s %s/a%d new_request %v", reqID, epShort, attempt, err) shouldFallback = true break } host := "" if parsedURL, parseErr := url.Parse(ep.URL); parseErr == nil { host = parsedURL.Host } headerValues := buildStreamingHeaderValues(account, host) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "*/*") req.Header.Set("X-Amz-Target", ep.AmzTarget) applyKiroBaseHeaders(req, account, headerValues) req.Header.Set("x-amzn-kiro-agent-mode", "vibe") req.Header.Set("x-amzn-codewhisperer-optout", "true") req.Header.Set("Amz-Sdk-Request", fmt.Sprintf("attempt=%d; max=%d", attempt, maxAttempts)) req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) attemptStart := time.Now() resp, err := kiroHttpStore.Load().Do(req) if err != nil { cancel() lastErr = err lastStatus = "ERR" log.Printf("[KiroAPI] ERR %s %s/a%d transport %s %v", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), err) shouldFallback = true break } if resp.StatusCode == 429 { resp.Body.Close() cancel() lastErr = fmt.Errorf("quota exhausted on %s", ep.Name) lastStatus = "429" log.Printf("[KiroAPI] 429 %s %s/a%d quota_exhausted %s", reqID, epShort, attempt, fmtMs(time.Since(attemptStart))) shouldFallback = true break } if resp.StatusCode != 200 { errBody, _ := io.ReadAll(resp.Body) resp.Body.Close() cancel() lastErr = fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, ep.Name, string(errBody)) lastStatus = fmt.Sprintf("%d", resp.StatusCode) bodyStr := string(errBody) // 记录非 200 / 非 429 的请求体和响应体以便排查(本地滚动日志,上限 10MB) if resp.StatusCode != 429 { logKiroError(reqID, ep.Name, resp.StatusCode, accountLabel, modelID, reqBody, errBody) } if resp.StatusCode == 401 || resp.StatusCode == 403 { log.Printf("[KiroAPI] %d %s %s/a%d auth_error %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200)) log.Printf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus) return lastErr } if resp.StatusCode == 400 && strings.Contains(bodyStr, "INVALID_MODEL_ID") { if invalidModelUsed < invalidModelRetries { invalidModelUsed++ log.Printf("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s retry %d/%d", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), invalidModelUsed, invalidModelRetries) continue } log.Printf("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s exhausted → fallback", reqID, epShort, attempt, fmtMs(time.Since(attemptStart))) shouldFallback = true break } log.Printf("[KiroAPI] %d %s %s/a%d %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200)) shouldFallback = true break } // 首字节超时 var firstByteReceived atomic.Bool var firstByteTimedOut atomic.Bool var firstByteAt time.Duration var timer *time.Timer if firstByteTimeoutSec > 0 { timer = time.AfterFunc(time.Duration(firstByteTimeoutSec)*time.Second, func() { if !firstByteReceived.Load() { firstByteTimedOut.Store(true) cancel() } }) } onFirstByte := func() { firstByteReceived.Store(true) firstByteAt = time.Since(attemptStart) if timer != nil { timer.Stop() } } err = parseEventStream(resp.Body, callback, onFirstByte) resp.Body.Close() if timer != nil { timer.Stop() } cancel() if err != nil && firstByteTimedOut.Load() && !firstByteReceived.Load() { lastStatus = "TIMEOUT" if firstByteUsed < firstByteRetries { firstByteUsed++ lastErr = fmt.Errorf("first-byte timeout after %ds", firstByteTimeoutSec) log.Printf("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds retry %d/%d", reqID, epShort, attempt, firstByteTimeoutSec, firstByteUsed, firstByteRetries) continue } lastErr = fmt.Errorf("first-byte timeout after %ds on %s", firstByteTimeoutSec, ep.Name) log.Printf("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds exhausted → fallback", reqID, epShort, attempt, firstByteTimeoutSec) shouldFallback = true break } status := "200" if err != nil { status = "ERR" } log.Printf("[KiroAPI] %s %s %s/a%d first_byte=%s total=%s", status, reqID, epShort, attempt, fmtMs(firstByteAt), fmtMs(time.Since(requestStart))) return err } if !shouldFallback { break } } log.Printf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus) if lastErr != nil { return lastErr } return fmt.Errorf("all endpoints failed") } // shortReqID 生成 6 字符请求标识(base36) func shortReqID() string { var buf [3]byte if _, err := cryptoRand.Read(buf[:]); err != nil { return fmt.Sprintf("%06d", time.Now().UnixNano()%1000000) } return fmt.Sprintf("%02x%02x%02x", buf[0], buf[1], buf[2]) } // shortEndpoint 把端点名缩短到 2 字符便于视觉对齐 func shortEndpoint(name string) string { switch name { case "CodeWhisperer": return "CW" case "AmazonQ": return "Q " default: if len(name) >= 2 { return name[:2] } return name } } // shortModel 把长模型名截短:claude-opus-4.7 → opus-4.7 func shortModel(m string) string { if strings.HasPrefix(m, "claude-") { return m[len("claude-"):] } if m == "" { return "-" } return m } // fmtMs 把耗时格式化成紧凑字符串:<1s 用 ms,>=1s 用 1 位小数 s func fmtMs(d time.Duration) string { if d <= 0 { return "0ms" } if d < time.Second { return fmt.Sprintf("%dms", d.Milliseconds()) } return fmt.Sprintf("%.1fs", d.Seconds()) } func truncateForLog(s string, max int) string { s = strings.ReplaceAll(s, "\n", " ") if len(s) <= max { return s } return s[:max] + "...(truncated)" } // ==================== Event Stream 解析 ==================== // parseEventStream 解析 AWS Event Stream 二进制格式 // onFirstByte 会在读完第一个完整 event-stream 包 prelude 时触发一次(只一次), // 供外层判断「首字节是否已收到」,以决定首字节超时时是否应该重试。 func parseEventStream(body io.Reader, callback *KiroStreamCallback, onFirstByte func()) error { // 不使用 bufio,直接读取避免缓冲延迟 var inputTokens, outputTokens int var totalCredits float64 var currentToolUse *toolUseState var lastAssistantContent string var lastReasoningContent string firstByteFired := false for { // Prelude: 12 bytes (total_len + headers_len + crc) prelude := make([]byte, 12) _, err := io.ReadFull(body, prelude) if err == io.EOF { break } if err != nil { return err } if !firstByteFired { firstByteFired = true if onFirstByte != nil { onFirstByte() } } totalLength := int(prelude[0])<<24 | int(prelude[1])<<16 | int(prelude[2])<<8 | int(prelude[3]) headersLength := int(prelude[4])<<24 | int(prelude[5])<<16 | int(prelude[6])<<8 | int(prelude[7]) if totalLength < 16 { continue } // 读取剩余部分 remaining := totalLength - 12 msgBuf := make([]byte, remaining) _, err = io.ReadFull(body, msgBuf) if err != nil { return err } if headersLength > len(msgBuf)-4 { continue } eventType := extractEventType(msgBuf[0:headersLength]) payloadBytes := msgBuf[headersLength : len(msgBuf)-4] if len(payloadBytes) == 0 { continue } var event map[string]interface{} if err := json.Unmarshal(payloadBytes, &event); err != nil { continue } inputTokens, outputTokens = updateTokensFromEvent(event, inputTokens, outputTokens) // 处理事件 switch eventType { case "assistantResponseEvent": if content, ok := event["content"].(string); ok && content != "" { normalized := normalizeChunk(content, &lastAssistantContent) if normalized != "" { callback.OnText(normalized, false) } } case "reasoningContentEvent": if text, ok := event["text"].(string); ok && text != "" { normalized := normalizeChunk(text, &lastReasoningContent) if normalized != "" { callback.OnText(normalized, true) } } case "toolUseEvent": currentToolUse = handleToolUseEvent(event, currentToolUse, callback) case "meteringEvent": if usage, ok := event["usage"].(float64); ok { totalCredits += usage } case "contextUsageEvent": if pct, ok := event["contextUsagePercentage"].(float64); ok { if callback.OnContextUsage != nil { callback.OnContextUsage(pct) } } } } if callback.OnCredits != nil && totalCredits > 0 { callback.OnCredits(totalCredits) } callback.OnComplete(inputTokens, outputTokens) return nil } func updateTokensFromEvent(event map[string]interface{}, currentInputTokens, currentOutputTokens int) (int, int) { candidates := []map[string]interface{}{event} collectUsageMaps(event, &candidates) inputTokens := currentInputTokens outputTokens := currentOutputTokens for _, usage := range candidates { if usage == nil { continue } if v, ok := readTokenNumber(usage, "outputTokens", "completionTokens", "totalOutputTokens", "output_tokens", "completion_tokens", "total_output_tokens", ); ok { outputTokens = v } if v, ok := readTokenNumber(usage, "inputTokens", "promptTokens", "totalInputTokens", "input_tokens", "prompt_tokens", "total_input_tokens", ); ok { inputTokens = v continue } uncached, _ := readTokenNumber(usage, "uncachedInputTokens", "uncached_input_tokens") cacheRead, _ := readTokenNumber(usage, "cacheReadInputTokens", "cache_read_input_tokens") cacheWrite, _ := readTokenNumber(usage, "cacheWriteInputTokens", "cache_write_input_tokens", "cacheCreationInputTokens", "cache_creation_input_tokens") if uncached+cacheRead+cacheWrite > 0 { inputTokens = uncached + cacheRead + cacheWrite continue } total, ok := readTokenNumber(usage, "totalTokens", "total_tokens") if ok && total > 0 { candidateOutput := outputTokens if v, vok := readTokenNumber(usage, "outputTokens", "completionTokens", "totalOutputTokens", "output_tokens", "completion_tokens", "total_output_tokens", ); vok { candidateOutput = v } if total-candidateOutput > 0 { inputTokens = total - candidateOutput } } } return inputTokens, outputTokens } // getContextWindowSize returns the context window size (in tokens) for a model. func getContextWindowSize(model string) int { m := strings.ToLower(model) // sonnet-4.6, opus-4.6, opus-4.7 all have 1M context windows if strings.Contains(m, "4.6") || strings.Contains(m, "4-6") || strings.Contains(m, "4.7") || strings.Contains(m, "4-7") { return 1_000_000 } return 200_000 } func collectUsageMaps(v interface{}, out *[]map[string]interface{}) { switch t := v.(type) { case map[string]interface{}: for k, child := range t { lk := strings.ToLower(k) if lk == "usage" || lk == "tokenusage" || lk == "token_usage" { if m, ok := child.(map[string]interface{}); ok { *out = append(*out, m) } } collectUsageMaps(child, out) } case []interface{}: for _, child := range t { collectUsageMaps(child, out) } } } func normalizeChunk(chunk string, previous *string) string { if chunk == "" { return "" } prev := *previous if prev == "" { *previous = chunk return chunk } if chunk == prev { return "" } if strings.HasPrefix(chunk, prev) { delta := chunk[len(prev):] *previous = chunk return delta } if strings.HasPrefix(prev, chunk) { return "" } maxOverlap := 0 maxLen := len(prev) if len(chunk) < maxLen { maxLen = len(chunk) } for i := maxLen; i > 0; i-- { if strings.HasSuffix(prev, chunk[:i]) { maxOverlap = i break } } *previous = chunk if maxOverlap > 0 { return chunk[maxOverlap:] } return chunk } func readTokenNumber(m map[string]interface{}, keys ...string) (int, bool) { for _, k := range keys { v, ok := m[k] if !ok { continue } switch n := v.(type) { case float64: return int(n), true case int: return n, true case int64: return int(n), true case json.Number: if parsed, err := n.Int64(); err == nil { return int(parsed), true } case string: if parsed, err := strconv.Atoi(n); err == nil { return parsed, true } if parsed, err := strconv.ParseFloat(n, 64); err == nil { return int(parsed), true } } } return 0, false } // ==================== Tool Use 处理 ==================== type toolUseState struct { ToolUseID string Name string InputBuffer strings.Builder } func handleToolUseEvent(event map[string]interface{}, current *toolUseState, callback *KiroStreamCallback) *toolUseState { toolUseID, _ := event["toolUseId"].(string) name, _ := event["name"].(string) isStop, _ := event["stop"].(bool) if toolUseID != "" && name != "" { if current == nil { current = &toolUseState{ToolUseID: toolUseID, Name: name} } else if current.ToolUseID != toolUseID { finishToolUse(current, callback) current = &toolUseState{ToolUseID: toolUseID, Name: name} } } if current != nil { if input, ok := event["input"].(string); ok { current.InputBuffer.WriteString(input) } else if inputObj, ok := event["input"].(map[string]interface{}); ok { data, _ := json.Marshal(inputObj) current.InputBuffer.Reset() current.InputBuffer.Write(data) } } if isStop && current != nil { finishToolUse(current, callback) return nil } return current } func finishToolUse(state *toolUseState, callback *KiroStreamCallback) { var input map[string]interface{} if state.InputBuffer.Len() > 0 { json.Unmarshal([]byte(state.InputBuffer.String()), &input) } if input == nil { input = make(map[string]interface{}) } callback.OnToolUse(KiroToolUse{ ToolUseID: state.ToolUseID, Name: state.Name, Input: input, }) } // extractEventType 从 headers 中提取事件类型 func extractEventType(headers []byte) string { offset := 0 for offset < len(headers) { if offset >= len(headers) { break } nameLen := int(headers[offset]) offset++ if offset+nameLen > len(headers) { break } name := string(headers[offset : offset+nameLen]) offset += nameLen if offset >= len(headers) { break } valueType := headers[offset] offset++ if valueType == 7 { // String if offset+2 > len(headers) { break } valueLen := int(headers[offset])<<8 | int(headers[offset+1]) offset += 2 if offset+valueLen > len(headers) { break } value := string(headers[offset : offset+valueLen]) offset += valueLen if name == ":event-type" { return value } continue } // 跳过其他类型 skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16} if valueType == 6 { if offset+2 > len(headers) { break } l := int(headers[offset])<<8 | int(headers[offset+1]) offset += 2 + l } else if skip, ok := skipSizes[valueType]; ok { offset += skip } else { break } } return "" }