From 0f8035d90e024ae67c0dceee1314020b23ff18eb Mon Sep 17 00:00:00 2001 From: Quorinex Date: Wed, 13 May 2026 13:59:52 +0800 Subject: [PATCH] feat: improve logging, tool compatibility, and endpoint configuration --- auth/oidc.go | 27 +++++---- config/config.go | 47 ++++++++++++++- logger/logger.go | 141 ++++++++++++++++++++++++++++++++++++++++++++ main.go | 14 +++-- proxy/handler.go | 55 ++++++++++++----- proxy/kiro.go | 133 +++++++++++++++++++++++++++++------------ proxy/kiro_api.go | 52 +++++++++++----- proxy/translator.go | 115 +++++++++++++++++++++++++++++++++--- web/index.html | 16 ++++- 9 files changed, 506 insertions(+), 94 deletions(-) create mode 100644 logger/logger.go diff --git a/auth/oidc.go b/auth/oidc.go index 7dcb494..1354470 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -11,7 +11,8 @@ import ( ) // RefreshToken 刷新 access token -func RefreshToken(account *config.Account) (string, string, int64, error) { +// Returns: accessToken, refreshToken, expiresAt, profileArn, error +func RefreshToken(account *config.Account) (string, string, int64, string, error) { if account.AuthMethod == "social" { return refreshSocialToken(account.RefreshToken) } @@ -19,9 +20,9 @@ func RefreshToken(account *config.Account) (string, string, int64, error) { } // refreshOIDCToken IdC/Builder ID token 刷新 -func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (string, string, int64, error) { +func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (string, string, int64, string, error) { if clientID == "" || clientSecret == "" { - return "", "", 0, fmt.Errorf("OIDC refresh requires clientId and clientSecret") + return "", "", 0, "", fmt.Errorf("OIDC refresh requires clientId and clientSecret") } if region == "" { region = "us-east-1" @@ -42,31 +43,32 @@ func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (stri resp, err := httpClient().Do(req) if err != nil { - return "", "", 0, err + return "", "", 0, "", err } defer resp.Body.Close() if resp.StatusCode != 200 { respBody, _ := io.ReadAll(resp.Body) - return "", "", 0, fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) + return "", "", 0, "", fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) } var result struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ExpiresIn int `json:"expiresIn"` + ProfileArn string `json:"profileArn"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", "", 0, err + return "", "", 0, "", err } expiresAt := time.Now().Unix() + int64(result.ExpiresIn) - return result.AccessToken, result.RefreshToken, expiresAt, nil + return result.AccessToken, result.RefreshToken, expiresAt, result.ProfileArn, nil } // refreshSocialToken Social (GitHub/Google) token 刷新 -func refreshSocialToken(refreshToken string) (string, string, int64, error) { +func refreshSocialToken(refreshToken string) (string, string, int64, string, error) { url := "https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken" payload := map[string]string{ @@ -79,25 +81,26 @@ func refreshSocialToken(refreshToken string) (string, string, int64, error) { resp, err := httpClient().Do(req) if err != nil { - return "", "", 0, err + return "", "", 0, "", err } defer resp.Body.Close() if resp.StatusCode != 200 { respBody, _ := io.ReadAll(resp.Body) - return "", "", 0, fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) + return "", "", 0, "", fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) } var result struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ExpiresIn int `json:"expiresIn"` + ProfileArn string `json:"profileArn"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", "", 0, err + return "", "", 0, "", err } expiresAt := time.Now().Unix() + int64(result.ExpiresIn) - return result.AccessToken, result.RefreshToken, expiresAt, nil + return result.AccessToken, result.RefreshToken, expiresAt, result.ProfileArn, nil } diff --git a/config/config.go b/config/config.go index ac0aad4..ef5f641 100644 --- a/config/config.go +++ b/config/config.go @@ -106,15 +106,24 @@ type Config struct { OpenAIThinkingFormat string `json:"openaiThinkingFormat,omitempty"` // OpenAI output format: "reasoning_content", "thinking", or "think" ClaudeThinkingFormat string `json:"claudeThinkingFormat,omitempty"` // Claude output format: "reasoning_content", "thinking", or "think" - // Endpoint configuration: "auto", "codewhisperer", or "amazonq" + // Endpoint configuration: "auto", "kiro", "codewhisperer", or "amazonq" PreferredEndpoint string `json:"preferredEndpoint,omitempty"` + // EndpointFallback controls whether to try other endpoints when the preferred one fails. + // Defaults to true. Set to false to only use the preferred endpoint. + EndpointFallback *bool `json:"endpointFallback,omitempty"` + // Proxy configuration: optional outbound proxy for Kiro API requests // Format: "socks5://host:port", "socks5://user:pass@host:port", // "http://host:port", "http://user:pass@host:port" // Leave empty to connect directly. ProxyURL string `json:"proxyURL,omitempty"` + // LogLevel controls verbosity of application logs. + // Accepted values: "debug", "info", "warn", "error". Defaults to "info". + // Can be overridden by the LOG_LEVEL environment variable. + LogLevel string `json:"logLevel,omitempty"` + // Global statistics (persisted across restarts) TotalRequests int `json:"totalRequests,omitempty"` // Total API requests received SuccessRequests int `json:"successRequests,omitempty"` // Successful requests count @@ -464,6 +473,24 @@ func UpdatePreferredEndpoint(endpoint string) error { return Save() } +// GetEndpointFallback returns whether endpoint fallback is enabled. Defaults to true. +func GetEndpointFallback() bool { + cfgLock.RLock() + defer cfgLock.RUnlock() + if cfg.EndpointFallback == nil { + return true + } + return *cfg.EndpointFallback +} + +// UpdateEndpointFallback sets the endpoint fallback switch and persists the change. +func UpdateEndpointFallback(enabled bool) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.EndpointFallback = &enabled + return Save() +} + // GetProxyURL 获取出站代理地址 func GetProxyURL() string { cfgLock.RLock() @@ -479,6 +506,24 @@ func UpdateProxySettings(proxyURL string) error { return Save() } +// GetLogLevel returns the configured log level (debug/info/warn/error). Defaults to "info". +func GetLogLevel() string { + cfgLock.RLock() + defer cfgLock.RUnlock() + if cfg == nil || cfg.LogLevel == "" { + return "info" + } + return cfg.LogLevel +} + +// UpdateLogLevel updates the log level setting and persists the change. +func UpdateLogLevel(level string) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.LogLevel = level + return Save() +} + type KiroClientConfig struct { KiroVersion string SystemVersion string diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..0316d10 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,141 @@ +// Package logger provides a lightweight leveled logger for Kiro-Go. +// +// Levels (from most to least verbose): +// +// DEBUG < INFO < WARN < ERROR +// +// The active level is configured via logger.Init at startup. +// Priority: LOG_LEVEL environment variable > provided fallback (usually +// taken from config.json "logLevel"). If neither is set or the value is +// unrecognized, the level defaults to INFO. +package logger + +import ( + "io" + "log" + "os" + "strings" + "sync/atomic" +) + +// Level represents a log severity. +type Level int32 + +const ( + LevelDebug Level = iota + LevelInfo + LevelWarn + LevelError +) + +var ( + currentLevel atomic.Int32 + + debugLog = log.New(os.Stdout, "DEBUG ", log.LstdFlags) + infoLog = log.New(os.Stdout, "INFO ", log.LstdFlags) + warnLog = log.New(os.Stderr, "WARN ", log.LstdFlags) + errorLog = log.New(os.Stderr, "ERROR ", log.LstdFlags) +) + +func init() { + currentLevel.Store(int32(LevelInfo)) +} + +// ParseLevel converts a textual level ("debug", "info", "warn", "error") +// to a Level. The ok flag is false when the input is empty or unknown. +func ParseLevel(s string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "debug", "trace": + return LevelDebug, true + case "info": + return LevelInfo, true + case "warn", "warning": + return LevelWarn, true + case "error", "err": + return LevelError, true + } + return LevelInfo, false +} + +// LevelName returns the canonical lowercase name of a Level. +func LevelName(l Level) string { + switch l { + case LevelDebug: + return "debug" + case LevelInfo: + return "info" + case LevelWarn: + return "warn" + case LevelError: + return "error" + } + return "info" +} + +// SetLevel sets the active log level. +func SetLevel(l Level) { + currentLevel.Store(int32(l)) +} + +// GetLevel returns the active log level. +func GetLevel() Level { + return Level(currentLevel.Load()) +} + +// SetOutput redirects all level outputs to w. Useful for tests. +func SetOutput(w io.Writer) { + debugLog.SetOutput(w) + infoLog.SetOutput(w) + warnLog.SetOutput(w) + errorLog.SetOutput(w) +} + +// Init configures the logger. The LOG_LEVEL environment variable, if set, +// overrides the supplied fallback (typically config.GetLogLevel()). +func Init(fallback string) { + value := fallback + if env := os.Getenv("LOG_LEVEL"); env != "" { + value = env + } + if l, ok := ParseLevel(value); ok { + SetLevel(l) + } +} + +func enabled(l Level) bool { + return Level(currentLevel.Load()) <= l +} + +// Debugf logs a formatted message at DEBUG level. +func Debugf(format string, v ...interface{}) { + if enabled(LevelDebug) { + debugLog.Printf(format, v...) + } +} + +// Infof logs a formatted message at INFO level. +func Infof(format string, v ...interface{}) { + if enabled(LevelInfo) { + infoLog.Printf(format, v...) + } +} + +// Warnf logs a formatted message at WARN level. +func Warnf(format string, v ...interface{}) { + if enabled(LevelWarn) { + warnLog.Printf(format, v...) + } +} + +// Errorf logs a formatted message at ERROR level. +func Errorf(format string, v ...interface{}) { + if enabled(LevelError) { + errorLog.Printf(format, v...) + } +} + +// Fatalf logs a formatted message at ERROR level and terminates the process. +func Fatalf(format string, v ...interface{}) { + errorLog.Printf(format, v...) + os.Exit(1) +} diff --git a/main.go b/main.go index 99de1c3..52defff 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ package main import ( "fmt" "kiro-go/config" + "kiro-go/logger" "kiro-go/pool" "kiro-go/proxy" "log" @@ -41,6 +42,9 @@ func main() { log.Fatalf("Failed to load config: %v", err) } + // Initialize log level: LOG_LEVEL env var takes priority over config, defaulting to "info". + logger.Init(config.GetLogLevel()) + // 环境变量覆盖密码 if envPassword := os.Getenv("ADMIN_PASSWORD"); envPassword != "" { config.SetPassword(envPassword) @@ -54,12 +58,12 @@ func main() { // 启动服务器 addr := fmt.Sprintf("%s:%d", config.GetHost(), config.GetPort()) - log.Printf("Kiro-Go starting on http://%s", addr) - log.Printf("Admin panel: http://%s/admin", addr) - log.Printf("Claude API: http://%s/v1/messages", addr) - log.Printf("OpenAI API: http://%s/v1/chat/completions", addr) + logger.Infof("Kiro-Go starting on http://%s (log level: %s)", addr, logger.LevelName(logger.GetLevel())) + logger.Infof("Admin panel: http://%s/admin", addr) + logger.Infof("Claude API: http://%s/v1/messages", addr) + logger.Infof("OpenAI API: http://%s/v1/chat/completions", addr) if err := http.ListenAndServe(addr, handler); err != nil { - log.Fatalf("Server failed: %v", err) + logger.Fatalf("Server failed: %v", err) } } diff --git a/proxy/handler.go b/proxy/handler.go index 8b1c783..aeee7b6 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -6,6 +6,7 @@ import ( "io" "kiro-go/auth" "kiro-go/config" + "kiro-go/logger" "kiro-go/pool" "net/http" "strings" @@ -261,9 +262,9 @@ func (h *Handler) refreshAllAccounts() { // 检查 token 是否需要刷新 if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 { - newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) + newAccessToken, newRefreshToken, newExpiresAt, profileArn, err := auth.RefreshToken(account) if err != nil { - fmt.Printf("[BackgroundRefresh] Token refresh failed for %s: %v\n", account.Email, err) + logger.Warnf("[BackgroundRefresh] Token refresh failed for %s: %v", account.Email, err) continue } account.AccessToken = newAccessToken @@ -273,17 +274,21 @@ func (h *Handler) refreshAllAccounts() { account.ExpiresAt = newExpiresAt config.UpdateAccountToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) h.pool.UpdateToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(account.ID, profileArn) + } } // 刷新账户信息 info, err := RefreshAccountInfo(account) if err != nil { - fmt.Printf("[BackgroundRefresh] Failed to refresh %s: %v\n", account.Email, err) + logger.Warnf("[BackgroundRefresh] Failed to refresh %s: %v", account.Email, err) continue } config.UpdateAccountInfo(account.ID, *info) - fmt.Printf("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f\n", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit) + logger.Infof("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit) } h.pool.Reload() } @@ -317,6 +322,9 @@ func (h *Handler) validateApiKey(r *http.Request) bool { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path + // Debug-level request trace for fine-grained visibility + logger.Debugf("[HTTP] %s %s from %s", r.Method, path, r.RemoteAddr) + // CORS - 完整的头部支持 w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") @@ -535,13 +543,13 @@ func (h *Handler) refreshModelsCache() { for i := range accounts { account := &accounts[i] if err := h.ensureValidToken(account); err != nil { - fmt.Printf("[ModelsCache] Skip %s token refresh failed: %v\n", account.Email, err) + logger.Warnf("[ModelsCache] Skip %s token refresh failed: %v", account.Email, err) continue } models, err := ListAvailableModels(account) if err != nil { - fmt.Printf("[ModelsCache] Failed to refresh for %s: %v\n", account.Email, err) + logger.Warnf("[ModelsCache] Failed to refresh for %s: %v", account.Email, err) continue } aggregated = mergeUniqueModels(aggregated, models) @@ -552,7 +560,7 @@ func (h *Handler) refreshModelsCache() { h.cachedModels = aggregated h.modelsCacheTime = time.Now().Unix() h.modelsCacheMu.Unlock() - fmt.Printf("[ModelsCache] Cached %d models\n", len(aggregated)) + logger.Infof("[ModelsCache] Cached %d models", len(aggregated)) } } @@ -1819,7 +1827,7 @@ func (h *Handler) ensureValidToken(account *config.Account) error { return nil } - accessToken, refreshToken, expiresAt, err := auth.RefreshToken(account) + accessToken, refreshToken, expiresAt, profileArn, err := auth.RefreshToken(account) if err != nil { return err } @@ -1831,6 +1839,10 @@ func (h *Handler) ensureValidToken(account *config.Account) error { account.RefreshToken = refreshToken } account.ExpiresAt = expiresAt + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(account.ID, profileArn) + } // 持久化 config.UpdateAccountToken(account.ID, accessToken, refreshToken, expiresAt) @@ -2119,13 +2131,17 @@ func (h *Handler) apiBatchAccounts(w http.ResponseWriter, r *http.Request) { } // 刷新 token if account.RefreshToken != "" { - if newAccess, newRefresh, newExpires, err := auth.RefreshToken(account); err == nil { + if newAccess, newRefresh, newExpires, profileArn, err := auth.RefreshToken(account); err == nil { account.AccessToken = newAccess if newRefresh != "" { account.RefreshToken = newRefresh } account.ExpiresAt = newExpires config.UpdateAccountToken(id, newAccess, newRefresh, newExpires) + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(id, profileArn) + } h.pool.UpdateToken(id, newAccess, newRefresh, newExpires) } } @@ -2464,7 +2480,7 @@ func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) { AuthMethod: req.AuthMethod, Region: req.Region, } - newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount) + newAccessToken, newRefreshToken, newExpiresAt, newProfileArn, err := auth.RefreshToken(tempAccount) if err != nil { // 刷新失败,如果有传入的 accessToken 则尝试使用 if req.AccessToken != "" { @@ -2500,6 +2516,7 @@ func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) { ExpiresAt: expiresAt, Enabled: true, MachineId: config.GenerateMachineId(), + ProfileArn: newProfileArn, } if err := config.AddAccount(account); err != nil { @@ -2612,7 +2629,7 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s if account.RefreshToken == "" { return nil } - newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) + newAccessToken, newRefreshToken, newExpiresAt, profileArn, err := auth.RefreshToken(account) if err != nil { return err } @@ -2623,6 +2640,10 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s account.ExpiresAt = newExpiresAt config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt) h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt) + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(id, profileArn) + } return nil } @@ -2847,8 +2868,9 @@ func (h *Handler) apiUpdateThinkingConfig(w http.ResponseWriter, r *http.Request // apiGetEndpointConfig 获取端点配置 func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]string{ + json.NewEncoder(w).Encode(map[string]interface{}{ "preferredEndpoint": config.GetPreferredEndpoint(), + "endpointFallback": config.GetEndpointFallback(), }) } @@ -2856,6 +2878,7 @@ func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) { func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request) { var req struct { PreferredEndpoint string `json:"preferredEndpoint"` + EndpointFallback *bool `json:"endpointFallback"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) @@ -2863,10 +2886,10 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request return } - valid := map[string]bool{"auto": true, "codewhisperer": true, "amazonq": true} + valid := map[string]bool{"auto": true, "kiro": true, "codewhisperer": true, "amazonq": true} if !valid[req.PreferredEndpoint] { w.WriteHeader(400) - json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, codewhisperer, or amazonq"}) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, kiro, codewhisperer, or amazonq"}) return } @@ -2876,6 +2899,10 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request return } + if req.EndpointFallback != nil { + config.UpdateEndpointFallback(*req.EndpointFallback) + } + json.NewEncoder(w).Encode(map[string]bool{"success": true}) } diff --git a/proxy/kiro.go b/proxy/kiro.go index 974650a..5ce2467 100644 --- a/proxy/kiro.go +++ b/proxy/kiro.go @@ -1,5 +1,5 @@ -// Package proxy Kiro API 代理核心 -// 负责调用 Kiro API 并解析 AWS Event Stream 响应 +// Package proxy is the core proxy layer for the Kiro API. +// It handles streaming API calls to the Kiro backend and parses AWS Event Stream responses. package proxy import ( @@ -8,6 +8,7 @@ import ( "fmt" "io" "kiro-go/config" + "kiro-go/logger" "net/http" "net/url" "strconv" @@ -18,7 +19,7 @@ import ( "github.com/google/uuid" ) -// 双端点配置(429 时自动 fallback) +// Endpoint configuration (auto-fallback on quota exhaustion). type kiroEndpoint struct { URL string Origin string @@ -27,6 +28,12 @@ type kiroEndpoint struct { } var kiroEndpoints = []kiroEndpoint{ + { + URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", + Origin: "AI_EDITOR", + AmzTarget: "", + Name: "Kiro IDE", + }, { URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse", Origin: "AI_EDITOR", @@ -35,13 +42,13 @@ var kiroEndpoints = []kiroEndpoint{ }, { URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", - Origin: "CLI", + Origin: "AI_EDITOR", AmzTarget: "AmazonQDeveloperStreamingService.SendMessage", Name: "AmazonQ", }, } -// 全局 HTTP 客户端,支持运行时更换(代理重配置) +// Global HTTP clients, swappable at runtime to apply proxy reconfiguration without restart. var kiroHttpStore atomic.Pointer[http.Client] var kiroRestHttpStore atomic.Pointer[http.Client] @@ -49,7 +56,7 @@ func init() { InitKiroHttpClient("") } -// buildKiroTransport 构建带可选代理的 Transport +// buildKiroTransport constructs an HTTP Transport with optional outbound proxy support. func buildKiroTransport(proxyURL string) *http.Transport { t := &http.Transport{ MaxIdleConns: 100, @@ -61,7 +68,7 @@ func buildKiroTransport(proxyURL string) *http.Transport { if proxyURL != "" { if u, err := url.Parse(proxyURL); err == nil { t.Proxy = http.ProxyURL(u) - // 代理不支持 HTTP/2 协议升级 + // Proxied connections cannot negotiate HTTP/2. t.ForceAttemptHTTP2 = false } } else { @@ -70,7 +77,7 @@ func buildKiroTransport(proxyURL string) *http.Transport { return t } -// InitKiroHttpClient 初始化(或重新初始化)Kiro API 的 HTTP 客户端 +// InitKiroHttpClient initializes (or reinitializes) the HTTP clients used for Kiro API requests. func InitKiroHttpClient(proxyURL string) { client := &http.Client{ Timeout: 5 * time.Minute, @@ -85,20 +92,28 @@ func InitKiroHttpClient(proxyURL string) { kiroRestHttpStore.Store(restClient) } -// ==================== 请求结构 ==================== +// ==================== Request Structs ==================== -// KiroPayload Kiro API 请求体 +// KiroPayload is the top-level request body sent to the Kiro API. type KiroPayload struct { ConversationState struct { - ChatTriggerType string `json:"chatTriggerType"` - ConversationID string `json:"conversationId"` - CurrentMessage struct { + AgentContinuationId string `json:"agentContinuationId,omitempty"` + AgentTaskType string `json:"agentTaskType,omitempty"` + ChatTriggerType string `json:"chatTriggerType"` + ConversationID string `json:"conversationId"` + CurrentMessage struct { UserInputMessage KiroUserInputMessage `json:"userInputMessage"` } `json:"currentMessage"` History []KiroHistoryMessage `json:"history,omitempty"` } `json:"conversationState"` ProfileArn string `json:"profileArn,omitempty"` InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"` + + // ToolNameMap maps sanitized tool names (sent to Kiro) back to the + // original names supplied by the client. Used to restore original names + // in tool_use responses so the client can match them to its tool registry. + // Not serialized to the Kiro API request body. + ToolNameMap map[string]string `json:"-"` } type KiroUserInputMessage struct { @@ -177,25 +192,65 @@ type KiroStreamCallback struct { OnContextUsage func(percentage float64) } -// ==================== API 调用 ==================== +// ==================== API Call ==================== -// getSortedEndpoints 根据首选端点配置排序端点列表 +// getSortedEndpoints returns endpoints ordered by user preference, with optional fallback. func getSortedEndpoints(preferred string) []kiroEndpoint { - if preferred == "amazonq" { - return []kiroEndpoint{kiroEndpoints[1], kiroEndpoints[0]} + fallback := config.GetEndpointFallback() + + var primary int + switch preferred { + case "kiro": + primary = 0 + case "codewhisperer": + primary = 1 + case "amazonq": + primary = 2 + default: + // "auto": Kiro first, then fallback to others + return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1], kiroEndpoints[2]} } - if preferred == "codewhisperer" { - return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]} + + if !fallback { + // No fallback: only use the selected endpoint + return []kiroEndpoint{kiroEndpoints[primary]} } - // "auto" 或空值:默认顺序 - return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]} + + // With fallback: selected first, then others in order + result := []kiroEndpoint{kiroEndpoints[primary]} + for i, ep := range kiroEndpoints { + if i != primary { + result = append(result, ep) + } + } + return result } -// CallKiroAPI 调用 Kiro API(流式),双端点自动 fallback +// CallKiroAPI calls the Kiro streaming API, trying each configured endpoint with automatic fallback. func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error { if _, err := json.Marshal(payload); err != nil { return err } + + // Debug: dump full payload for troubleshooting upstream rejections + if payloadJSON, err := json.Marshal(payload); err == nil { + logger.Debugf("[KiroAPI] Request payload: %s", string(payloadJSON)) + } + + // Wrap OnToolUse to restore original tool names for the client. + if callback != nil && callback.OnToolUse != nil && len(payload.ToolNameMap) > 0 { + originalOnToolUse := callback.OnToolUse + nameMap := payload.ToolNameMap + wrapped := *callback + wrapped.OnToolUse = func(tu KiroToolUse) { + if original, ok := nameMap[tu.Name]; ok { + tu.Name = original + } + originalOnToolUse(tu) + } + callback = &wrapped + } + if payload != nil && strings.TrimSpace(payload.ProfileArn) == "" { if profileArn, err := ResolveProfileArn(account); err == nil { payload.ProfileArn = profileArn @@ -204,16 +259,16 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt if account != nil { accountEmail = account.Email } - fmt.Printf("[ProfileArn] Failed to resolve profile ARN for %s: %v\n", accountEmail, err) + logger.Warnf("[ProfileArn] Failed to resolve profile ARN for %s: %v", accountEmail, err) } } - // 根据配置排序端点 + // Build endpoint list ordered by configuration. endpoints := getSortedEndpoints(config.GetPreferredEndpoint()) var lastErr error for _, ep := range endpoints { - // 更新 payload 中的 origin + // Update the origin field for the selected endpoint. payload.ConversationState.CurrentMessage.UserInputMessage.Origin = ep.Origin reqBody, _ := json.Marshal(payload) @@ -231,7 +286,9 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "*/*") - req.Header.Set("X-Amz-Target", ep.AmzTarget) + if ep.AmzTarget != "" { + req.Header.Set("X-Amz-Target", ep.AmzTarget) + } applyKiroBaseHeaders(req, account, headerValues) req.Header.Set("x-amzn-kiro-agent-mode", "vibe") req.Header.Set("x-amzn-codewhisperer-optout", "true") @@ -241,13 +298,13 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt resp, err := kiroHttpStore.Load().Do(req) if err != nil { lastErr = err - fmt.Printf("[KiroAPI] Endpoint %s failed: %v\n", ep.Name, err) + logger.Warnf("[KiroAPI] Endpoint %s failed: %v", ep.Name, err) continue } if resp.StatusCode == 429 { resp.Body.Close() - fmt.Printf("[KiroAPI] Endpoint %s quota exhausted (429), trying next...\n", ep.Name) + logger.Warnf("[KiroAPI] Endpoint %s quota exhausted (429), trying next...", ep.Name) lastErr = fmt.Errorf("quota exhausted on %s", ep.Name) continue } @@ -256,11 +313,11 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt errBody, _ := io.ReadAll(resp.Body) resp.Body.Close() lastErr = fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, ep.Name, string(errBody)) - // 认证错误不继续尝试 + // Authentication errors are not retried across endpoints. if resp.StatusCode == 401 || resp.StatusCode == 403 { return lastErr } - fmt.Printf("[KiroAPI] Endpoint %s error: %v\n", ep.Name, lastErr) + logger.Warnf("[KiroAPI] Endpoint %s error: %v", ep.Name, lastErr) continue } @@ -275,11 +332,11 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt return fmt.Errorf("all endpoints failed") } -// ==================== Event Stream 解析 ==================== +// ==================== Event Stream Parsing ==================== -// parseEventStream 解析 AWS Event Stream 二进制格式 +// parseEventStream decodes an AWS binary Event Stream response body. func parseEventStream(body io.Reader, callback *KiroStreamCallback) error { - // 不使用 bufio,直接读取避免缓冲延迟 + // Read directly without bufio to avoid buffering latency in streaming responses. var inputTokens, outputTokens int var totalCredits float64 var currentToolUse *toolUseState @@ -304,7 +361,7 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback) error { continue } - // 读取剩余部分 + // Read the remaining message bytes. remaining := totalLength - 12 msgBuf := make([]byte, remaining) _, err = io.ReadFull(body, msgBuf) @@ -329,7 +386,7 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback) error { inputTokens, outputTokens = updateTokensFromEvent(event, inputTokens, outputTokens) - // 处理事件 + // Dispatch by event type. switch eventType { case "assistantResponseEvent": if content, ok := event["content"].(string); ok && content != "" { @@ -525,7 +582,7 @@ func readTokenNumber(m map[string]interface{}, keys ...string) (int, bool) { return 0, false } -// ==================== Tool Use 处理 ==================== +// ==================== Tool Use Handling ==================== type toolUseState struct { ToolUseID string @@ -580,7 +637,7 @@ func finishToolUse(state *toolUseState, callback *KiroStreamCallback) { }) } -// extractEventType 从 headers 中提取事件类型 +// extractEventType extracts the event type string from AWS Event Stream message headers. func extractEventType(headers []byte) string { offset := 0 for offset < len(headers) { @@ -617,7 +674,7 @@ func extractEventType(headers []byte) string { continue } - // 跳过其他类型 + // Skip other value types by their fixed byte widths. skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16} if valueType == 6 { if offset+2 > len(headers) { diff --git a/proxy/kiro_api.go b/proxy/kiro_api.go index 91c27f8..13367ab 100644 --- a/proxy/kiro_api.go +++ b/proxy/kiro_api.go @@ -4,7 +4,9 @@ import ( "encoding/json" "fmt" "io" + "kiro-go/auth" "kiro-go/config" + "kiro-go/logger" "net/http" neturl "net/url" "strings" @@ -109,8 +111,8 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { } // ResolveProfileArn returns the account profile ARN, fetching and caching it -// when it is missing. Some Kiro generation requests require this profile for -// model authorization even when model listing works without it. +// when it is missing. First tries ListAvailableProfiles; if that returns empty, +// falls back to refreshing the token (which returns profileArn in the response). func ResolveProfileArn(account *config.Account) (string, error) { if account == nil { return "", fmt.Errorf("account is nil") @@ -119,6 +121,32 @@ func ResolveProfileArn(account *config.Account) (string, error) { return profileArn, nil } + // Try ListAvailableProfiles first + profileArn, err := listAvailableProfiles(account) + if err == nil && profileArn != "" { + if updateErr := config.UpdateAccountProfileArn(account.ID, profileArn); updateErr != nil { + logger.Warnf("[ProfileArn] Failed to cache profile ARN for %s: %v", account.Email, updateErr) + } + account.ProfileArn = profileArn + return profileArn, nil + } + + // Fallback: refresh token to get profileArn from auth response + if account.RefreshToken != "" { + _, _, _, refreshedArn, refreshErr := auth.RefreshToken(account) + if refreshErr == nil && refreshedArn != "" { + if updateErr := config.UpdateAccountProfileArn(account.ID, refreshedArn); updateErr != nil { + logger.Warnf("[ProfileArn] Failed to cache profile ARN for %s: %v", account.Email, updateErr) + } + account.ProfileArn = refreshedArn + return refreshedArn, nil + } + } + + return "", fmt.Errorf("no available Kiro profile") +} + +func listAvailableProfiles(account *config.Account) (string, error) { req, err := http.NewRequest("POST", fmt.Sprintf("%s/ListAvailableProfiles", kiroRestAPIBase), strings.NewReader(`{"maxResults":10}`)) if err != nil { return "", err @@ -147,14 +175,10 @@ func ResolveProfileArn(account *config.Account) (string, error) { } for _, profile := range result.Profiles { if profileArn := strings.TrimSpace(profile.Arn); profileArn != "" { - if updateErr := config.UpdateAccountProfileArn(account.ID, profileArn); updateErr != nil { - fmt.Printf("[ProfileArn] Failed to cache profile ARN for %s: %v\n", account.Email, updateErr) - } - account.ProfileArn = profileArn return profileArn, nil } } - return "", fmt.Errorf("no available Kiro profile") + return "", fmt.Errorf("empty profile list") } func withProfileArnQuery(rawURL string, account *config.Account) string { @@ -192,7 +216,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { errMsg := err.Error() if strings.Contains(errMsg, "TEMPORARILY_SUSPENDED") { // 账户被暂时封禁,自动禁用并标记封禁状态 - fmt.Printf("[RefreshAccountInfo] Account %s is temporarily suspended: %v\n", account.Email, err) + logger.Warnf("[RefreshAccountInfo] Account %s is temporarily suspended: %v", account.Email, err) // 更新账户封禁状态并自动禁用 updatedAccount := *account @@ -203,14 +227,14 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 保存更新后的账户状态 if updateErr := config.UpdateAccount(account.ID, updatedAccount); updateErr != nil { - fmt.Printf("[RefreshAccountInfo] Failed to update account ban status: %v\n", updateErr) + logger.Errorf("[RefreshAccountInfo] Failed to update account ban status: %v", updateErr) } return nil, fmt.Errorf("Account suspended: %w", err) } else if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") { // Token 相关错误,可能需要重新认证 - fmt.Printf("[RefreshAccountInfo] Authentication error for %s: %v\n", account.Email, err) + logger.Warnf("[RefreshAccountInfo] Authentication error for %s: %v", account.Email, err) // 更新账户封禁状态为认证失败并自动禁用 updatedAccount := *account @@ -221,7 +245,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 保存更新后的账户状态 if updateErr := config.UpdateAccount(account.ID, updatedAccount); updateErr != nil { - fmt.Printf("[RefreshAccountInfo] Failed to update account ban status: %v\n", updateErr) + logger.Errorf("[RefreshAccountInfo] Failed to update account ban status: %v", updateErr) } } @@ -230,7 +254,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 如果成功获取信息,清除封禁状态(如果之前被标记) if account.BanStatus != "" && account.BanStatus != "ACTIVE" { - fmt.Printf("[RefreshAccountInfo] Account %s is now active, clearing ban status\n", account.Email) + logger.Infof("[RefreshAccountInfo] Account %s is now active, clearing ban status", account.Email) updatedAccount := *account updatedAccount.BanStatus = "ACTIVE" @@ -239,7 +263,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 保存更新后的账户状态 if updateErr := config.UpdateAccount(account.ID, updatedAccount); updateErr != nil { - fmt.Printf("[RefreshAccountInfo] Failed to clear account ban status: %v\n", updateErr) + logger.Errorf("[RefreshAccountInfo] Failed to clear account ban status: %v", updateErr) } } @@ -264,7 +288,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { if info.SubscriptionTitle == "" { info.SubscriptionTitle = usage.SubscriptionInfo.SubscriptionName } - fmt.Printf("[RefreshAccountInfo] Subscription: type=%s, title=%s, name=%s, parsed=%s\n", + logger.Debugf("[RefreshAccountInfo] Subscription: type=%s, title=%s, name=%s, parsed=%s", usage.SubscriptionInfo.SubscriptionType, usage.SubscriptionInfo.SubscriptionTitle, usage.SubscriptionInfo.SubscriptionName, diff --git a/proxy/translator.go b/proxy/translator.go index 38b562e..e2ba190 100644 --- a/proxy/translator.go +++ b/proxy/translator.go @@ -244,11 +244,14 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload { } // 转换工具 - kiroTools := convertClaudeTools(req.Tools) + kiroTools, toolNameMap := convertClaudeTools(req.Tools) // 构建 payload payload := &KiroPayload{} + payload.ToolNameMap = toolNameMap payload.ConversationState.ChatTriggerType = "MANUAL" + payload.ConversationState.AgentTaskType = "vibe" + payload.ConversationState.AgentContinuationId = uuid.New().String() payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstClaudeConversationAnchor(req.Messages)) payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{ Content: finalContent, @@ -519,21 +522,115 @@ func extractClaudeAssistantContent(content interface{}) (string, []KiroToolUse) return text, toolUses } -func convertClaudeTools(tools []ClaudeTool) []KiroToolWrapper { +func convertClaudeTools(tools []ClaudeTool) ([]KiroToolWrapper, map[string]string) { if len(tools) == 0 { - return nil + return nil, nil } - result := make([]KiroToolWrapper, len(tools)) - for i, tool := range tools { + result := make([]KiroToolWrapper, 0, len(tools)) + nameMap := make(map[string]string) + for _, tool := range tools { desc := tool.Description if len(desc) > maxToolDescLen { desc = desc[:maxToolDescLen] + "..." } - result[i] = KiroToolWrapper{} - result[i].ToolSpecification.Name = shortenToolName(tool.Name) - result[i].ToolSpecification.Description = desc - result[i].ToolSpecification.InputSchema = InputSchema{JSON: tool.InputSchema} + sanitized := shortenToolName(sanitizeToolName(tool.Name)) + if sanitized != tool.Name { + nameMap[sanitized] = tool.Name + } + w := KiroToolWrapper{} + w.ToolSpecification.Name = sanitized + w.ToolSpecification.Description = desc + w.ToolSpecification.InputSchema = InputSchema{JSON: ensureObjectSchema(tool.InputSchema)} + result = append(result, w) + } + return result, nameMap +} + +// ensureObjectSchema ensures the JSON schema has "type": "object" at the top level +// and removes invalid null values from "required" fields (recursively). +// Kiro API rejects tool schemas with "required": null. +func ensureObjectSchema(schema interface{}) interface{} { + m, ok := schema.(map[string]interface{}) + if !ok { + return map[string]interface{}{"type": "object"} + } + cleanSchema(m) + if _, hasType := m["type"]; !hasType { + m["type"] = "object" + } + return m +} + +// cleanSchema recursively removes or fixes invalid "required": null entries +// in a JSON Schema tree. +func cleanSchema(m map[string]interface{}) { + // Fix "required" field: must be array or absent + if req, exists := m["required"]; exists { + if req == nil { + delete(m, "required") + } else if arr, ok := req.([]interface{}); ok && len(arr) == 0 { + delete(m, "required") + } + } + + // Recurse into "properties" + if props, ok := m["properties"].(map[string]interface{}); ok { + for _, v := range props { + if sub, ok := v.(map[string]interface{}); ok { + cleanSchema(sub) + } + } + } + + // Recurse into "items" + if items, ok := m["items"].(map[string]interface{}); ok { + cleanSchema(items) + } + + // Recurse into nested object schemas (e.g., additionalProperties, allOf, oneOf, anyOf) + for _, key := range []string{"additionalProperties"} { + if sub, ok := m[key].(map[string]interface{}); ok { + cleanSchema(sub) + } + } + for _, key := range []string{"allOf", "oneOf", "anyOf"} { + if arr, ok := m[key].([]interface{}); ok { + for _, item := range arr { + if sub, ok := item.(map[string]interface{}); ok { + cleanSchema(sub) + } + } + } + } +} + +// sanitizeToolName normalizes a tool name to characters the Kiro API accepts. +// Kiro tool names must be pure camelCase (no underscores or dashes). +// Separators (_, -, and multi-underscore namespace prefixes) are converted to camelCase boundaries. +func sanitizeToolName(name string) string { + // Split on underscores and dashes, including multi-underscore namespace prefixes. + parts := strings.FieldsFunc(name, func(r rune) bool { + return r == '_' || r == '-' + }) + if len(parts) == 0 { + return "tool" + } + // Build camelCase: first part lowercase start, rest capitalize first letter + var b strings.Builder + for i, part := range parts { + if part == "" { + continue + } + if i == 0 { + b.WriteString(strings.ToLower(part[:1]) + part[1:]) + } else { + b.WriteString(strings.ToUpper(part[:1]) + part[1:]) + } + } + result := b.String() + if result == "" { + return "tool" } return result } diff --git a/web/index.html b/web/index.html index 0d6985b..5af669e 100644 --- a/web/index.html +++ b/web/index.html @@ -1002,12 +1002,21 @@ +
+ + +
@@ -1163,6 +1172,8 @@ 'settings.preferredEndpoint': '首选端点', 'settings.endpointAuto': '自动选择', 'settings.endpointHint': '选择首选端点,自动选择模式下会根据可用性自动选择端点', + 'settings.endpointFallback': '端点不可用时自动切换', + 'settings.endpointFallbackHint': '关闭后,仅使用选定的端点,不会自动切换到其他端点', 'settings.saveEndpoint': '保存端点设置', 'settings.endpointSaved': '端点设置已保存', 'settings.adminPassword': '管理密码', @@ -1379,6 +1390,8 @@ 'settings.preferredEndpoint': 'Preferred Endpoint', 'settings.endpointAuto': 'Auto', 'settings.endpointHint': 'Select preferred endpoint. In auto-select mode, the endpoint is automatically selected based on availability.', + 'settings.endpointFallback': 'Fallback to other endpoints', + 'settings.endpointFallbackHint': 'When disabled, only the selected endpoint is used without automatic fallback', 'settings.saveEndpoint': 'Save Endpoint Settings', 'settings.endpointSaved': 'Endpoint settings saved', 'settings.adminPassword': 'Admin Password', @@ -2060,11 +2073,12 @@ const res = await fetch('/admin/api/endpoint', { headers: { 'X-Admin-Password': password } }); const d = await res.json(); document.getElementById('preferredEndpoint').value = d.preferredEndpoint || 'auto'; + document.getElementById('endpointFallback').checked = d.endpointFallback !== false; } async function saveEndpointConfig() { const res = await fetch('/admin/api/endpoint', { method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, - body: JSON.stringify({ preferredEndpoint: document.getElementById('preferredEndpoint').value }) + body: JSON.stringify({ preferredEndpoint: document.getElementById('preferredEndpoint').value, endpointFallback: document.getElementById('endpointFallback').checked }) }); const d = await res.json(); if (d.success) { alert(t('settings.endpointSaved')); } else { alert(t('common.saveFailed') + ': ' + d.error); }