Files
kirogo/proxy/kiro.go

376 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package proxy Kiro API 代理核心
// 负责调用 Kiro API 并解析 AWS Event Stream 响应
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"kiro-api-proxy/config"
"net/http"
"strings"
"time"
"github.com/google/uuid"
)
const (
KiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse"
KiroVersion = "0.6.18"
)
// ==================== 请求结构 ====================
// KiroPayload Kiro API 请求体
type KiroPayload struct {
ConversationState struct {
ChatTriggerType string `json:"chatTriggerType"`
ConversationID string `json:"conversationId"`
CurrentMessage struct {
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
} `json:"currentMessage"`
History []KiroHistoryMessage `json:"history,omitempty"`
} `json:"conversationState"`
ProfileArn string `json:"profileArn,omitempty"`
InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"`
}
type KiroUserInputMessage struct {
Content string `json:"content"`
ModelID string `json:"modelId,omitempty"`
Origin string `json:"origin"`
Images []KiroImage `json:"images,omitempty"`
UserInputMessageContext *UserInputMessageContext `json:"userInputMessageContext,omitempty"`
}
type UserInputMessageContext struct {
Tools []KiroToolWrapper `json:"tools,omitempty"`
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
}
type KiroToolWrapper struct {
ToolSpecification struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema InputSchema `json:"inputSchema"`
} `json:"toolSpecification"`
}
type InputSchema struct {
JSON interface{} `json:"json"`
}
type KiroToolResult struct {
ToolUseID string `json:"toolUseId"`
Content []KiroResultContent `json:"content"`
Status string `json:"status"`
}
type KiroResultContent struct {
Text string `json:"text"`
}
type KiroImage struct {
Format string `json:"format"`
Source struct {
Bytes string `json:"bytes"`
} `json:"source"`
}
type KiroHistoryMessage struct {
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
}
type KiroAssistantResponseMessage struct {
Content string `json:"content"`
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
}
type KiroToolUse struct {
ToolUseID string `json:"toolUseId"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
}
type InferenceConfig struct {
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// ==================== 流式回调 ====================
// KiroStreamCallback 流式响应回调
type KiroStreamCallback struct {
OnText func(text string, isThinking bool)
OnToolUse func(toolUse KiroToolUse)
OnComplete func(inputTokens, outputTokens int)
OnError func(err error)
OnCredits func(credits float64)
}
// ==================== API 调用 ====================
// CallKiroAPI 调用 Kiro API流式
func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
// 预估输入 token约 3 字符 = 1 token
estimatedInputTokens := max(1, len(body)/3)
req, err := http.NewRequest("POST", KiroEndpoint, bytes.NewReader(body))
if err != nil {
return err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
req.Header.Set("X-Amz-Target", "AmazonCodeWhispererStreamingService.GenerateAssistantResponse")
// User-Agent 包含机器码
machineId := account.MachineId
var userAgent, amzUserAgent string
if machineId != "" {
userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s-%s", KiroVersion, machineId)
amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s %s", KiroVersion, machineId)
} else {
userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", KiroVersion)
amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s", KiroVersion)
}
req.Header.Set("User-Agent", userAgent)
req.Header.Set("X-Amz-User-Agent", amzUserAgent)
req.Header.Set("x-amzn-kiro-agent-mode", "spec")
req.Header.Set("x-amzn-codewhisperer-optout", "true")
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
req.Header.Set("Authorization", "Bearer "+account.AccessToken)
client := &http.Client{Timeout: 5 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
return parseEventStream(resp.Body, callback, estimatedInputTokens)
}
// ==================== Event Stream 解析 ====================
// parseEventStream 解析 AWS Event Stream 二进制格式
func parseEventStream(body io.Reader, callback *KiroStreamCallback, estimatedInputTokens int) error {
// 不使用 bufio直接读取避免缓冲延迟
var inputTokens, outputTokens int
var totalOutputChars int
var totalCredits float64
var currentToolUse *toolUseState
for {
// Prelude: 12 bytes (total_len + headers_len + crc)
prelude := make([]byte, 12)
_, err := io.ReadFull(body, prelude)
if err == io.EOF {
break
}
if err != nil {
return err
}
totalLength := int(prelude[0])<<24 | int(prelude[1])<<16 | int(prelude[2])<<8 | int(prelude[3])
headersLength := int(prelude[4])<<24 | int(prelude[5])<<16 | int(prelude[6])<<8 | int(prelude[7])
if totalLength < 16 {
continue
}
// 读取剩余部分
remaining := totalLength - 12
msgBuf := make([]byte, remaining)
_, err = io.ReadFull(body, msgBuf)
if err != nil {
return err
}
if headersLength > len(msgBuf)-4 {
continue
}
eventType := extractEventType(msgBuf[0:headersLength])
payloadBytes := msgBuf[headersLength : len(msgBuf)-4]
if len(payloadBytes) == 0 {
continue
}
var event map[string]interface{}
if err := json.Unmarshal(payloadBytes, &event); err != nil {
continue
}
// 处理事件
switch eventType {
case "assistantResponseEvent":
if content, ok := event["content"].(string); ok && content != "" {
callback.OnText(content, false)
totalOutputChars += len(content)
}
case "reasoningContentEvent":
if text, ok := event["text"].(string); ok && text != "" {
callback.OnText(text, true)
totalOutputChars += len(text)
}
case "toolUseEvent":
currentToolUse = handleToolUseEvent(event, currentToolUse, callback)
case "messageMetadataEvent", "metadataEvent":
if tokenUsage, ok := event["tokenUsage"].(map[string]interface{}); ok {
if v, ok := tokenUsage["outputTokens"].(float64); ok {
outputTokens = int(v)
}
uncached, _ := tokenUsage["uncachedInputTokens"].(float64)
cacheRead, _ := tokenUsage["cacheReadInputTokens"].(float64)
cacheWrite, _ := tokenUsage["cacheWriteInputTokens"].(float64)
inputTokens = int(uncached + cacheRead + cacheWrite)
}
case "meteringEvent":
if usage, ok := event["usage"].(float64); ok {
totalCredits += usage
}
}
}
// 估算 token约 3 字符 = 1 token
if outputTokens == 0 && totalOutputChars > 0 {
outputTokens = max(1, totalOutputChars/3)
}
// 如果 Kiro 没返回 inputTokens使用预估值
if inputTokens == 0 {
inputTokens = estimatedInputTokens
}
if callback.OnCredits != nil && totalCredits > 0 {
callback.OnCredits(totalCredits)
}
callback.OnComplete(inputTokens, outputTokens)
return nil
}
// ==================== Tool Use 处理 ====================
type toolUseState struct {
ToolUseID string
Name string
InputBuffer strings.Builder
}
func handleToolUseEvent(event map[string]interface{}, current *toolUseState, callback *KiroStreamCallback) *toolUseState {
toolUseID, _ := event["toolUseId"].(string)
name, _ := event["name"].(string)
isStop, _ := event["stop"].(bool)
if toolUseID != "" && name != "" {
if current == nil {
current = &toolUseState{ToolUseID: toolUseID, Name: name}
} else if current.ToolUseID != toolUseID {
finishToolUse(current, callback)
current = &toolUseState{ToolUseID: toolUseID, Name: name}
}
}
if current != nil {
if input, ok := event["input"].(string); ok {
current.InputBuffer.WriteString(input)
} else if inputObj, ok := event["input"].(map[string]interface{}); ok {
data, _ := json.Marshal(inputObj)
current.InputBuffer.Reset()
current.InputBuffer.Write(data)
}
}
if isStop && current != nil {
finishToolUse(current, callback)
return nil
}
return current
}
func finishToolUse(state *toolUseState, callback *KiroStreamCallback) {
var input map[string]interface{}
if state.InputBuffer.Len() > 0 {
json.Unmarshal([]byte(state.InputBuffer.String()), &input)
}
if input == nil {
input = make(map[string]interface{})
}
callback.OnToolUse(KiroToolUse{
ToolUseID: state.ToolUseID,
Name: state.Name,
Input: input,
})
}
// extractEventType 从 headers 中提取事件类型
func extractEventType(headers []byte) string {
offset := 0
for offset < len(headers) {
if offset >= len(headers) {
break
}
nameLen := int(headers[offset])
offset++
if offset+nameLen > len(headers) {
break
}
name := string(headers[offset : offset+nameLen])
offset += nameLen
if offset >= len(headers) {
break
}
valueType := headers[offset]
offset++
if valueType == 7 { // String
if offset+2 > len(headers) {
break
}
valueLen := int(headers[offset])<<8 | int(headers[offset+1])
offset += 2
if offset+valueLen > len(headers) {
break
}
value := string(headers[offset : offset+valueLen])
offset += valueLen
if name == ":event-type" {
return value
}
continue
}
// 跳过其他类型
skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16}
if valueType == 6 {
if offset+2 > len(headers) {
break
}
l := int(headers[offset])<<8 | int(headers[offset+1])
offset += 2 + l
} else if skip, ok := skipSizes[valueType]; ok {
offset += skip
} else {
break
}
}
return ""
}