feat: Kiro API Proxy - OpenAI/Anthropic compatible API service
- Multi-account pool with round-robin load balancing - Auto token refresh for IAM IdC and Social auth - Streaming support (SSE) - Web admin panel with account management - Docker support with GitHub Actions CI/CD - Machine ID management per account - Usage tracking (requests, tokens, credits)
This commit is contained in:
370
proxy/kiro.go
Normal file
370
proxy/kiro.go
Normal file
@@ -0,0 +1,370 @@
|
||||
// Package proxy Kiro API 代理核心
|
||||
// 负责调用 Kiro API 并解析 AWS Event Stream 响应
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"kiro-api-proxy/config"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
KiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse"
|
||||
KiroVersion = "0.6.18"
|
||||
)
|
||||
|
||||
// ==================== 请求结构 ====================
|
||||
|
||||
// KiroPayload Kiro API 请求体
|
||||
type KiroPayload struct {
|
||||
ConversationState struct {
|
||||
ChatTriggerType string `json:"chatTriggerType"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
CurrentMessage struct {
|
||||
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||||
} `json:"currentMessage"`
|
||||
History []KiroHistoryMessage `json:"history,omitempty"`
|
||||
} `json:"conversationState"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
}
|
||||
|
||||
type KiroUserInputMessage struct {
|
||||
Content string `json:"content"`
|
||||
ModelID string `json:"modelId,omitempty"`
|
||||
Origin string `json:"origin"`
|
||||
Images []KiroImage `json:"images,omitempty"`
|
||||
UserInputMessageContext *UserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||||
}
|
||||
|
||||
type UserInputMessageContext struct {
|
||||
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||||
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||||
}
|
||||
|
||||
type KiroToolWrapper struct {
|
||||
ToolSpecification struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema InputSchema `json:"inputSchema"`
|
||||
} `json:"toolSpecification"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
JSON interface{} `json:"json"`
|
||||
}
|
||||
|
||||
type KiroToolResult struct {
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Content []KiroResultContent `json:"content"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type KiroResultContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type KiroImage struct {
|
||||
Format string `json:"format"`
|
||||
Source struct {
|
||||
Bytes string `json:"bytes"`
|
||||
} `json:"source"`
|
||||
}
|
||||
|
||||
type KiroHistoryMessage struct {
|
||||
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||||
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||||
}
|
||||
|
||||
type KiroAssistantResponseMessage struct {
|
||||
Content string `json:"content"`
|
||||
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||||
}
|
||||
|
||||
type KiroToolUse struct {
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
type InferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// ==================== 流式回调 ====================
|
||||
|
||||
// KiroStreamCallback 流式响应回调
|
||||
type KiroStreamCallback struct {
|
||||
OnText func(text string, isThinking bool)
|
||||
OnToolUse func(toolUse KiroToolUse)
|
||||
OnComplete func(inputTokens, outputTokens int)
|
||||
OnError func(err error)
|
||||
OnCredits func(credits float64)
|
||||
}
|
||||
|
||||
// ==================== API 调用 ====================
|
||||
|
||||
// CallKiroAPI 调用 Kiro API(流式)
|
||||
func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", KiroEndpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("X-Amz-Target", "AmazonCodeWhispererStreamingService.GenerateAssistantResponse")
|
||||
|
||||
// User-Agent 包含机器码
|
||||
machineId := account.MachineId
|
||||
var userAgent, amzUserAgent string
|
||||
if machineId != "" {
|
||||
userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s-%s", KiroVersion, machineId)
|
||||
amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s %s", KiroVersion, machineId)
|
||||
} else {
|
||||
userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", KiroVersion)
|
||||
amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s", KiroVersion)
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("X-Amz-User-Agent", amzUserAgent)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "spec")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
req.Header.Set("Authorization", "Bearer "+account.AccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return parseEventStream(resp.Body, callback)
|
||||
}
|
||||
|
||||
// ==================== Event Stream 解析 ====================
|
||||
|
||||
// parseEventStream 解析 AWS Event Stream 二进制格式
|
||||
func parseEventStream(body io.Reader, callback *KiroStreamCallback) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
var inputTokens, outputTokens int
|
||||
var totalOutputChars int
|
||||
var totalCredits float64
|
||||
var currentToolUse *toolUseState
|
||||
|
||||
for {
|
||||
// Prelude: 12 bytes (total_len + headers_len + crc)
|
||||
prelude := make([]byte, 12)
|
||||
_, err := io.ReadFull(reader, prelude)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
totalLength := int(prelude[0])<<24 | int(prelude[1])<<16 | int(prelude[2])<<8 | int(prelude[3])
|
||||
headersLength := int(prelude[4])<<24 | int(prelude[5])<<16 | int(prelude[6])<<8 | int(prelude[7])
|
||||
|
||||
if totalLength < 16 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 读取剩余部分
|
||||
remaining := totalLength - 12
|
||||
msgBuf := make([]byte, remaining)
|
||||
_, err = io.ReadFull(reader, msgBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if headersLength > len(msgBuf)-4 {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType := extractEventType(msgBuf[0:headersLength])
|
||||
payloadBytes := msgBuf[headersLength : len(msgBuf)-4]
|
||||
if len(payloadBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(payloadBytes, &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理事件
|
||||
switch eventType {
|
||||
case "assistantResponseEvent":
|
||||
if content, ok := event["content"].(string); ok && content != "" {
|
||||
callback.OnText(content, false)
|
||||
totalOutputChars += len(content)
|
||||
}
|
||||
case "reasoningContentEvent":
|
||||
if text, ok := event["text"].(string); ok && text != "" {
|
||||
callback.OnText(text, true)
|
||||
totalOutputChars += len(text)
|
||||
}
|
||||
case "toolUseEvent":
|
||||
currentToolUse = handleToolUseEvent(event, currentToolUse, callback)
|
||||
case "messageMetadataEvent", "metadataEvent":
|
||||
if tokenUsage, ok := event["tokenUsage"].(map[string]interface{}); ok {
|
||||
if v, ok := tokenUsage["outputTokens"].(float64); ok {
|
||||
outputTokens = int(v)
|
||||
}
|
||||
uncached, _ := tokenUsage["uncachedInputTokens"].(float64)
|
||||
cacheRead, _ := tokenUsage["cacheReadInputTokens"].(float64)
|
||||
cacheWrite, _ := tokenUsage["cacheWriteInputTokens"].(float64)
|
||||
inputTokens = int(uncached + cacheRead + cacheWrite)
|
||||
}
|
||||
case "meteringEvent":
|
||||
if usage, ok := event["usage"].(float64); ok {
|
||||
totalCredits += usage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 估算 token(约 3 字符 = 1 token)
|
||||
if outputTokens == 0 && totalOutputChars > 0 {
|
||||
outputTokens = max(1, totalOutputChars/3)
|
||||
}
|
||||
|
||||
if callback.OnCredits != nil && totalCredits > 0 {
|
||||
callback.OnCredits(totalCredits)
|
||||
}
|
||||
|
||||
callback.OnComplete(inputTokens, outputTokens)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Tool Use 处理 ====================
|
||||
|
||||
type toolUseState struct {
|
||||
ToolUseID string
|
||||
Name string
|
||||
InputBuffer strings.Builder
|
||||
}
|
||||
|
||||
func handleToolUseEvent(event map[string]interface{}, current *toolUseState, callback *KiroStreamCallback) *toolUseState {
|
||||
toolUseID, _ := event["toolUseId"].(string)
|
||||
name, _ := event["name"].(string)
|
||||
isStop, _ := event["stop"].(bool)
|
||||
|
||||
if toolUseID != "" && name != "" {
|
||||
if current == nil {
|
||||
current = &toolUseState{ToolUseID: toolUseID, Name: name}
|
||||
} else if current.ToolUseID != toolUseID {
|
||||
finishToolUse(current, callback)
|
||||
current = &toolUseState{ToolUseID: toolUseID, Name: name}
|
||||
}
|
||||
}
|
||||
|
||||
if current != nil {
|
||||
if input, ok := event["input"].(string); ok {
|
||||
current.InputBuffer.WriteString(input)
|
||||
} else if inputObj, ok := event["input"].(map[string]interface{}); ok {
|
||||
data, _ := json.Marshal(inputObj)
|
||||
current.InputBuffer.Reset()
|
||||
current.InputBuffer.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
if isStop && current != nil {
|
||||
finishToolUse(current, callback)
|
||||
return nil
|
||||
}
|
||||
|
||||
return current
|
||||
}
|
||||
|
||||
func finishToolUse(state *toolUseState, callback *KiroStreamCallback) {
|
||||
var input map[string]interface{}
|
||||
if state.InputBuffer.Len() > 0 {
|
||||
json.Unmarshal([]byte(state.InputBuffer.String()), &input)
|
||||
}
|
||||
if input == nil {
|
||||
input = make(map[string]interface{})
|
||||
}
|
||||
callback.OnToolUse(KiroToolUse{
|
||||
ToolUseID: state.ToolUseID,
|
||||
Name: state.Name,
|
||||
Input: input,
|
||||
})
|
||||
}
|
||||
|
||||
// extractEventType 从 headers 中提取事件类型
|
||||
func extractEventType(headers []byte) string {
|
||||
offset := 0
|
||||
for offset < len(headers) {
|
||||
if offset >= len(headers) {
|
||||
break
|
||||
}
|
||||
nameLen := int(headers[offset])
|
||||
offset++
|
||||
if offset+nameLen > len(headers) {
|
||||
break
|
||||
}
|
||||
name := string(headers[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
if offset >= len(headers) {
|
||||
break
|
||||
}
|
||||
valueType := headers[offset]
|
||||
offset++
|
||||
|
||||
if valueType == 7 { // String
|
||||
if offset+2 > len(headers) {
|
||||
break
|
||||
}
|
||||
valueLen := int(headers[offset])<<8 | int(headers[offset+1])
|
||||
offset += 2
|
||||
if offset+valueLen > len(headers) {
|
||||
break
|
||||
}
|
||||
value := string(headers[offset : offset+valueLen])
|
||||
offset += valueLen
|
||||
if name == ":event-type" {
|
||||
return value
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过其他类型
|
||||
skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16}
|
||||
if valueType == 6 {
|
||||
if offset+2 > len(headers) {
|
||||
break
|
||||
}
|
||||
l := int(headers[offset])<<8 | int(headers[offset+1])
|
||||
offset += 2 + l
|
||||
} else if skip, ok := skipSizes[valueType]; ok {
|
||||
offset += skip
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user