feat: add i18n support and batch JSON credentials import
This commit is contained in:
@@ -57,7 +57,7 @@ func StartBuilderIdLogin(region string) (*BuilderIdSession, error) {
|
||||
regReq, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(regBody))
|
||||
regReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
regResp, err := client.Do(regReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register client failed: %v", err)
|
||||
@@ -175,7 +175,7 @@ func PollBuilderIdAuth(sessionID string) (accessToken, refreshToken, clientID, c
|
||||
tokenReq, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(tokenBody))
|
||||
tokenReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
tokenResp, err := client.Do(tokenReq)
|
||||
if err != nil {
|
||||
return "", "", "", "", "", 0, "", fmt.Errorf("token request failed: %v", err)
|
||||
|
||||
20
auth/http_client.go
Normal file
20
auth/http_client.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Package auth 提供认证相关功能的 HTTP 客户端
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 全局 HTTP 客户端,复用连接池
|
||||
// 用于所有 auth 模块的 HTTP 请求
|
||||
var httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 50, // 最大空闲连接数
|
||||
MaxIdleConnsPerHost: 10, // 每个 Host 最大空闲连接数
|
||||
IdleConnTimeout: 90 * time.Second, // 空闲连接超时
|
||||
DisableCompression: false, // 启用压缩
|
||||
ForceAttemptHTTP2: true, // 尝试使用 HTTP/2
|
||||
},
|
||||
}
|
||||
@@ -170,8 +170,7 @@ func registerOIDCClient(oidcBase, startUrl, redirectUri string) (clientID, clien
|
||||
req, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -208,8 +207,7 @@ func exchangeToken(oidcBase, clientID, clientSecret, code, codeVerifier, redirec
|
||||
req, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
@@ -37,8 +37,7 @@ func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (stri
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
@@ -75,8 +74,7 @@ func refreshSocialToken(refreshToken string) (string, string, int64, error) {
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ func registerDeviceClient(oidcBase, startUrl string) (clientID, clientSecret str
|
||||
req, _ := http.NewRequest("POST", oidcBase+"/client/register", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
@@ -110,7 +110,7 @@ func startDeviceAuth(oidcBase, clientID, clientSecret, startUrl string) (deviceC
|
||||
req, _ := http.NewRequest("POST", oidcBase+"/device_authorization", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
@@ -139,7 +139,7 @@ func verifyBearerToken(portalBase, bearerToken string) error {
|
||||
req.Header.Set("Authorization", "Bearer "+bearerToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -157,7 +157,7 @@ func getDeviceSessionToken(portalBase, bearerToken string) (string, error) {
|
||||
req.Header.Set("Authorization", "Bearer "+bearerToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -193,7 +193,7 @@ func acceptUserCode(oidcBase, userCode, deviceSessionToken string) (*deviceConte
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Referer", "https://view.awsapps.com/")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -227,7 +227,7 @@ func approveAuth(oidcBase string, deviceContext *deviceContextInfo, deviceSessio
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Referer", "https://view.awsapps.com/")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -262,7 +262,7 @@ func pollForToken(oidcBase, clientID, clientSecret, deviceCode string, interval
|
||||
req, _ := http.NewRequest("POST", oidcBase+"/token", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -311,7 +311,7 @@ func GetUserInfo(accessToken string) (email, userID string, err error) {
|
||||
req.Header.Set("User-Agent", "aws-sdk-js/1.0.18 KiroAPIProxy")
|
||||
req.Header.Set("x-amz-user-agent", "aws-sdk-js/1.0.18 KiroAPIProxy")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
@@ -95,6 +95,9 @@ type Config struct {
|
||||
OpenAIThinkingFormat string `json:"openaiThinkingFormat,omitempty"` // OpenAI output format: "reasoning_content", "thinking", or "think"
|
||||
ClaudeThinkingFormat string `json:"claudeThinkingFormat,omitempty"` // Claude output format: "reasoning_content", "thinking", or "think"
|
||||
|
||||
// Endpoint configuration: "auto", "codewhisperer", or "amazonq"
|
||||
PreferredEndpoint string `json:"preferredEndpoint,omitempty"`
|
||||
|
||||
// Global statistics (persisted across restarts)
|
||||
TotalRequests int `json:"totalRequests,omitempty"` // Total API requests received
|
||||
SuccessRequests int `json:"successRequests,omitempty"` // Successful requests count
|
||||
@@ -410,3 +413,21 @@ func UpdateThinkingConfig(suffix, openaiFormat, claudeFormat string) error {
|
||||
cfg.ClaudeThinkingFormat = claudeFormat
|
||||
return Save()
|
||||
}
|
||||
|
||||
// GetPreferredEndpoint 获取首选端点配置
|
||||
func GetPreferredEndpoint() string {
|
||||
cfgLock.RLock()
|
||||
defer cfgLock.RUnlock()
|
||||
if cfg.PreferredEndpoint == "" {
|
||||
return "auto"
|
||||
}
|
||||
return cfg.PreferredEndpoint
|
||||
}
|
||||
|
||||
// UpdatePreferredEndpoint 更新首选端点配置
|
||||
func UpdatePreferredEndpoint(endpoint string) error {
|
||||
cfgLock.Lock()
|
||||
defer cfgLock.Unlock()
|
||||
cfg.PreferredEndpoint = endpoint
|
||||
return Save()
|
||||
}
|
||||
|
||||
171
proxy/handler.go
171
proxy/handler.go
@@ -9,6 +9,8 @@ import (
|
||||
"kiro-api-proxy/pool"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,30 +19,35 @@ import (
|
||||
// Handler HTTP 处理器
|
||||
type Handler struct {
|
||||
pool *pool.AccountPool
|
||||
// 运行时统计
|
||||
totalRequests int
|
||||
successRequests int
|
||||
failedRequests int
|
||||
totalTokens int
|
||||
totalCredits float64
|
||||
// 运行时统计 (使用原子操作)
|
||||
totalRequests int64
|
||||
successRequests int64
|
||||
failedRequests int64
|
||||
totalTokens int64
|
||||
totalCredits float64 // float64 需要用锁保护
|
||||
creditsMu sync.RWMutex
|
||||
startTime int64
|
||||
stopRefresh chan struct{}
|
||||
stopStatsSaver chan struct{}
|
||||
}
|
||||
|
||||
func NewHandler() *Handler {
|
||||
totalReq, successReq, failedReq, totalTokens, totalCredits := config.GetStats()
|
||||
h := &Handler{
|
||||
pool: pool.GetPool(),
|
||||
totalRequests: totalReq,
|
||||
successRequests: successReq,
|
||||
failedRequests: failedReq,
|
||||
totalTokens: totalTokens,
|
||||
totalRequests: int64(totalReq),
|
||||
successRequests: int64(successReq),
|
||||
failedRequests: int64(failedReq),
|
||||
totalTokens: int64(totalTokens),
|
||||
totalCredits: totalCredits,
|
||||
startTime: time.Now().Unix(),
|
||||
stopRefresh: make(chan struct{}),
|
||||
stopStatsSaver: make(chan struct{}),
|
||||
}
|
||||
// 启动后台刷新
|
||||
go h.backgroundRefresh()
|
||||
// 启动后台统计保存 (每30秒保存一次)
|
||||
go h.backgroundStatsSaver()
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -106,7 +113,7 @@ func (h *Handler) validateApiKey(r *http.Request) bool {
|
||||
if !config.IsApiKeyRequired() {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
expectedKey := config.GetApiKey()
|
||||
if expectedKey == "" {
|
||||
return true
|
||||
@@ -193,11 +200,11 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
"status": "ok",
|
||||
"accounts": h.pool.Count(),
|
||||
"available": h.pool.AvailableCount(),
|
||||
"totalRequests": h.totalRequests,
|
||||
"successRequests": h.successRequests,
|
||||
"failedRequests": h.failedRequests,
|
||||
"totalTokens": h.totalTokens,
|
||||
"totalCredits": h.totalCredits,
|
||||
"totalRequests": atomic.LoadInt64(&h.totalRequests),
|
||||
"successRequests": atomic.LoadInt64(&h.successRequests),
|
||||
"failedRequests": atomic.LoadInt64(&h.failedRequests),
|
||||
"totalTokens": atomic.LoadInt64(&h.totalTokens),
|
||||
"totalCredits": h.getCredits(),
|
||||
"uptime": time.Now().Unix() - h.startTime,
|
||||
})
|
||||
}
|
||||
@@ -376,7 +383,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
})
|
||||
contentStarted = true
|
||||
}
|
||||
|
||||
|
||||
if thinkingState == 0 {
|
||||
// 普通内容
|
||||
if text == "" {
|
||||
@@ -426,7 +433,7 @@ func (h *Handler) handleClaudeStream(w http.ResponseWriter, account *config.Acco
|
||||
|
||||
// 处理文本,解析 <thinking> 标签
|
||||
var thinkingStarted bool
|
||||
|
||||
|
||||
processClaudeText := func(text string, isThinking bool, forceFlush bool) {
|
||||
// 如果是 reasoningContentEvent,直接输出
|
||||
if isThinking {
|
||||
@@ -642,20 +649,58 @@ func (h *Handler) sendSSE(w http.ResponseWriter, flusher http.Flusher, event str
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// 统计记录
|
||||
func (h *Handler) recordSuccess(inputTokens, outputTokens int, credits float64) {
|
||||
h.totalRequests++
|
||||
h.successRequests++
|
||||
h.totalTokens += inputTokens + outputTokens
|
||||
// backgroundStatsSaver 后台定时保存统计数据
|
||||
func (h *Handler) backgroundStatsSaver() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
h.saveStats()
|
||||
case <-h.stopStatsSaver:
|
||||
h.saveStats() // 退出前保存一次
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// saveStats 保存统计到配置文件
|
||||
func (h *Handler) saveStats() {
|
||||
config.UpdateStats(
|
||||
int(atomic.LoadInt64(&h.totalRequests)),
|
||||
int(atomic.LoadInt64(&h.successRequests)),
|
||||
int(atomic.LoadInt64(&h.failedRequests)),
|
||||
int(atomic.LoadInt64(&h.totalTokens)),
|
||||
h.getCredits(),
|
||||
)
|
||||
}
|
||||
|
||||
// getCredits 线程安全获取 credits
|
||||
func (h *Handler) getCredits() float64 {
|
||||
h.creditsMu.RLock()
|
||||
defer h.creditsMu.RUnlock()
|
||||
return h.totalCredits
|
||||
}
|
||||
|
||||
// addCredits 线程安全增加 credits
|
||||
func (h *Handler) addCredits(credits float64) {
|
||||
h.creditsMu.Lock()
|
||||
h.totalCredits += credits
|
||||
// 异步保存
|
||||
go config.UpdateStats(h.totalRequests, h.successRequests, h.failedRequests, h.totalTokens, h.totalCredits)
|
||||
h.creditsMu.Unlock()
|
||||
}
|
||||
|
||||
// 统计记录 (使用原子操作)
|
||||
func (h *Handler) recordSuccess(inputTokens, outputTokens int, credits float64) {
|
||||
atomic.AddInt64(&h.totalRequests, 1)
|
||||
atomic.AddInt64(&h.successRequests, 1)
|
||||
atomic.AddInt64(&h.totalTokens, int64(inputTokens+outputTokens))
|
||||
h.addCredits(credits)
|
||||
}
|
||||
|
||||
func (h *Handler) recordFailure() {
|
||||
h.totalRequests++
|
||||
h.failedRequests++
|
||||
go config.UpdateStats(h.totalRequests, h.successRequests, h.failedRequests, h.totalTokens, h.totalCredits)
|
||||
atomic.AddInt64(&h.totalRequests, 1)
|
||||
atomic.AddInt64(&h.failedRequests, 1)
|
||||
}
|
||||
|
||||
// handleClaudeNonStream Claude 非流式响应
|
||||
@@ -807,9 +852,9 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
if content == "" && thinkingState == 2 {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
var chunk map[string]interface{}
|
||||
|
||||
|
||||
if thinkingState > 0 {
|
||||
// thinking 内容
|
||||
switch thinkingFormat {
|
||||
@@ -903,7 +948,7 @@ func (h *Handler) handleOpenAIStream(w http.ResponseWriter, account *config.Acco
|
||||
// 处理文本,解析 <thinking> 标签
|
||||
// thinkingStarted 用于跟踪是否已发送开始标签
|
||||
var thinkingStarted bool
|
||||
|
||||
|
||||
processText := func(text string, isThinking bool, forceFlush bool) {
|
||||
// 如果是 reasoningContentEvent,直接输出
|
||||
if isThinking {
|
||||
@@ -1233,6 +1278,10 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) {
|
||||
h.apiGetThinkingConfig(w, r)
|
||||
case path == "/thinking" && r.Method == "POST":
|
||||
h.apiUpdateThinkingConfig(w, r)
|
||||
case path == "/endpoint" && r.Method == "GET":
|
||||
h.apiGetEndpointConfig(w, r)
|
||||
case path == "/endpoint" && r.Method == "POST":
|
||||
h.apiUpdateEndpointConfig(w, r)
|
||||
default:
|
||||
w.WriteHeader(404)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Not Found"})
|
||||
@@ -1242,19 +1291,19 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
accounts := config.GetAccounts()
|
||||
poolAccounts := h.pool.GetAllAccounts()
|
||||
|
||||
|
||||
// 合并运行时统计
|
||||
statsMap := make(map[string]config.Account)
|
||||
for _, a := range poolAccounts {
|
||||
statsMap[a.ID] = a
|
||||
}
|
||||
|
||||
|
||||
// 隐藏敏感信息
|
||||
result := make([]map[string]interface{}, len(accounts))
|
||||
for i, a := range accounts {
|
||||
// 获取运行时统计
|
||||
stats := statsMap[a.ID]
|
||||
|
||||
|
||||
result[i] = map[string]interface{}{
|
||||
"id": a.ID,
|
||||
"email": a.Email,
|
||||
@@ -1765,21 +1814,23 @@ func (h *Handler) apiUpdateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *Handler) apiGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"totalRequests": h.totalRequests,
|
||||
"successRequests": h.successRequests,
|
||||
"failedRequests": h.failedRequests,
|
||||
"totalTokens": h.totalTokens,
|
||||
"totalCredits": h.totalCredits,
|
||||
"totalRequests": atomic.LoadInt64(&h.totalRequests),
|
||||
"successRequests": atomic.LoadInt64(&h.successRequests),
|
||||
"failedRequests": atomic.LoadInt64(&h.failedRequests),
|
||||
"totalTokens": atomic.LoadInt64(&h.totalTokens),
|
||||
"totalCredits": h.getCredits(),
|
||||
"uptime": time.Now().Unix() - h.startTime,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) apiResetStats(w http.ResponseWriter, r *http.Request) {
|
||||
h.totalRequests = 0
|
||||
h.successRequests = 0
|
||||
h.failedRequests = 0
|
||||
h.totalTokens = 0
|
||||
atomic.StoreInt64(&h.totalRequests, 0)
|
||||
atomic.StoreInt64(&h.successRequests, 0)
|
||||
atomic.StoreInt64(&h.failedRequests, 0)
|
||||
atomic.StoreInt64(&h.totalTokens, 0)
|
||||
h.creditsMu.Lock()
|
||||
h.totalCredits = 0
|
||||
h.creditsMu.Unlock()
|
||||
config.UpdateStats(0, 0, 0, 0, 0)
|
||||
json.NewEncoder(w).Encode(map[string]bool{"success": true})
|
||||
}
|
||||
@@ -1930,3 +1981,37 @@ func (h *Handler) apiUpdateThinkingConfig(w http.ResponseWriter, r *http.Request
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]bool{"success": true})
|
||||
}
|
||||
|
||||
// apiGetEndpointConfig 获取端点配置
|
||||
func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"preferredEndpoint": config.GetPreferredEndpoint(),
|
||||
})
|
||||
}
|
||||
|
||||
// apiUpdateEndpointConfig 更新端点配置
|
||||
func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
PreferredEndpoint string `json:"preferredEndpoint"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(400)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"})
|
||||
return
|
||||
}
|
||||
|
||||
valid := map[string]bool{"auto": true, "codewhisperer": true, "amazonq": true}
|
||||
if !valid[req.PreferredEndpoint] {
|
||||
w.WriteHeader(400)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, codewhisperer, or amazonq"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := config.UpdatePreferredEndpoint(req.PreferredEndpoint); err != nil {
|
||||
w.WriteHeader(500)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]bool{"success": true})
|
||||
}
|
||||
|
||||
139
proxy/kiro.go
139
proxy/kiro.go
@@ -15,10 +15,42 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
KiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse"
|
||||
KiroVersion = "0.6.18"
|
||||
)
|
||||
const KiroVersion = "0.6.18"
|
||||
|
||||
// 双端点配置(429 时自动 fallback)
|
||||
type kiroEndpoint struct {
|
||||
URL string
|
||||
Origin string
|
||||
AmzTarget string
|
||||
Name string
|
||||
}
|
||||
|
||||
var kiroEndpoints = []kiroEndpoint{
|
||||
{
|
||||
URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||
Origin: "AI_EDITOR",
|
||||
AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse",
|
||||
Name: "CodeWhisperer",
|
||||
},
|
||||
{
|
||||
URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||
Origin: "CLI",
|
||||
AmzTarget: "AmazonQDeveloperStreamingService.SendMessage",
|
||||
Name: "AmazonQ",
|
||||
},
|
||||
}
|
||||
|
||||
// 全局 HTTP 客户端,复用连接池
|
||||
var kiroHttpClient = &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100, // 最大空闲连接数
|
||||
MaxIdleConnsPerHost: 20, // 每个 Host 最大空闲连接数
|
||||
IdleConnTimeout: 90 * time.Second, // 空闲连接超时
|
||||
DisableCompression: false, // 启用压缩
|
||||
ForceAttemptHTTP2: true, // 尝试使用 HTTP/2
|
||||
},
|
||||
}
|
||||
|
||||
// ==================== 请求结构 ====================
|
||||
|
||||
@@ -113,7 +145,19 @@ type KiroStreamCallback struct {
|
||||
|
||||
// ==================== API 调用 ====================
|
||||
|
||||
// CallKiroAPI 调用 Kiro API(流式)
|
||||
// getSortedEndpoints 根据首选端点配置排序端点列表
|
||||
func getSortedEndpoints(preferred string) []kiroEndpoint {
|
||||
if preferred == "amazonq" {
|
||||
return []kiroEndpoint{kiroEndpoints[1], kiroEndpoints[0]}
|
||||
}
|
||||
if preferred == "codewhisperer" {
|
||||
return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]}
|
||||
}
|
||||
// "auto" 或空值:默认顺序
|
||||
return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]}
|
||||
}
|
||||
|
||||
// CallKiroAPI 调用 Kiro API(流式),双端点自动 fallback
|
||||
func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -123,17 +167,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
|
||||
// 预估输入 token(约 3 字符 = 1 token)
|
||||
estimatedInputTokens := max(1, len(body)/3)
|
||||
|
||||
req, err := http.NewRequest("POST", KiroEndpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("X-Amz-Target", "AmazonCodeWhispererStreamingService.GenerateAssistantResponse")
|
||||
|
||||
// User-Agent 包含机器码
|
||||
// User-Agent
|
||||
machineId := account.MachineId
|
||||
var userAgent, amzUserAgent string
|
||||
if machineId != "" {
|
||||
@@ -143,27 +177,68 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt
|
||||
userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", KiroVersion)
|
||||
amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s", KiroVersion)
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("X-Amz-User-Agent", amzUserAgent)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "spec")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
req.Header.Set("Authorization", "Bearer "+account.AccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 根据配置排序端点
|
||||
endpoints := getSortedEndpoints(config.GetPreferredEndpoint())
|
||||
|
||||
var lastErr error
|
||||
for _, ep := range endpoints {
|
||||
// 更新 payload 中的 origin
|
||||
payload.ConversationState.CurrentMessage.UserInputMessage.Origin = ep.Origin
|
||||
|
||||
reqBody, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequest("POST", ep.URL, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("X-Amz-Target", ep.AmzTarget)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("X-Amz-User-Agent", amzUserAgent)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "spec")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
req.Header.Set("Authorization", "Bearer "+account.AccessToken)
|
||||
|
||||
resp, err := kiroHttpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
fmt.Printf("[KiroAPI] Endpoint %s failed: %v\n", ep.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode == 429 {
|
||||
resp.Body.Close()
|
||||
fmt.Printf("[KiroAPI] Endpoint %s quota exhausted (429), trying next...\n", ep.Name)
|
||||
lastErr = fmt.Errorf("quota exhausted on %s", ep.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
lastErr = fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, ep.Name, string(errBody))
|
||||
// 认证错误不继续尝试
|
||||
if resp.StatusCode == 401 || resp.StatusCode == 403 {
|
||||
return lastErr
|
||||
}
|
||||
fmt.Printf("[KiroAPI] Endpoint %s error: %v\n", ep.Name, lastErr)
|
||||
continue
|
||||
}
|
||||
|
||||
err = parseEventStream(resp.Body, callback, estimatedInputTokens)
|
||||
resp.Body.Close()
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
|
||||
return parseEventStream(resp.Body, callback, estimatedInputTokens)
|
||||
return fmt.Errorf("all endpoints failed")
|
||||
}
|
||||
|
||||
// ==================== Event Stream 解析 ====================
|
||||
|
||||
2020
web/index.html
2020
web/index.html
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user