Files
kirogo/proxy/kiro.go

574 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package proxy Kiro API 代理核心
// 负责调用 Kiro API 并解析 AWS Event Stream 响应
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"kiro-go/config"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
// 双端点配置429 时自动 fallback
type kiroEndpoint struct {
URL string
Origin string
AmzTarget string
Name string
}
var kiroEndpoints = []kiroEndpoint{
{
URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse",
Origin: "AI_EDITOR",
AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse",
Name: "CodeWhisperer",
},
{
URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
Origin: "CLI",
AmzTarget: "AmazonQDeveloperStreamingService.SendMessage",
Name: "AmazonQ",
},
}
// 全局 HTTP 客户端,复用连接池
var kiroHttpClient = &http.Client{
Timeout: 5 * time.Minute,
Transport: &http.Transport{
MaxIdleConns: 100, // 最大空闲连接数
MaxIdleConnsPerHost: 20, // 每个 Host 最大空闲连接数
IdleConnTimeout: 90 * time.Second, // 空闲连接超时
DisableCompression: false, // 启用压缩
ForceAttemptHTTP2: true, // 尝试使用 HTTP/2
},
}
// ==================== 请求结构 ====================
// KiroPayload Kiro API 请求体
type KiroPayload struct {
ConversationState struct {
ChatTriggerType string `json:"chatTriggerType"`
ConversationID string `json:"conversationId"`
CurrentMessage struct {
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
} `json:"currentMessage"`
History []KiroHistoryMessage `json:"history,omitempty"`
} `json:"conversationState"`
ProfileArn string `json:"profileArn,omitempty"`
InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"`
}
type KiroUserInputMessage struct {
Content string `json:"content"`
ModelID string `json:"modelId,omitempty"`
Origin string `json:"origin"`
Images []KiroImage `json:"images,omitempty"`
UserInputMessageContext *UserInputMessageContext `json:"userInputMessageContext,omitempty"`
}
type UserInputMessageContext struct {
Tools []KiroToolWrapper `json:"tools,omitempty"`
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
}
type KiroToolWrapper struct {
ToolSpecification struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema InputSchema `json:"inputSchema"`
} `json:"toolSpecification"`
}
type InputSchema struct {
JSON interface{} `json:"json"`
}
type KiroToolResult struct {
ToolUseID string `json:"toolUseId"`
Content []KiroResultContent `json:"content"`
Status string `json:"status"`
}
type KiroResultContent struct {
Text string `json:"text"`
}
type KiroImage struct {
Format string `json:"format"`
Source struct {
Bytes string `json:"bytes"`
} `json:"source"`
}
type KiroHistoryMessage struct {
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
}
type KiroAssistantResponseMessage struct {
Content string `json:"content"`
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
}
type KiroToolUse struct {
ToolUseID string `json:"toolUseId"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
}
type InferenceConfig struct {
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// ==================== 流式回调 ====================
// KiroStreamCallback 流式响应回调
type KiroStreamCallback struct {
OnText func(text string, isThinking bool)
OnToolUse func(toolUse KiroToolUse)
OnComplete func(inputTokens, outputTokens int)
OnError func(err error)
OnCredits func(credits float64)
}
// ==================== API 调用 ====================
// getSortedEndpoints 根据首选端点配置排序端点列表
func getSortedEndpoints(preferred string) []kiroEndpoint {
if preferred == "amazonq" {
return []kiroEndpoint{kiroEndpoints[1], kiroEndpoints[0]}
}
if preferred == "codewhisperer" {
return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]}
}
// "auto" 或空值:默认顺序
return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]}
}
// CallKiroAPI 调用 Kiro API流式双端点自动 fallback
func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error {
if _, err := json.Marshal(payload); err != nil {
return err
}
// 根据配置排序端点
endpoints := getSortedEndpoints(config.GetPreferredEndpoint())
var lastErr error
for _, ep := range endpoints {
// 更新 payload 中的 origin
payload.ConversationState.CurrentMessage.UserInputMessage.Origin = ep.Origin
reqBody, _ := json.Marshal(payload)
req, err := http.NewRequest("POST", ep.URL, bytes.NewReader(reqBody))
if err != nil {
lastErr = err
continue
}
host := ""
if parsedURL, parseErr := url.Parse(ep.URL); parseErr == nil {
host = parsedURL.Host
}
headerValues := buildStreamingHeaderValues(account, host)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
req.Header.Set("X-Amz-Target", ep.AmzTarget)
applyKiroBaseHeaders(req, account, headerValues)
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
resp, err := kiroHttpClient.Do(req)
if err != nil {
lastErr = err
fmt.Printf("[KiroAPI] Endpoint %s failed: %v\n", ep.Name, err)
continue
}
if resp.StatusCode == 429 {
resp.Body.Close()
fmt.Printf("[KiroAPI] Endpoint %s quota exhausted (429), trying next...\n", ep.Name)
lastErr = fmt.Errorf("quota exhausted on %s", ep.Name)
continue
}
if resp.StatusCode != 200 {
errBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
lastErr = fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, ep.Name, string(errBody))
// 认证错误不继续尝试
if resp.StatusCode == 401 || resp.StatusCode == 403 {
return lastErr
}
fmt.Printf("[KiroAPI] Endpoint %s error: %v\n", ep.Name, lastErr)
continue
}
err = parseEventStream(resp.Body, callback)
resp.Body.Close()
return err
}
if lastErr != nil {
return lastErr
}
return fmt.Errorf("all endpoints failed")
}
// ==================== Event Stream 解析 ====================
// parseEventStream 解析 AWS Event Stream 二进制格式
func parseEventStream(body io.Reader, callback *KiroStreamCallback) error {
// 不使用 bufio直接读取避免缓冲延迟
var inputTokens, outputTokens int
var totalCredits float64
var currentToolUse *toolUseState
var lastAssistantContent string
var lastReasoningContent string
for {
// Prelude: 12 bytes (total_len + headers_len + crc)
prelude := make([]byte, 12)
_, err := io.ReadFull(body, prelude)
if err == io.EOF {
break
}
if err != nil {
return err
}
totalLength := int(prelude[0])<<24 | int(prelude[1])<<16 | int(prelude[2])<<8 | int(prelude[3])
headersLength := int(prelude[4])<<24 | int(prelude[5])<<16 | int(prelude[6])<<8 | int(prelude[7])
if totalLength < 16 {
continue
}
// 读取剩余部分
remaining := totalLength - 12
msgBuf := make([]byte, remaining)
_, err = io.ReadFull(body, msgBuf)
if err != nil {
return err
}
if headersLength > len(msgBuf)-4 {
continue
}
eventType := extractEventType(msgBuf[0:headersLength])
payloadBytes := msgBuf[headersLength : len(msgBuf)-4]
if len(payloadBytes) == 0 {
continue
}
var event map[string]interface{}
if err := json.Unmarshal(payloadBytes, &event); err != nil {
continue
}
inputTokens, outputTokens = updateTokensFromEvent(event, inputTokens, outputTokens)
// 处理事件
switch eventType {
case "assistantResponseEvent":
if content, ok := event["content"].(string); ok && content != "" {
normalized := normalizeChunk(content, &lastAssistantContent)
if normalized != "" {
callback.OnText(normalized, false)
}
}
case "reasoningContentEvent":
if text, ok := event["text"].(string); ok && text != "" {
normalized := normalizeChunk(text, &lastReasoningContent)
if normalized != "" {
callback.OnText(normalized, true)
}
}
case "toolUseEvent":
currentToolUse = handleToolUseEvent(event, currentToolUse, callback)
case "meteringEvent":
if usage, ok := event["usage"].(float64); ok {
totalCredits += usage
}
}
}
if callback.OnCredits != nil && totalCredits > 0 {
callback.OnCredits(totalCredits)
}
callback.OnComplete(inputTokens, outputTokens)
return nil
}
func updateTokensFromEvent(event map[string]interface{}, currentInputTokens, currentOutputTokens int) (int, int) {
candidates := []map[string]interface{}{event}
collectUsageMaps(event, &candidates)
inputTokens := currentInputTokens
outputTokens := currentOutputTokens
for _, usage := range candidates {
if usage == nil {
continue
}
if v, ok := readTokenNumber(usage,
"outputTokens", "completionTokens", "totalOutputTokens",
"output_tokens", "completion_tokens", "total_output_tokens",
); ok {
outputTokens = v
}
if v, ok := readTokenNumber(usage,
"inputTokens", "promptTokens", "totalInputTokens",
"input_tokens", "prompt_tokens", "total_input_tokens",
); ok {
inputTokens = v
continue
}
uncached, _ := readTokenNumber(usage, "uncachedInputTokens", "uncached_input_tokens")
cacheRead, _ := readTokenNumber(usage, "cacheReadInputTokens", "cache_read_input_tokens")
cacheWrite, _ := readTokenNumber(usage, "cacheWriteInputTokens", "cache_write_input_tokens", "cacheCreationInputTokens", "cache_creation_input_tokens")
if uncached+cacheRead+cacheWrite > 0 {
inputTokens = uncached + cacheRead + cacheWrite
continue
}
total, ok := readTokenNumber(usage, "totalTokens", "total_tokens")
if ok && total > 0 {
candidateOutput := outputTokens
if v, vok := readTokenNumber(usage,
"outputTokens", "completionTokens", "totalOutputTokens",
"output_tokens", "completion_tokens", "total_output_tokens",
); vok {
candidateOutput = v
}
if total-candidateOutput > 0 {
inputTokens = total - candidateOutput
}
}
}
return inputTokens, outputTokens
}
func collectUsageMaps(v interface{}, out *[]map[string]interface{}) {
switch t := v.(type) {
case map[string]interface{}:
for k, child := range t {
lk := strings.ToLower(k)
if lk == "usage" || lk == "tokenusage" || lk == "token_usage" {
if m, ok := child.(map[string]interface{}); ok {
*out = append(*out, m)
}
}
collectUsageMaps(child, out)
}
case []interface{}:
for _, child := range t {
collectUsageMaps(child, out)
}
}
}
func normalizeChunk(chunk string, previous *string) string {
if chunk == "" {
return ""
}
prev := *previous
if prev == "" {
*previous = chunk
return chunk
}
if chunk == prev {
return ""
}
if strings.HasPrefix(chunk, prev) {
delta := chunk[len(prev):]
*previous = chunk
return delta
}
if strings.HasPrefix(prev, chunk) {
return ""
}
maxOverlap := 0
maxLen := len(prev)
if len(chunk) < maxLen {
maxLen = len(chunk)
}
for i := maxLen; i > 0; i-- {
if strings.HasSuffix(prev, chunk[:i]) {
maxOverlap = i
break
}
}
*previous = chunk
if maxOverlap > 0 {
return chunk[maxOverlap:]
}
return chunk
}
func readTokenNumber(m map[string]interface{}, keys ...string) (int, bool) {
for _, k := range keys {
v, ok := m[k]
if !ok {
continue
}
switch n := v.(type) {
case float64:
return int(n), true
case int:
return n, true
case int64:
return int(n), true
case json.Number:
if parsed, err := n.Int64(); err == nil {
return int(parsed), true
}
case string:
if parsed, err := strconv.Atoi(n); err == nil {
return parsed, true
}
if parsed, err := strconv.ParseFloat(n, 64); err == nil {
return int(parsed), true
}
}
}
return 0, false
}
// ==================== Tool Use 处理 ====================
type toolUseState struct {
ToolUseID string
Name string
InputBuffer strings.Builder
}
func handleToolUseEvent(event map[string]interface{}, current *toolUseState, callback *KiroStreamCallback) *toolUseState {
toolUseID, _ := event["toolUseId"].(string)
name, _ := event["name"].(string)
isStop, _ := event["stop"].(bool)
if toolUseID != "" && name != "" {
if current == nil {
current = &toolUseState{ToolUseID: toolUseID, Name: name}
} else if current.ToolUseID != toolUseID {
finishToolUse(current, callback)
current = &toolUseState{ToolUseID: toolUseID, Name: name}
}
}
if current != nil {
if input, ok := event["input"].(string); ok {
current.InputBuffer.WriteString(input)
} else if inputObj, ok := event["input"].(map[string]interface{}); ok {
data, _ := json.Marshal(inputObj)
current.InputBuffer.Reset()
current.InputBuffer.Write(data)
}
}
if isStop && current != nil {
finishToolUse(current, callback)
return nil
}
return current
}
func finishToolUse(state *toolUseState, callback *KiroStreamCallback) {
var input map[string]interface{}
if state.InputBuffer.Len() > 0 {
json.Unmarshal([]byte(state.InputBuffer.String()), &input)
}
if input == nil {
input = make(map[string]interface{})
}
callback.OnToolUse(KiroToolUse{
ToolUseID: state.ToolUseID,
Name: state.Name,
Input: input,
})
}
// extractEventType 从 headers 中提取事件类型
func extractEventType(headers []byte) string {
offset := 0
for offset < len(headers) {
if offset >= len(headers) {
break
}
nameLen := int(headers[offset])
offset++
if offset+nameLen > len(headers) {
break
}
name := string(headers[offset : offset+nameLen])
offset += nameLen
if offset >= len(headers) {
break
}
valueType := headers[offset]
offset++
if valueType == 7 { // String
if offset+2 > len(headers) {
break
}
valueLen := int(headers[offset])<<8 | int(headers[offset+1])
offset += 2
if offset+valueLen > len(headers) {
break
}
value := string(headers[offset : offset+valueLen])
offset += valueLen
if name == ":event-type" {
return value
}
continue
}
// 跳过其他类型
skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16}
if valueType == 6 {
if offset+2 > len(headers) {
break
}
l := int(headers[offset])<<8 | int(headers[offset+1])
offset += 2 + l
} else if skip, ok := skipSizes[valueType]; ok {
offset += skip
} else {
break
}
}
return ""
}