diff --git a/auth/http_client.go b/auth/http_client.go index fa5443e..4604d70 100644 --- a/auth/http_client.go +++ b/auth/http_client.go @@ -34,6 +34,8 @@ func buildAuthTransport(proxyURL string) *http.Transport { t.Proxy = http.ProxyURL(u) t.ForceAttemptHTTP2 = false } + } else { + t.Proxy = http.ProxyFromEnvironment } return t } diff --git a/auth/http_client_test.go b/auth/http_client_test.go new file mode 100644 index 0000000..3f5d505 --- /dev/null +++ b/auth/http_client_test.go @@ -0,0 +1,52 @@ +package auth + +import ( + "net/http" + "net/url" + "testing" +) + +func TestBuildAuthTransportUsesExplicitProxyURL(t *testing.T) { + transport := buildAuthTransport("http://proxy.local:8080") + req := &http.Request{URL: mustParseURL(t, "https://oidc.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://proxy.local:8080") +} + +func TestBuildAuthTransportFallsBackToEnvironmentProxy(t *testing.T) { + t.Setenv("HTTPS_PROXY", "http://env-proxy.local:2323") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + transport := buildAuthTransport("") + req := &http.Request{URL: mustParseURL(t, "https://oidc.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://env-proxy.local:2323") +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + parsed, err := url.Parse(raw) + if err != nil { + t.Fatalf("invalid test URL: %v", err) + } + return parsed +} + +func assertProxyURL(t *testing.T, got *url.URL, want string) { + t.Helper() + if got == nil { + t.Fatalf("expected proxy URL %q, got nil", want) + } + if got.String() != want { + t.Fatalf("expected proxy URL %q, got %q", want, got.String()) + } +} diff --git a/auth/oidc.go b/auth/oidc.go index 7dcb494..1354470 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -11,7 +11,8 @@ import ( ) // RefreshToken 刷新 access token -func RefreshToken(account *config.Account) (string, string, int64, error) { +// Returns: accessToken, refreshToken, expiresAt, profileArn, error +func RefreshToken(account *config.Account) (string, string, int64, string, error) { if account.AuthMethod == "social" { return refreshSocialToken(account.RefreshToken) } @@ -19,9 +20,9 @@ func RefreshToken(account *config.Account) (string, string, int64, error) { } // refreshOIDCToken IdC/Builder ID token 刷新 -func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (string, string, int64, error) { +func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (string, string, int64, string, error) { if clientID == "" || clientSecret == "" { - return "", "", 0, fmt.Errorf("OIDC refresh requires clientId and clientSecret") + return "", "", 0, "", fmt.Errorf("OIDC refresh requires clientId and clientSecret") } if region == "" { region = "us-east-1" @@ -42,31 +43,32 @@ func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (stri resp, err := httpClient().Do(req) if err != nil { - return "", "", 0, err + return "", "", 0, "", err } defer resp.Body.Close() if resp.StatusCode != 200 { respBody, _ := io.ReadAll(resp.Body) - return "", "", 0, fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) + return "", "", 0, "", fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) } var result struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ExpiresIn int `json:"expiresIn"` + ProfileArn string `json:"profileArn"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", "", 0, err + return "", "", 0, "", err } expiresAt := time.Now().Unix() + int64(result.ExpiresIn) - return result.AccessToken, result.RefreshToken, expiresAt, nil + return result.AccessToken, result.RefreshToken, expiresAt, result.ProfileArn, nil } // refreshSocialToken Social (GitHub/Google) token 刷新 -func refreshSocialToken(refreshToken string) (string, string, int64, error) { +func refreshSocialToken(refreshToken string) (string, string, int64, string, error) { url := "https://prod.us-east-1.auth.desktop.kiro.dev/refreshToken" payload := map[string]string{ @@ -79,25 +81,26 @@ func refreshSocialToken(refreshToken string) (string, string, int64, error) { resp, err := httpClient().Do(req) if err != nil { - return "", "", 0, err + return "", "", 0, "", err } defer resp.Body.Close() if resp.StatusCode != 200 { respBody, _ := io.ReadAll(resp.Body) - return "", "", 0, fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) + return "", "", 0, "", fmt.Errorf("refresh failed: %d %s", resp.StatusCode, string(respBody)) } var result struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ExpiresIn int `json:"expiresIn"` + ProfileArn string `json:"profileArn"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", "", 0, err + return "", "", 0, "", err } expiresAt := time.Now().Unix() + int64(result.ExpiresIn) - return result.AccessToken, result.RefreshToken, expiresAt, nil + return result.AccessToken, result.RefreshToken, expiresAt, result.ProfileArn, nil } diff --git a/config/config.go b/config/config.go index 1107ebc..848f659 100644 --- a/config/config.go +++ b/config/config.go @@ -50,10 +50,15 @@ type Account struct { StartUrl string `json:"startUrl,omitempty"` // AWS SSO start URL ExpiresAt int64 `json:"expiresAt,omitempty"` // Token expiration timestamp (Unix seconds) MachineId string `json:"machineId,omitempty"` // UUID machine identifier for request tracking + ProfileArn string `json:"profileArn,omitempty"` // CodeWhisperer/Kiro profile ARN for generation requests // Priority weight for load balancing (higher = more requests) Weight int `json:"weight,omitempty"` // 0 or 1 = normal, 2+ = higher priority + // Overage behavior after the main usage limit is reached. + AllowOverage bool `json:"allowOverage,omitempty"` // Whether to keep using the account after UsageLimit is reached + OverageWeight int `json:"overageWeight,omitempty"` // 1-10, lower values reduce overage request frequency + // Account status Enabled bool `json:"enabled"` // Whether account is active in the pool BanStatus string `json:"banStatus,omitempty"` // Ban status: "ACTIVE", "BANNED", "SUSPENDED" @@ -105,9 +110,13 @@ type Config struct { OpenAIThinkingFormat string `json:"openaiThinkingFormat,omitempty"` // OpenAI output format: "reasoning_content", "thinking", or "think" ClaudeThinkingFormat string `json:"claudeThinkingFormat,omitempty"` // Claude output format: "reasoning_content", "thinking", or "think" - // Endpoint configuration: "auto", "codewhisperer", or "amazonq" + // Endpoint configuration: "auto", "kiro", "codewhisperer", or "amazonq" PreferredEndpoint string `json:"preferredEndpoint,omitempty"` + // EndpointFallback controls whether to try other endpoints when the preferred one fails. + // Defaults to true. Set to false to only use the preferred endpoint. + EndpointFallback *bool `json:"endpointFallback,omitempty"` + // Proxy configuration: optional outbound proxy for Kiro API requests // Format: "socks5://host:port", "socks5://user:pass@host:port", // "http://host:port", "http://user:pass@host:port" @@ -119,6 +128,11 @@ type Config struct { FirstByteTimeoutSec int `json:"firstByteTimeoutSec,omitempty"` // First-byte timeout in seconds (default: 10, 0=disabled) FirstByteRetries int `json:"firstByteRetries,omitempty"` // Same-endpoint retry count on first-byte timeout (default: 1) + // LogLevel controls verbosity of application logs. + // Accepted values: "debug", "info", "warn", "error". Defaults to "info". + // Can be overridden by the LOG_LEVEL environment variable. + LogLevel string `json:"logLevel,omitempty"` + // Global statistics (persisted across restarts) TotalRequests int `json:"totalRequests,omitempty"` // Total API requests received SuccessRequests int `json:"successRequests,omitempty"` // Successful requests count @@ -148,7 +162,7 @@ type AccountInfo struct { } // Version current version -const Version = "1.0.6" +const Version = "1.0.7" var ( cfg *Config @@ -279,6 +293,18 @@ func UpdateAccount(id string, account Account) error { return nil } +func UpdateAccountProfileArn(id, profileArn string) error { + cfgLock.Lock() + defer cfgLock.Unlock() + for i, a := range cfg.Accounts { + if a.ID == id { + cfg.Accounts[i].ProfileArn = profileArn + return Save() + } + } + return nil +} + func DeleteAccount(id string) error { cfgLock.Lock() defer cfgLock.Unlock() @@ -456,6 +482,24 @@ func UpdatePreferredEndpoint(endpoint string) error { return Save() } +// GetEndpointFallback returns whether endpoint fallback is enabled. Defaults to true. +func GetEndpointFallback() bool { + cfgLock.RLock() + defer cfgLock.RUnlock() + if cfg.EndpointFallback == nil { + return true + } + return *cfg.EndpointFallback +} + +// UpdateEndpointFallback sets the endpoint fallback switch and persists the change. +func UpdateEndpointFallback(enabled bool) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.EndpointFallback = &enabled + return Save() +} + // GetProxyURL 获取出站代理地址 func GetProxyURL() string { cfgLock.RLock() @@ -543,6 +587,24 @@ func UpdateFirstByteRetries(n int) error { return Save() } +// GetLogLevel returns the configured log level (debug/info/warn/error). Defaults to "info". +func GetLogLevel() string { + cfgLock.RLock() + defer cfgLock.RUnlock() + if cfg == nil || cfg.LogLevel == "" { + return "info" + } + return cfg.LogLevel +} + +// UpdateLogLevel updates the log level setting and persists the change. +func UpdateLogLevel(level string) error { + cfgLock.Lock() + defer cfgLock.Unlock() + cfg.LogLevel = level + return Save() +} + type KiroClientConfig struct { KiroVersion string SystemVersion string diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..0316d10 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,141 @@ +// Package logger provides a lightweight leveled logger for Kiro-Go. +// +// Levels (from most to least verbose): +// +// DEBUG < INFO < WARN < ERROR +// +// The active level is configured via logger.Init at startup. +// Priority: LOG_LEVEL environment variable > provided fallback (usually +// taken from config.json "logLevel"). If neither is set or the value is +// unrecognized, the level defaults to INFO. +package logger + +import ( + "io" + "log" + "os" + "strings" + "sync/atomic" +) + +// Level represents a log severity. +type Level int32 + +const ( + LevelDebug Level = iota + LevelInfo + LevelWarn + LevelError +) + +var ( + currentLevel atomic.Int32 + + debugLog = log.New(os.Stdout, "DEBUG ", log.LstdFlags) + infoLog = log.New(os.Stdout, "INFO ", log.LstdFlags) + warnLog = log.New(os.Stderr, "WARN ", log.LstdFlags) + errorLog = log.New(os.Stderr, "ERROR ", log.LstdFlags) +) + +func init() { + currentLevel.Store(int32(LevelInfo)) +} + +// ParseLevel converts a textual level ("debug", "info", "warn", "error") +// to a Level. The ok flag is false when the input is empty or unknown. +func ParseLevel(s string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "debug", "trace": + return LevelDebug, true + case "info": + return LevelInfo, true + case "warn", "warning": + return LevelWarn, true + case "error", "err": + return LevelError, true + } + return LevelInfo, false +} + +// LevelName returns the canonical lowercase name of a Level. +func LevelName(l Level) string { + switch l { + case LevelDebug: + return "debug" + case LevelInfo: + return "info" + case LevelWarn: + return "warn" + case LevelError: + return "error" + } + return "info" +} + +// SetLevel sets the active log level. +func SetLevel(l Level) { + currentLevel.Store(int32(l)) +} + +// GetLevel returns the active log level. +func GetLevel() Level { + return Level(currentLevel.Load()) +} + +// SetOutput redirects all level outputs to w. Useful for tests. +func SetOutput(w io.Writer) { + debugLog.SetOutput(w) + infoLog.SetOutput(w) + warnLog.SetOutput(w) + errorLog.SetOutput(w) +} + +// Init configures the logger. The LOG_LEVEL environment variable, if set, +// overrides the supplied fallback (typically config.GetLogLevel()). +func Init(fallback string) { + value := fallback + if env := os.Getenv("LOG_LEVEL"); env != "" { + value = env + } + if l, ok := ParseLevel(value); ok { + SetLevel(l) + } +} + +func enabled(l Level) bool { + return Level(currentLevel.Load()) <= l +} + +// Debugf logs a formatted message at DEBUG level. +func Debugf(format string, v ...interface{}) { + if enabled(LevelDebug) { + debugLog.Printf(format, v...) + } +} + +// Infof logs a formatted message at INFO level. +func Infof(format string, v ...interface{}) { + if enabled(LevelInfo) { + infoLog.Printf(format, v...) + } +} + +// Warnf logs a formatted message at WARN level. +func Warnf(format string, v ...interface{}) { + if enabled(LevelWarn) { + warnLog.Printf(format, v...) + } +} + +// Errorf logs a formatted message at ERROR level. +func Errorf(format string, v ...interface{}) { + if enabled(LevelError) { + errorLog.Printf(format, v...) + } +} + +// Fatalf logs a formatted message at ERROR level and terminates the process. +func Fatalf(format string, v ...interface{}) { + errorLog.Printf(format, v...) + os.Exit(1) +} diff --git a/main.go b/main.go index 99de1c3..52defff 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ package main import ( "fmt" "kiro-go/config" + "kiro-go/logger" "kiro-go/pool" "kiro-go/proxy" "log" @@ -41,6 +42,9 @@ func main() { log.Fatalf("Failed to load config: %v", err) } + // Initialize log level: LOG_LEVEL env var takes priority over config, defaulting to "info". + logger.Init(config.GetLogLevel()) + // 环境变量覆盖密码 if envPassword := os.Getenv("ADMIN_PASSWORD"); envPassword != "" { config.SetPassword(envPassword) @@ -54,12 +58,12 @@ func main() { // 启动服务器 addr := fmt.Sprintf("%s:%d", config.GetHost(), config.GetPort()) - log.Printf("Kiro-Go starting on http://%s", addr) - log.Printf("Admin panel: http://%s/admin", addr) - log.Printf("Claude API: http://%s/v1/messages", addr) - log.Printf("OpenAI API: http://%s/v1/chat/completions", addr) + logger.Infof("Kiro-Go starting on http://%s (log level: %s)", addr, logger.LevelName(logger.GetLevel())) + logger.Infof("Admin panel: http://%s/admin", addr) + logger.Infof("Claude API: http://%s/v1/messages", addr) + logger.Infof("OpenAI API: http://%s/v1/chat/completions", addr) if err := http.ListenAndServe(addr, handler); err != nil { - log.Fatalf("Server failed: %v", err) + logger.Fatalf("Server failed: %v", err) } } diff --git a/pool/account.go b/pool/account.go index 2f0ab4d..bba2fad 100644 --- a/pool/account.go +++ b/pool/account.go @@ -9,10 +9,13 @@ import ( "time" ) +const overageFrequencyScale = 10 + // AccountPool 账号池 type AccountPool struct { mu sync.RWMutex accounts []config.Account + totalAccounts int currentIndex uint64 cooldowns map[string]time.Time // 账号冷却时间 errorCounts map[string]int // 连续错误计数 @@ -43,15 +46,19 @@ func (p *AccountPool) Reload() { enabled := config.GetEnabledAccounts() var weighted []config.Account for _, a := range enabled { - w := a.Weight - if w < 1 { - w = 1 + w := effectiveWeight(a.Weight) * overageFrequencyScale + if isOverUsageLimit(a) { + if !a.AllowOverage { + continue + } + w = effectiveOverageWeight(a.OverageWeight) } for j := 0; j < w; j++ { weighted = append(weighted, a) } } p.accounts = weighted + p.totalAccounts = len(enabled) } // GetNext 获取下一个可用账号(加权轮询) @@ -89,7 +96,7 @@ func (p *AccountPool) GetNext() *config.Account { } // 跳过额度已用尽的账号(适用于所有订阅类型) - if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit { + if isOverUsageLimit(*acc) && !acc.AllowOverage { seen[acc.ID] = true continue } @@ -103,7 +110,7 @@ func (p *AccountPool) GetNext() *config.Account { for i := range p.accounts { acc := &p.accounts[i] // 额度用尽的账号不作为 fallback - if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit { + if isOverUsageLimit(*acc) && !acc.AllowOverage { continue } if cooldown, ok := p.cooldowns[acc.ID]; ok { @@ -165,7 +172,6 @@ func (p *AccountPool) UpdateToken(id, accessToken, refreshToken string, expiresA p.accounts[i].RefreshToken = refreshToken } p.accounts[i].ExpiresAt = expiresAt - break } } } @@ -174,7 +180,15 @@ func (p *AccountPool) UpdateToken(id, accessToken, refreshToken string, expiresA func (p *AccountPool) Count() int { p.mu.RLock() defer p.mu.RUnlock() - return len(p.accounts) + if p.totalAccounts > 0 { + return p.totalAccounts + } + + seen := make(map[string]bool) + for _, acc := range p.accounts { + seen[acc.ID] = true + } + return len(seen) } // AvailableCount 返回可用账号数 @@ -183,7 +197,12 @@ func (p *AccountPool) AvailableCount() int { defer p.mu.RUnlock() now := time.Now() count := 0 + seen := make(map[string]bool) for _, acc := range p.accounts { + if seen[acc.ID] { + continue + } + seen[acc.ID] = true if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) { continue } @@ -196,16 +215,36 @@ func (p *AccountPool) AvailableCount() int { func (p *AccountPool) UpdateStats(id string, tokens int, credits float64) { p.mu.Lock() defer p.mu.Unlock() + var updated bool + var requestCount, errorCount, totalTokens int + var totalCredits float64 + var lastUsed int64 for i := range p.accounts { if p.accounts[i].ID == id { - p.accounts[i].RequestCount++ - p.accounts[i].TotalTokens += tokens - p.accounts[i].TotalCredits += credits - p.accounts[i].LastUsed = time.Now().Unix() - go config.UpdateAccountStats(id, p.accounts[i].RequestCount, p.accounts[i].ErrorCount, p.accounts[i].TotalTokens, p.accounts[i].TotalCredits, p.accounts[i].LastUsed) - break + if !updated { + p.accounts[i].RequestCount++ + p.accounts[i].TotalTokens += tokens + p.accounts[i].TotalCredits += credits + p.accounts[i].LastUsed = time.Now().Unix() + + requestCount = p.accounts[i].RequestCount + errorCount = p.accounts[i].ErrorCount + totalTokens = p.accounts[i].TotalTokens + totalCredits = p.accounts[i].TotalCredits + lastUsed = p.accounts[i].LastUsed + updated = true + continue + } + p.accounts[i].RequestCount = requestCount + p.accounts[i].ErrorCount = errorCount + p.accounts[i].TotalTokens = totalTokens + p.accounts[i].TotalCredits = totalCredits + p.accounts[i].LastUsed = lastUsed } } + if updated { + go config.UpdateAccountStats(id, requestCount, errorCount, totalTokens, totalCredits, lastUsed) + } } // GetAllAccounts 获取所有账号副本 @@ -216,3 +255,24 @@ func (p *AccountPool) GetAllAccounts() []config.Account { copy(result, p.accounts) return result } + +func isOverUsageLimit(acc config.Account) bool { + return acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit +} + +func effectiveWeight(weight int) int { + if weight < 1 { + return 1 + } + return weight +} + +func effectiveOverageWeight(weight int) int { + if weight < 1 { + return 1 + } + if weight > overageFrequencyScale { + return overageFrequencyScale + } + return weight +} diff --git a/pool/account_test.go b/pool/account_test.go new file mode 100644 index 0000000..82f5513 --- /dev/null +++ b/pool/account_test.go @@ -0,0 +1,54 @@ +package pool + +import ( + "kiro-api-proxy/config" + "testing" +) + +func TestOverageAccountsAreSkippedByDefault(t *testing.T) { + p := &AccountPool{} + normal := config.Account{ID: "normal"} + overLimit := config.Account{ID: "over", UsageCurrent: 10, UsageLimit: 10} + + p.accounts = []config.Account{normal, overLimit} + + for i := 0; i < 5; i++ { + acc := p.GetNext() + if acc == nil { + t.Fatalf("expected an account") + } + if acc.ID == "over" { + t.Fatalf("expected over-limit account to be skipped by default") + } + } +} + +func TestOverageAccountsCanBeSelectedWhenAllowed(t *testing.T) { + p := &AccountPool{} + overLimit := config.Account{ + ID: "over", + UsageCurrent: 10, + UsageLimit: 10, + AllowOverage: true, + OverageWeight: 1, + } + + p.accounts = []config.Account{overLimit} + + acc := p.GetNext() + if acc == nil { + t.Fatalf("expected allowed overage account") + } + if acc.ID != "over" { + t.Fatalf("expected overage account, got %q", acc.ID) + } +} + +func TestOverageWeightIsLowerThanNormalWeight(t *testing.T) { + normalWeight := effectiveWeight(1) * overageFrequencyScale + overageWeight := effectiveOverageWeight(1) + + if overageWeight >= normalWeight { + t.Fatalf("expected overage weight %d to be lower than normal weight %d", overageWeight, normalWeight) + } +} diff --git a/proxy/cache_tracker.go b/proxy/cache_tracker.go index 2c07d5a..682b8c4 100644 --- a/proxy/cache_tracker.go +++ b/proxy/cache_tracker.go @@ -254,9 +254,10 @@ func flattenClaudeCacheBlocks(req *ClaudeRequest) []cacheablePromptBlock { "description": tool.Description, "input_schema": tool.InputSchema, } + fingerprintValue := stripCachePositionKeys(toolValue) blocks = append(blocks, cacheablePromptBlock{ - Value: toolValue, - Tokens: estimateApproxTokens(canonicalizeCacheValue(toolValue)), + Value: fingerprintValue, + Tokens: estimateApproxTokens(canonicalizeCacheValue(fingerprintValue)), TTL: normalizePromptCacheTTL(extractPromptCacheTTL(tool)), }) } @@ -357,59 +358,52 @@ func appendPromptBlock(blocks *[]cacheablePromptBlock, wrapper map[string]interf blockValue := wrapper["block"] ttl := normalizePromptCacheTTL(extractPromptCacheTTL(blockValue)) - // Normalize volatile text (e.g. Claude Code's x-anthropic-billing-header - // which drifts on every request) so that fingerprints remain stable across - // requests within the same conversation. - if normalized, changed := normalizeCacheBlockContent(blockValue); changed { - cloned := make(map[string]interface{}, len(wrapper)) - for k, v := range wrapper { - cloned[k] = v - } - cloned["block"] = normalized - wrapper = cloned + // Drop volatile billing metadata from the cache fingerprint. Claude Code's + // x-anthropic-billing-header can drift, appear, or disappear across + // otherwise identical requests, and it does not change model semantics. + if isAnthropicBillingHeaderBlock(blockValue) { + return } - canonical := canonicalizeCacheValue(wrapper) + fingerprintValue := stripCachePositionKeys(wrapper) + canonical := canonicalizeCacheValue(fingerprintValue) *blocks = append(*blocks, cacheablePromptBlock{ - Value: wrapper, + Value: fingerprintValue, Tokens: estimateApproxTokens(canonical), TTL: ttl, IsMessageEnd: isMessageEnd, }) } -// normalizeCacheBlockContent replaces volatile but semantically irrelevant -// fields with a placeholder so that the cumulative fingerprint stays stable -// across requests in the same session. Currently handles: -// - Claude Code's "x-anthropic-billing-header: ..." system text block -// whose content drifts on every request (version, telemetry hash, etc.) -func normalizeCacheBlockContent(value interface{}) (interface{}, bool) { +func stripCachePositionKeys(value map[string]interface{}) map[string]interface{} { + cloned := make(map[string]interface{}, len(value)) + for key, item := range value { + if isCachePositionKey(key) { + continue + } + cloned[key] = item + } + return cloned +} + +func isAnthropicBillingHeaderBlock(value interface{}) bool { blockMap, ok := value.(map[string]interface{}) if !ok { - return value, false + return false } // Only normalize text blocks (or blocks without an explicit type but containing text). if t, ok := blockMap["type"].(string); ok && t != "" && t != "text" { - return value, false + return false } text, ok := blockMap["text"].(string) if !ok { - return value, false + return false } trimmed := strings.TrimLeft(text, " \t\r\n") - if !strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") { - return value, false - } - - cloned := make(map[string]interface{}, len(blockMap)) - for k, v := range blockMap { - cloned[k] = v - } - cloned["text"] = "__anthropic_billing_header__" - return cloned, true + return strings.HasPrefix(strings.ToLower(trimmed), "x-anthropic-billing-header:") } func extractPromptCacheTTL(value interface{}) time.Duration { @@ -586,6 +580,15 @@ func writeCanonicalJSON(buf *bytes.Buffer, value interface{}) { } } +func isCachePositionKey(key string) bool { + switch key { + case "tool_index", "system_index", "message_index", "block_index": + return true + default: + return false + } +} + func writeHashChunk(hasher hashWriter, chunk string) { length := strconv.Itoa(len(chunk)) hasher.Write([]byte(length)) diff --git a/proxy/cache_tracker_test.go b/proxy/cache_tracker_test.go index f0130f3..2e3a1d8 100644 --- a/proxy/cache_tracker_test.go +++ b/proxy/cache_tracker_test.go @@ -77,7 +77,7 @@ func TestBuildClaudeUsageMapIncludesCacheFields(t *testing.T) { // TestPromptCacheStableAcrossBillingHeaderDrift verifies that Claude Code's // per-request "x-anthropic-billing-header: cc_version=...; cch=...;" system // block (whose content drifts on every request) does not break cache hits. -// The normalization logic should ensure the same conversation still matches. +// The tracker should ignore that metadata when fingerprinting cached prefixes. func TestPromptCacheStableAcrossBillingHeaderDrift(t *testing.T) { tracker := newPromptCacheTracker(time.Hour) mainSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) @@ -124,6 +124,92 @@ func TestPromptCacheStableAcrossBillingHeaderDrift(t *testing.T) { } } +func TestPromptCacheStableWhenBillingHeaderAppearsOrDisappears(t *testing.T) { + tracker := newPromptCacheTracker(time.Hour) + mainSystem := strings.Repeat("You are a helpful coding assistant with deep knowledge of Go, Rust, Python, and TypeScript. ", 80) + + build := func(includeBilling bool) *ClaudeRequest { + system := []interface{}{} + if includeBilling { + system = append(system, map[string]interface{}{ + "type": "text", + "text": "x-anthropic-billing-header: cc_version=2.1.87.1; cch=aaaa;", + }) + } + system = append(system, map[string]interface{}{ + "type": "text", + "text": mainSystem, + "cache_control": map[string]interface{}{ + "type": "ephemeral", + }, + }) + return &ClaudeRequest{ + Model: "claude-sonnet-4.5", + System: system, + Messages: []ClaudeMessage{{Role: "user", Content: "hello world"}}, + } + } + + withBilling := tracker.BuildClaudeProfile(build(true), 2048) + if withBilling == nil { + t.Fatalf("profile with billing header should be built") + } + tracker.Update("acct-1", withBilling) + + withoutBilling := tracker.BuildClaudeProfile(build(false), 2048) + if withoutBilling == nil { + t.Fatalf("profile without billing header should be built") + } + result := tracker.Compute("acct-1", withoutBilling) + if result.CacheReadInputTokens == 0 { + t.Fatalf("expected cache read when billing header disappears, got %+v", result) + } +} + +func TestCanonicalCacheValueIgnoresPositionKeys(t *testing.T) { + first := canonicalizeCacheValue(stripCachePositionKeys(map[string]interface{}{ + "kind": "system", + "system_index": 0, + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + }, + })) + second := canonicalizeCacheValue(stripCachePositionKeys(map[string]interface{}{ + "kind": "system", + "system_index": 1, + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + }, + })) + if first != second { + t.Fatalf("expected position keys to be ignored, got %q vs %q", first, second) + } +} + +func TestCanonicalCacheValuePreservesSemanticPositionKeys(t *testing.T) { + first := canonicalizeCacheValue(map[string]interface{}{ + "kind": "system", + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + "block_index": 1, + }, + }) + second := canonicalizeCacheValue(map[string]interface{}{ + "kind": "system", + "block": map[string]interface{}{ + "type": "text", + "text": "stable", + "block_index": 2, + }, + }) + if first == second { + t.Fatalf("expected semantic block_index fields to remain fingerprinted") + } +} + // TestPromptCacheImplicitBreakpointAtMessageEnd verifies that once any // explicit cache_control breakpoint has been seen, subsequent message-end // boundaries act as implicit breakpoints. This allows multi-turn conversations diff --git a/proxy/handler.go b/proxy/handler.go index a45158b..f89076c 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -6,6 +6,7 @@ import ( "io" "kiro-go/auth" "kiro-go/config" + "kiro-go/logger" "kiro-go/pool" "net/http" "strings" @@ -261,9 +262,9 @@ func (h *Handler) refreshAllAccounts() { // 检查 token 是否需要刷新 if account.ExpiresAt > 0 && time.Now().Unix() > account.ExpiresAt-300 { - newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) + newAccessToken, newRefreshToken, newExpiresAt, profileArn, err := auth.RefreshToken(account) if err != nil { - fmt.Printf("[BackgroundRefresh] Token refresh failed for %s: %v\n", account.Email, err) + logger.Warnf("[BackgroundRefresh] Token refresh failed for %s: %v", account.Email, err) continue } account.AccessToken = newAccessToken @@ -273,17 +274,21 @@ func (h *Handler) refreshAllAccounts() { account.ExpiresAt = newExpiresAt config.UpdateAccountToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) h.pool.UpdateToken(account.ID, newAccessToken, newRefreshToken, newExpiresAt) + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(account.ID, profileArn) + } } // 刷新账户信息 info, err := RefreshAccountInfo(account) if err != nil { - fmt.Printf("[BackgroundRefresh] Failed to refresh %s: %v\n", account.Email, err) + logger.Warnf("[BackgroundRefresh] Failed to refresh %s: %v", account.Email, err) continue } config.UpdateAccountInfo(account.ID, *info) - fmt.Printf("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f\n", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit) + logger.Infof("[BackgroundRefresh] Refreshed %s: %s %.1f/%.1f", account.Email, info.SubscriptionType, info.UsageCurrent, info.UsageLimit) } h.pool.Reload() } @@ -317,6 +322,9 @@ func (h *Handler) validateApiKey(r *http.Request) bool { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path + // Debug-level request trace for fine-grained visibility + logger.Debugf("[HTTP] %s %s from %s", r.Method, path, r.RemoteAddr) + // CORS - 完整的头部支持 w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") @@ -535,13 +543,13 @@ func (h *Handler) refreshModelsCache() { for i := range accounts { account := &accounts[i] if err := h.ensureValidToken(account); err != nil { - fmt.Printf("[ModelsCache] Skip %s token refresh failed: %v\n", account.Email, err) + logger.Warnf("[ModelsCache] Skip %s token refresh failed: %v", account.Email, err) continue } models, err := ListAvailableModels(account) if err != nil { - fmt.Printf("[ModelsCache] Failed to refresh for %s: %v\n", account.Email, err) + logger.Warnf("[ModelsCache] Failed to refresh for %s: %v", account.Email, err) continue } aggregated = mergeUniqueModels(aggregated, models) @@ -552,7 +560,7 @@ func (h *Handler) refreshModelsCache() { h.cachedModels = aggregated h.modelsCacheTime = time.Now().Unix() h.modelsCacheMu.Unlock() - fmt.Printf("[ModelsCache] Cached %d models\n", len(aggregated)) + logger.Infof("[ModelsCache] Cached %d models", len(aggregated)) } } @@ -1849,7 +1857,7 @@ func (h *Handler) ensureValidToken(account *config.Account) error { return nil } - accessToken, refreshToken, expiresAt, err := auth.RefreshToken(account) + accessToken, refreshToken, expiresAt, profileArn, err := auth.RefreshToken(account) if err != nil { return err } @@ -1861,6 +1869,10 @@ func (h *Handler) ensureValidToken(account *config.Account) error { account.RefreshToken = refreshToken } account.ExpiresAt = expiresAt + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(account.ID, profileArn) + } // 持久化 config.UpdateAccountToken(account.ID, accessToken, refreshToken, expiresAt) @@ -1991,6 +2003,8 @@ func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) { "hasToken": a.AccessToken != "", "machineId": a.MachineId, "weight": a.Weight, + "allowOverage": a.AllowOverage, + "overageWeight": a.OverageWeight, "subscriptionType": a.SubscriptionType, "subscriptionTitle": a.SubscriptionTitle, "daysRemaining": a.DaysRemaining, @@ -2085,6 +2099,12 @@ func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id st if v, ok := updates["weight"].(float64); ok { existing.Weight = int(v) } + if v, ok := updates["allowOverage"].(bool); ok { + existing.AllowOverage = v + } + if v, ok := updates["overageWeight"].(float64); ok { + existing.OverageWeight = clampInt(int(v), 1, 10) + } if err := config.UpdateAccount(id, *existing); err != nil { w.WriteHeader(500) @@ -2153,13 +2173,17 @@ func (h *Handler) apiBatchAccounts(w http.ResponseWriter, r *http.Request) { } // 刷新 token if account.RefreshToken != "" { - if newAccess, newRefresh, newExpires, err := auth.RefreshToken(account); err == nil { + if newAccess, newRefresh, newExpires, profileArn, err := auth.RefreshToken(account); err == nil { account.AccessToken = newAccess if newRefresh != "" { account.RefreshToken = newRefresh } account.ExpiresAt = newExpires config.UpdateAccountToken(id, newAccess, newRefresh, newExpires) + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(id, profileArn) + } h.pool.UpdateToken(id, newAccess, newRefresh, newExpires) } } @@ -2498,7 +2522,7 @@ func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) { AuthMethod: req.AuthMethod, Region: req.Region, } - newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(tempAccount) + newAccessToken, newRefreshToken, newExpiresAt, newProfileArn, err := auth.RefreshToken(tempAccount) if err != nil { // 刷新失败,如果有传入的 accessToken 则尝试使用 if req.AccessToken != "" { @@ -2534,6 +2558,7 @@ func (h *Handler) apiImportCredentials(w http.ResponseWriter, r *http.Request) { ExpiresAt: expiresAt, Enabled: true, MachineId: config.GenerateMachineId(), + ProfileArn: newProfileArn, } if err := config.AddAccount(account); err != nil { @@ -2646,7 +2671,7 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s if account.RefreshToken == "" { return nil } - newAccessToken, newRefreshToken, newExpiresAt, err := auth.RefreshToken(account) + newAccessToken, newRefreshToken, newExpiresAt, profileArn, err := auth.RefreshToken(account) if err != nil { return err } @@ -2657,6 +2682,10 @@ func (h *Handler) apiRefreshAccount(w http.ResponseWriter, r *http.Request, id s account.ExpiresAt = newExpiresAt config.UpdateAccountToken(id, newAccessToken, newRefreshToken, newExpiresAt) h.pool.UpdateToken(id, newAccessToken, newRefreshToken, newExpiresAt) + if profileArn != "" { + account.ProfileArn = profileArn + config.UpdateAccountProfileArn(id, profileArn) + } return nil } @@ -2766,6 +2795,9 @@ func (h *Handler) apiGetAccountFull(w http.ResponseWriter, r *http.Request, id s "region": account.Region, "expiresAt": account.ExpiresAt, "machineId": account.MachineId, + "weight": account.Weight, + "allowOverage": account.AllowOverage, + "overageWeight": account.OverageWeight, "enabled": account.Enabled, "banStatus": account.BanStatus, "banReason": account.BanReason, @@ -2881,8 +2913,9 @@ func (h *Handler) apiUpdateThinkingConfig(w http.ResponseWriter, r *http.Request // apiGetEndpointConfig 获取端点配置 func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]string{ + json.NewEncoder(w).Encode(map[string]interface{}{ "preferredEndpoint": config.GetPreferredEndpoint(), + "endpointFallback": config.GetEndpointFallback(), }) } @@ -2890,6 +2923,7 @@ func (h *Handler) apiGetEndpointConfig(w http.ResponseWriter, r *http.Request) { func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request) { var req struct { PreferredEndpoint string `json:"preferredEndpoint"` + EndpointFallback *bool `json:"endpointFallback"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { w.WriteHeader(400) @@ -2897,10 +2931,10 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request return } - valid := map[string]bool{"auto": true, "codewhisperer": true, "amazonq": true} + valid := map[string]bool{"auto": true, "kiro": true, "codewhisperer": true, "amazonq": true} if !valid[req.PreferredEndpoint] { w.WriteHeader(400) - json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, codewhisperer, or amazonq"}) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid endpoint, must be: auto, kiro, codewhisperer, or amazonq"}) return } @@ -2910,6 +2944,10 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request return } + if req.EndpointFallback != nil { + config.UpdateEndpointFallback(*req.EndpointFallback) + } + json.NewEncoder(w).Encode(map[string]bool{"success": true}) } @@ -3185,3 +3223,13 @@ func (h *Handler) apiExportAccounts(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(data) } + +func clampInt(v, min, max int) int { + if v < min { + return min + } + if v > max { + return max + } + return v +} diff --git a/proxy/kiro.go b/proxy/kiro.go index 5f4aa10..c957db5 100644 --- a/proxy/kiro.go +++ b/proxy/kiro.go @@ -1,5 +1,5 @@ -// Package proxy Kiro API 代理核心 -// 负责调用 Kiro API 并解析 AWS Event Stream 响应 +// Package proxy is the core proxy layer for the Kiro API. +// It handles streaming API calls to the Kiro backend and parses AWS Event Stream responses. package proxy import ( @@ -10,7 +10,7 @@ import ( "fmt" "io" "kiro-go/config" - "log" + "kiro-go/logger" "net/http" "net/url" "strconv" @@ -21,7 +21,7 @@ import ( "github.com/google/uuid" ) -// 双端点配置(429 时自动 fallback) +// Endpoint configuration (auto-fallback on quota exhaustion). type kiroEndpoint struct { URL string Origin string @@ -30,6 +30,12 @@ type kiroEndpoint struct { } var kiroEndpoints = []kiroEndpoint{ + { + URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", + Origin: "AI_EDITOR", + AmzTarget: "", + Name: "Kiro IDE", + }, { URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse", Origin: "AI_EDITOR", @@ -44,14 +50,15 @@ var kiroEndpoints = []kiroEndpoint{ }, } -// 全局 HTTP 客户端,支持运行时更换(代理重配置) +// Global HTTP clients, swappable at runtime to apply proxy reconfiguration without restart. var kiroHttpStore atomic.Pointer[http.Client] +var kiroRestHttpStore atomic.Pointer[http.Client] func init() { InitKiroHttpClient("") } -// buildKiroTransport 构建带可选代理的 Transport +// buildKiroTransport constructs an HTTP Transport with optional outbound proxy support. func buildKiroTransport(proxyURL string) *http.Transport { t := &http.Transport{ MaxIdleConns: 100, @@ -63,36 +70,52 @@ func buildKiroTransport(proxyURL string) *http.Transport { if proxyURL != "" { if u, err := url.Parse(proxyURL); err == nil { t.Proxy = http.ProxyURL(u) - // 代理不支持 HTTP/2 协议升级 + // Proxied connections cannot negotiate HTTP/2. t.ForceAttemptHTTP2 = false } + } else { + t.Proxy = http.ProxyFromEnvironment } return t } -// InitKiroHttpClient 初始化(或重新初始化)Kiro API 的 HTTP 客户端 +// InitKiroHttpClient initializes (or reinitializes) the HTTP clients used for Kiro API requests. func InitKiroHttpClient(proxyURL string) { client := &http.Client{ Timeout: 5 * time.Minute, Transport: buildKiroTransport(proxyURL), } kiroHttpStore.Store(client) + + restClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: buildKiroTransport(proxyURL), + } + kiroRestHttpStore.Store(restClient) } -// ==================== 请求结构 ==================== +// ==================== Request Structs ==================== -// KiroPayload Kiro API 请求体 +// KiroPayload is the top-level request body sent to the Kiro API. type KiroPayload struct { ConversationState struct { - ChatTriggerType string `json:"chatTriggerType"` - ConversationID string `json:"conversationId"` - CurrentMessage struct { + AgentContinuationId string `json:"agentContinuationId,omitempty"` + AgentTaskType string `json:"agentTaskType,omitempty"` + ChatTriggerType string `json:"chatTriggerType"` + ConversationID string `json:"conversationId"` + CurrentMessage struct { UserInputMessage KiroUserInputMessage `json:"userInputMessage"` } `json:"currentMessage"` History []KiroHistoryMessage `json:"history,omitempty"` } `json:"conversationState"` ProfileArn string `json:"profileArn,omitempty"` InferenceConfig *InferenceConfig `json:"inferenceConfig,omitempty"` + + // ToolNameMap maps sanitized tool names (sent to Kiro) back to the + // original names supplied by the client. Used to restore original names + // in tool_use responses so the client can match them to its tool registry. + // Not serialized to the Kiro API request body. + ToolNameMap map[string]string `json:"-"` } type KiroUserInputMessage struct { @@ -171,27 +194,78 @@ type KiroStreamCallback struct { OnContextUsage func(percentage float64) } -// ==================== API 调用 ==================== +// ==================== API Call ==================== -// getSortedEndpoints 根据首选端点配置排序端点列表 +// getSortedEndpoints returns endpoints ordered by user preference, with optional fallback. func getSortedEndpoints(preferred string) []kiroEndpoint { - if preferred == "amazonq" { - return []kiroEndpoint{kiroEndpoints[1], kiroEndpoints[0]} + fallback := config.GetEndpointFallback() + + var primary int + switch preferred { + case "kiro": + primary = 0 + case "codewhisperer": + primary = 1 + case "amazonq": + primary = 2 + default: + // "auto": Kiro first, then fallback to others + return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1], kiroEndpoints[2]} } - if preferred == "codewhisperer" { - return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]} + + if !fallback { + // No fallback: only use the selected endpoint + return []kiroEndpoint{kiroEndpoints[primary]} } - // "auto" 或空值:默认顺序 - return []kiroEndpoint{kiroEndpoints[0], kiroEndpoints[1]} + + // With fallback: selected first, then others in order + result := []kiroEndpoint{kiroEndpoints[primary]} + for i, ep := range kiroEndpoints { + if i != primary { + result = append(result, ep) + } + } + return result } -// CallKiroAPI 调用 Kiro API(流式),双端点自动 fallback +// CallKiroAPI calls the Kiro streaming API, trying each configured endpoint with automatic fallback. func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroStreamCallback) error { if _, err := json.Marshal(payload); err != nil { return err } - // 根据配置排序端点 + // Debug: dump full payload for troubleshooting upstream rejections + if payloadJSON, err := json.Marshal(payload); err == nil { + logger.Debugf("[KiroAPI] Request payload: %s", string(payloadJSON)) + } + + // Wrap OnToolUse to restore original tool names for the client. + if callback != nil && callback.OnToolUse != nil && len(payload.ToolNameMap) > 0 { + originalOnToolUse := callback.OnToolUse + nameMap := payload.ToolNameMap + wrapped := *callback + wrapped.OnToolUse = func(tu KiroToolUse) { + if original, ok := nameMap[tu.Name]; ok { + tu.Name = original + } + originalOnToolUse(tu) + } + callback = &wrapped + } + + if payload != nil && strings.TrimSpace(payload.ProfileArn) == "" { + if profileArn, err := ResolveProfileArn(account); err == nil { + payload.ProfileArn = profileArn + } else { + accountEmail := "" + if account != nil { + accountEmail = account.Email + } + logger.Warnf("[ProfileArn] Failed to resolve profile ARN for %s: %v", accountEmail, err) + } + } + + // Build endpoint list ordered by configuration. endpoints := getSortedEndpoints(config.GetPreferredEndpoint()) invalidModelRetries := config.GetInvalidModelRetries() firstByteTimeoutSec := config.GetFirstByteTimeoutSec() @@ -208,7 +282,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt for _, ep := range endpoints { epNames = append(epNames, shortEndpoint(ep.Name)) } - log.Printf("[KiroAPI] REQ %s model=%s account=%s endpoints=%s", reqID, shortModel(modelID), accountLabel, strings.Join(epNames, ",")) + logger.Infof("[KiroAPI] REQ %s model=%s account=%s endpoints=%s", reqID, shortModel(modelID), accountLabel, strings.Join(epNames, ",")) requestStart := time.Now() @@ -235,7 +309,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt cancel() lastErr = err lastStatus = "ERR" - log.Printf("[KiroAPI] ERR %s %s/a%d new_request %v", reqID, epShort, attempt, err) + logger.Warnf("[KiroAPI] ERR %s %s/a%d new_request %v", reqID, epShort, attempt, err) shouldFallback = true break } @@ -262,7 +336,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt cancel() lastErr = err lastStatus = "ERR" - log.Printf("[KiroAPI] ERR %s %s/a%d transport %s %v", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), err) + logger.Warnf("[KiroAPI] ERR %s %s/a%d transport %s %v", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), err) shouldFallback = true break } @@ -272,7 +346,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt cancel() lastErr = fmt.Errorf("quota exhausted on %s", ep.Name) lastStatus = "429" - log.Printf("[KiroAPI] 429 %s %s/a%d quota_exhausted %s", reqID, epShort, attempt, fmtMs(time.Since(attemptStart))) + logger.Infof("[KiroAPI] 429 %s %s/a%d quota_exhausted %s", reqID, epShort, attempt, fmtMs(time.Since(attemptStart))) shouldFallback = true break } @@ -291,23 +365,23 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt } if resp.StatusCode == 401 || resp.StatusCode == 403 { - log.Printf("[KiroAPI] %d %s %s/a%d auth_error %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200)) - log.Printf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus) + logger.Warnf("[KiroAPI] %d %s %s/a%d auth_error %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200)) + logger.Warnf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus) return lastErr } if resp.StatusCode == 400 && strings.Contains(bodyStr, "INVALID_MODEL_ID") { if invalidModelUsed < invalidModelRetries { invalidModelUsed++ - log.Printf("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s retry %d/%d", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), invalidModelUsed, invalidModelRetries) + logger.Infof("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s retry %d/%d", reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), invalidModelUsed, invalidModelRetries) continue } - log.Printf("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s exhausted → fallback", reqID, epShort, attempt, fmtMs(time.Since(attemptStart))) + logger.Warnf("[KiroAPI] 400 %s %s/a%d INVALID_MODEL_ID %s exhausted → fallback", reqID, epShort, attempt, fmtMs(time.Since(attemptStart))) shouldFallback = true break } - log.Printf("[KiroAPI] %d %s %s/a%d %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200)) + logger.Warnf("[KiroAPI] %d %s %s/a%d %s %s", resp.StatusCode, reqID, epShort, attempt, fmtMs(time.Since(attemptStart)), truncateForLog(bodyStr, 200)) shouldFallback = true break } @@ -346,11 +420,11 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt if firstByteUsed < firstByteRetries { firstByteUsed++ lastErr = fmt.Errorf("first-byte timeout after %ds", firstByteTimeoutSec) - log.Printf("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds retry %d/%d", reqID, epShort, attempt, firstByteTimeoutSec, firstByteUsed, firstByteRetries) + logger.Infof("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds retry %d/%d", reqID, epShort, attempt, firstByteTimeoutSec, firstByteUsed, firstByteRetries) continue } lastErr = fmt.Errorf("first-byte timeout after %ds on %s", firstByteTimeoutSec, ep.Name) - log.Printf("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds exhausted → fallback", reqID, epShort, attempt, firstByteTimeoutSec) + logger.Warnf("[KiroAPI] TIMEOUT %s %s/a%d first_byte>%ds exhausted → fallback", reqID, epShort, attempt, firstByteTimeoutSec) shouldFallback = true break } @@ -359,7 +433,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt if err != nil { status = "ERR" } - log.Printf("[KiroAPI] %s %s %s/a%d first_byte=%s total=%s", status, reqID, epShort, attempt, fmtMs(firstByteAt), fmtMs(time.Since(requestStart))) + logger.Infof("[KiroAPI] %s %s %s/a%d first_byte=%s total=%s", status, reqID, epShort, attempt, fmtMs(firstByteAt), fmtMs(time.Since(requestStart))) return err } @@ -368,14 +442,14 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt } } - log.Printf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus) + logger.Warnf("[KiroAPI] FAIL %s all endpoints failed %s last=%s", reqID, fmtMs(time.Since(requestStart)), lastStatus) if lastErr != nil { return lastErr } return fmt.Errorf("all endpoints failed") } -// shortReqID 生成 6 字符请求标识(base36) +// shortReqID generates a 6-character request identifier (hex). func shortReqID() string { var buf [3]byte if _, err := cryptoRand.Read(buf[:]); err != nil { @@ -384,7 +458,7 @@ func shortReqID() string { return fmt.Sprintf("%02x%02x%02x", buf[0], buf[1], buf[2]) } -// shortEndpoint 把端点名缩短到 2 字符便于视觉对齐 +// shortEndpoint abbreviates endpoint names to 2 chars for log alignment. func shortEndpoint(name string) string { switch name { case "CodeWhisperer": @@ -399,7 +473,7 @@ func shortEndpoint(name string) string { } } -// shortModel 把长模型名截短:claude-opus-4.7 → opus-4.7 +// shortModel trims the "claude-" prefix: claude-opus-4.7 → opus-4.7 func shortModel(m string) string { if strings.HasPrefix(m, "claude-") { return m[len("claude-"):] @@ -410,7 +484,7 @@ func shortModel(m string) string { return m } -// fmtMs 把耗时格式化成紧凑字符串:<1s 用 ms,>=1s 用 1 位小数 s +// fmtMs formats a duration compactly: <1s uses ms, >=1s uses 1 decimal place. func fmtMs(d time.Duration) string { if d <= 0 { return "0ms" @@ -429,13 +503,10 @@ func truncateForLog(s string, max int) string { return s[:max] + "...(truncated)" } -// ==================== Event Stream 解析 ==================== - -// parseEventStream 解析 AWS Event Stream 二进制格式 -// onFirstByte 会在读完第一个完整 event-stream 包 prelude 时触发一次(只一次), -// 供外层判断「首字节是否已收到」,以决定首字节超时时是否应该重试。 +// parseEventStream decodes an AWS binary Event Stream response body. +// onFirstByte fires once when the first complete prelude is read, used for first-byte timeout detection. func parseEventStream(body io.Reader, callback *KiroStreamCallback, onFirstByte func()) error { - // 不使用 bufio,直接读取避免缓冲延迟 + // Read directly without bufio to avoid buffering latency in streaming responses. var inputTokens, outputTokens int var totalCredits float64 var currentToolUse *toolUseState @@ -468,7 +539,7 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, onFirstByte continue } - // 读取剩余部分 + // Read the remaining message bytes. remaining := totalLength - 12 msgBuf := make([]byte, remaining) _, err = io.ReadFull(body, msgBuf) @@ -493,7 +564,7 @@ func parseEventStream(body io.Reader, callback *KiroStreamCallback, onFirstByte inputTokens, outputTokens = updateTokensFromEvent(event, inputTokens, outputTokens) - // 处理事件 + // Dispatch by event type. switch eventType { case "assistantResponseEvent": if content, ok := event["content"].(string); ok && content != "" { @@ -689,7 +760,7 @@ func readTokenNumber(m map[string]interface{}, keys ...string) (int, bool) { return 0, false } -// ==================== Tool Use 处理 ==================== +// ==================== Tool Use Handling ==================== type toolUseState struct { ToolUseID string @@ -744,7 +815,7 @@ func finishToolUse(state *toolUseState, callback *KiroStreamCallback) { }) } -// extractEventType 从 headers 中提取事件类型 +// extractEventType extracts the event type string from AWS Event Stream message headers. func extractEventType(headers []byte) string { offset := 0 for offset < len(headers) { @@ -781,7 +852,7 @@ func extractEventType(headers []byte) string { continue } - // 跳过其他类型 + // Skip other value types by their fixed byte widths. skipSizes := map[byte]int{0: 0, 1: 0, 2: 1, 3: 2, 4: 4, 5: 8, 8: 8, 9: 16} if valueType == 6 { if offset+2 > len(headers) { diff --git a/proxy/kiro_api.go b/proxy/kiro_api.go index 948336e..13367ab 100644 --- a/proxy/kiro_api.go +++ b/proxy/kiro_api.go @@ -4,8 +4,11 @@ import ( "encoding/json" "fmt" "io" + "kiro-go/auth" "kiro-go/config" + "kiro-go/logger" "net/http" + neturl "net/url" "strings" "time" ) @@ -17,6 +20,7 @@ const ( // GetUsageLimits 获取账户使用量和订阅信息 func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) { url := fmt.Sprintf("%s/getUsageLimits?origin=AI_EDITOR&resourceType=AGENTIC_REQUEST&isEmailRequired=true", kiroRestAPIBase) + url = withProfileArnQuery(url, account) req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -25,8 +29,7 @@ func GetUsageLimits(account *config.Account) (*UsageLimitsResponse, error) { setKiroHeaders(req, account) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := kiroRestHttpStore.Load().Do(req) if err != nil { return nil, err } @@ -57,8 +60,7 @@ func GetUserInfo(account *config.Account) (*UserInfoResponse, error) { setKiroHeaders(req, account) req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := kiroRestHttpStore.Load().Do(req) if err != nil { return nil, err } @@ -79,6 +81,7 @@ func GetUserInfo(account *config.Account) (*UserInfoResponse, error) { // ListAvailableModels 获取可用模型列表 func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { url := fmt.Sprintf("%s/ListAvailableModels?origin=AI_EDITOR&maxResults=50", kiroRestAPIBase) + url = withProfileArnQuery(url, account) req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -87,8 +90,7 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { setKiroHeaders(req, account) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := kiroRestHttpStore.Load().Do(req) if err != nil { return nil, err } @@ -108,6 +110,88 @@ func ListAvailableModels(account *config.Account) ([]ModelInfo, error) { return result.Models, nil } +// ResolveProfileArn returns the account profile ARN, fetching and caching it +// when it is missing. First tries ListAvailableProfiles; if that returns empty, +// falls back to refreshing the token (which returns profileArn in the response). +func ResolveProfileArn(account *config.Account) (string, error) { + if account == nil { + return "", fmt.Errorf("account is nil") + } + if profileArn := strings.TrimSpace(account.ProfileArn); profileArn != "" { + return profileArn, nil + } + + // Try ListAvailableProfiles first + profileArn, err := listAvailableProfiles(account) + if err == nil && profileArn != "" { + if updateErr := config.UpdateAccountProfileArn(account.ID, profileArn); updateErr != nil { + logger.Warnf("[ProfileArn] Failed to cache profile ARN for %s: %v", account.Email, updateErr) + } + account.ProfileArn = profileArn + return profileArn, nil + } + + // Fallback: refresh token to get profileArn from auth response + if account.RefreshToken != "" { + _, _, _, refreshedArn, refreshErr := auth.RefreshToken(account) + if refreshErr == nil && refreshedArn != "" { + if updateErr := config.UpdateAccountProfileArn(account.ID, refreshedArn); updateErr != nil { + logger.Warnf("[ProfileArn] Failed to cache profile ARN for %s: %v", account.Email, updateErr) + } + account.ProfileArn = refreshedArn + return refreshedArn, nil + } + } + + return "", fmt.Errorf("no available Kiro profile") +} + +func listAvailableProfiles(account *config.Account) (string, error) { + req, err := http.NewRequest("POST", fmt.Sprintf("%s/ListAvailableProfiles", kiroRestAPIBase), strings.NewReader(`{"maxResults":10}`)) + if err != nil { + return "", err + } + setKiroHeaders(req, account) + req.Header.Set("Content-Type", "application/json") + + resp, err := kiroRestHttpStore.Load().Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + } `json:"profiles"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + for _, profile := range result.Profiles { + if profileArn := strings.TrimSpace(profile.Arn); profileArn != "" { + return profileArn, nil + } + } + return "", fmt.Errorf("empty profile list") +} + +func withProfileArnQuery(rawURL string, account *config.Account) string { + if account == nil { + return rawURL + } + profileArn := strings.TrimSpace(account.ProfileArn) + if profileArn == "" { + return rawURL + } + return rawURL + "&profileArn=" + neturl.QueryEscape(profileArn) +} + func setKiroHeaders(req *http.Request, account *config.Account) { host := "" if req.URL != nil { @@ -132,7 +216,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { errMsg := err.Error() if strings.Contains(errMsg, "TEMPORARILY_SUSPENDED") { // 账户被暂时封禁,自动禁用并标记封禁状态 - fmt.Printf("[RefreshAccountInfo] Account %s is temporarily suspended: %v\n", account.Email, err) + logger.Warnf("[RefreshAccountInfo] Account %s is temporarily suspended: %v", account.Email, err) // 更新账户封禁状态并自动禁用 updatedAccount := *account @@ -143,14 +227,14 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 保存更新后的账户状态 if updateErr := config.UpdateAccount(account.ID, updatedAccount); updateErr != nil { - fmt.Printf("[RefreshAccountInfo] Failed to update account ban status: %v\n", updateErr) + logger.Errorf("[RefreshAccountInfo] Failed to update account ban status: %v", updateErr) } return nil, fmt.Errorf("Account suspended: %w", err) } else if strings.Contains(errMsg, "403") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid") || strings.Contains(errMsg, "expired") { // Token 相关错误,可能需要重新认证 - fmt.Printf("[RefreshAccountInfo] Authentication error for %s: %v\n", account.Email, err) + logger.Warnf("[RefreshAccountInfo] Authentication error for %s: %v", account.Email, err) // 更新账户封禁状态为认证失败并自动禁用 updatedAccount := *account @@ -161,7 +245,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 保存更新后的账户状态 if updateErr := config.UpdateAccount(account.ID, updatedAccount); updateErr != nil { - fmt.Printf("[RefreshAccountInfo] Failed to update account ban status: %v\n", updateErr) + logger.Errorf("[RefreshAccountInfo] Failed to update account ban status: %v", updateErr) } } @@ -170,7 +254,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 如果成功获取信息,清除封禁状态(如果之前被标记) if account.BanStatus != "" && account.BanStatus != "ACTIVE" { - fmt.Printf("[RefreshAccountInfo] Account %s is now active, clearing ban status\n", account.Email) + logger.Infof("[RefreshAccountInfo] Account %s is now active, clearing ban status", account.Email) updatedAccount := *account updatedAccount.BanStatus = "ACTIVE" @@ -179,7 +263,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { // 保存更新后的账户状态 if updateErr := config.UpdateAccount(account.ID, updatedAccount); updateErr != nil { - fmt.Printf("[RefreshAccountInfo] Failed to clear account ban status: %v\n", updateErr) + logger.Errorf("[RefreshAccountInfo] Failed to clear account ban status: %v", updateErr) } } @@ -204,7 +288,7 @@ func RefreshAccountInfo(account *config.Account) (*config.AccountInfo, error) { if info.SubscriptionTitle == "" { info.SubscriptionTitle = usage.SubscriptionInfo.SubscriptionName } - fmt.Printf("[RefreshAccountInfo] Subscription: type=%s, title=%s, name=%s, parsed=%s\n", + logger.Debugf("[RefreshAccountInfo] Subscription: type=%s, title=%s, name=%s, parsed=%s", usage.SubscriptionInfo.SubscriptionType, usage.SubscriptionInfo.SubscriptionTitle, usage.SubscriptionInfo.SubscriptionName, diff --git a/proxy/kiro_api_test.go b/proxy/kiro_api_test.go new file mode 100644 index 0000000..4fce7cd --- /dev/null +++ b/proxy/kiro_api_test.go @@ -0,0 +1,96 @@ +package proxy + +import ( + "io" + "kiro-go/config" + "net/http" + "path/filepath" + "strings" + "testing" +) + +func TestResolveProfileArnReturnsCachedValueWithoutRequest(t *testing.T) { + kiroRestHttpStore.Store(&http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + t.Fatal("unexpected HTTP request for cached profile ARN") + return nil, nil + }), + }) + t.Cleanup(func() { InitKiroHttpClient("") }) + + account := &config.Account{ProfileArn: " arn:aws:codewhisperer:profile/test "} + got, err := ResolveProfileArn(account) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "arn:aws:codewhisperer:profile/test" { + t.Fatalf("expected trimmed cached ARN, got %q", got) + } +} + +func TestResolveProfileArnFetchesAndCachesProfile(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + if err := config.Init(configPath); err != nil { + t.Fatalf("init config: %v", err) + } + account := config.Account{ + ID: "acct-1", + Email: "user@example.com", + AccessToken: "access-token", + Region: "us-east-1", + UsageCurrent: 7, + } + if err := config.AddAccount(account); err != nil { + t.Fatalf("add account: %v", err) + } + + kiroRestHttpStore.Store(&http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", req.Method) + } + if req.URL.Path != "/ListAvailableProfiles" { + t.Fatalf("expected ListAvailableProfiles path, got %s", req.URL.Path) + } + if got := req.Header.Get("Content-Type"); got != "application/json" { + t.Fatalf("expected JSON content type, got %q", got) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"profiles":[{"arn":" arn:aws:codewhisperer:profile/fetched "}]} `)), + Header: make(http.Header), + }, nil + }), + }) + t.Cleanup(func() { InitKiroHttpClient("") }) + + requestAccount := account + requestAccount.UsageCurrent = 0 + got, err := ResolveProfileArn(&requestAccount) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "arn:aws:codewhisperer:profile/fetched" { + t.Fatalf("expected fetched ARN, got %q", got) + } + if requestAccount.ProfileArn != got { + t.Fatalf("expected account to be updated with fetched ARN, got %q", requestAccount.ProfileArn) + } + + accounts := config.GetAccounts() + if len(accounts) != 1 { + t.Fatalf("expected one persisted account, got %d", len(accounts)) + } + if accounts[0].ProfileArn != got { + t.Fatalf("expected persisted account profile ARN %q, got %q", got, accounts[0].ProfileArn) + } + if accounts[0].UsageCurrent != 7 { + t.Fatalf("expected profile cache update to preserve usage fields, got usageCurrent=%v", accounts[0].UsageCurrent) + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} diff --git a/proxy/kiro_test.go b/proxy/kiro_test.go index f32190b..003e544 100644 --- a/proxy/kiro_test.go +++ b/proxy/kiro_test.go @@ -1,6 +1,11 @@ package proxy -import "testing" +import ( + "net/http" + "net/url" + "testing" + "time" +) func TestNormalizeChunkBasicProgression(t *testing.T) { prev := "" @@ -35,3 +40,63 @@ func TestNormalizeChunkOverlapDelta(t *testing.T) { t.Fatalf("expected overlap suffix delta, got %q", got) } } + +func TestBuildKiroTransportUsesExplicitProxyURL(t *testing.T) { + transport := buildKiroTransport("http://proxy.local:8080") + req := &http.Request{URL: mustParseURL(t, "https://q.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://proxy.local:8080") +} + +func TestBuildKiroTransportFallsBackToEnvironmentProxy(t *testing.T) { + t.Setenv("HTTPS_PROXY", "http://env-proxy.local:2323") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + transport := buildKiroTransport("") + req := &http.Request{URL: mustParseURL(t, "https://q.us-east-1.amazonaws.com")} + + got, err := transport.Proxy(req) + if err != nil { + t.Fatalf("unexpected proxy error: %v", err) + } + assertProxyURL(t, got, "http://env-proxy.local:2323") +} + +func TestInitKiroHttpClientKeepsShortRestTimeout(t *testing.T) { + InitKiroHttpClient("") + t.Cleanup(func() { InitKiroHttpClient("") }) + + streamClient := kiroHttpStore.Load() + restClient := kiroRestHttpStore.Load() + + if streamClient.Timeout != 5*time.Minute { + t.Fatalf("expected streaming timeout to be 5m, got %s", streamClient.Timeout) + } + if restClient.Timeout != 30*time.Second { + t.Fatalf("expected REST timeout to stay 30s, got %s", restClient.Timeout) + } +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + parsed, err := url.Parse(raw) + if err != nil { + t.Fatalf("invalid test URL: %v", err) + } + return parsed +} + +func assertProxyURL(t *testing.T, got *url.URL, want string) { + t.Helper() + if got == nil { + t.Fatalf("expected proxy URL %q, got nil", want) + } + if got.String() != want { + t.Fatalf("expected proxy URL %q, got %q", want, got.String()) + } +} diff --git a/proxy/translator.go b/proxy/translator.go index 1c11fd7..3e8f5c4 100644 --- a/proxy/translator.go +++ b/proxy/translator.go @@ -296,11 +296,14 @@ func ClaudeToKiro(req *ClaudeRequest, thinking bool) *KiroPayload { } // 转换工具 - kiroTools := convertClaudeTools(req.Tools) + kiroTools, toolNameMap := convertClaudeTools(req.Tools) // 构建 payload payload := &KiroPayload{} + payload.ToolNameMap = toolNameMap payload.ConversationState.ChatTriggerType = "MANUAL" + payload.ConversationState.AgentTaskType = "vibe" + payload.ConversationState.AgentContinuationId = uuid.New().String() payload.ConversationState.ConversationID = buildConversationID(modelID, systemPrompt, firstClaudeConversationAnchor(req.Messages)) payload.ConversationState.CurrentMessage.UserInputMessage = KiroUserInputMessage{ Content: finalContent, @@ -571,21 +574,115 @@ func extractClaudeAssistantContent(content interface{}) (string, []KiroToolUse) return text, toolUses } -func convertClaudeTools(tools []ClaudeTool) []KiroToolWrapper { +func convertClaudeTools(tools []ClaudeTool) ([]KiroToolWrapper, map[string]string) { if len(tools) == 0 { - return nil + return nil, nil } - result := make([]KiroToolWrapper, len(tools)) - for i, tool := range tools { + result := make([]KiroToolWrapper, 0, len(tools)) + nameMap := make(map[string]string) + for _, tool := range tools { desc := tool.Description if len(desc) > maxToolDescLen { desc = desc[:maxToolDescLen] + "..." } - result[i] = KiroToolWrapper{} - result[i].ToolSpecification.Name = shortenToolName(tool.Name) - result[i].ToolSpecification.Description = desc - result[i].ToolSpecification.InputSchema = InputSchema{JSON: tool.InputSchema} + sanitized := shortenToolName(sanitizeToolName(tool.Name)) + if sanitized != tool.Name { + nameMap[sanitized] = tool.Name + } + w := KiroToolWrapper{} + w.ToolSpecification.Name = sanitized + w.ToolSpecification.Description = desc + w.ToolSpecification.InputSchema = InputSchema{JSON: ensureObjectSchema(tool.InputSchema)} + result = append(result, w) + } + return result, nameMap +} + +// ensureObjectSchema ensures the JSON schema has "type": "object" at the top level +// and removes invalid null values from "required" fields (recursively). +// Kiro API rejects tool schemas with "required": null. +func ensureObjectSchema(schema interface{}) interface{} { + m, ok := schema.(map[string]interface{}) + if !ok { + return map[string]interface{}{"type": "object"} + } + cleanSchema(m) + if _, hasType := m["type"]; !hasType { + m["type"] = "object" + } + return m +} + +// cleanSchema recursively removes or fixes invalid "required": null entries +// in a JSON Schema tree. +func cleanSchema(m map[string]interface{}) { + // Fix "required" field: must be array or absent + if req, exists := m["required"]; exists { + if req == nil { + delete(m, "required") + } else if arr, ok := req.([]interface{}); ok && len(arr) == 0 { + delete(m, "required") + } + } + + // Recurse into "properties" + if props, ok := m["properties"].(map[string]interface{}); ok { + for _, v := range props { + if sub, ok := v.(map[string]interface{}); ok { + cleanSchema(sub) + } + } + } + + // Recurse into "items" + if items, ok := m["items"].(map[string]interface{}); ok { + cleanSchema(items) + } + + // Recurse into nested object schemas (e.g., additionalProperties, allOf, oneOf, anyOf) + for _, key := range []string{"additionalProperties"} { + if sub, ok := m[key].(map[string]interface{}); ok { + cleanSchema(sub) + } + } + for _, key := range []string{"allOf", "oneOf", "anyOf"} { + if arr, ok := m[key].([]interface{}); ok { + for _, item := range arr { + if sub, ok := item.(map[string]interface{}); ok { + cleanSchema(sub) + } + } + } + } +} + +// sanitizeToolName normalizes a tool name to characters the Kiro API accepts. +// Kiro tool names must be pure camelCase (no underscores or dashes). +// Separators (_, -, and multi-underscore namespace prefixes) are converted to camelCase boundaries. +func sanitizeToolName(name string) string { + // Split on underscores and dashes, including multi-underscore namespace prefixes. + parts := strings.FieldsFunc(name, func(r rune) bool { + return r == '_' || r == '-' + }) + if len(parts) == 0 { + return "tool" + } + // Build camelCase: first part lowercase start, rest capitalize first letter + var b strings.Builder + for i, part := range parts { + if part == "" { + continue + } + if i == 0 { + b.WriteString(strings.ToLower(part[:1]) + part[1:]) + } else { + b.WriteString(strings.ToUpper(part[:1]) + part[1:]) + } + } + result := b.String() + if result == "" { + return "tool" } return result } diff --git a/version.json b/version.json index 14e14d7..ce7d713 100644 --- a/version.json +++ b/version.json @@ -1,5 +1,5 @@ { - "version": "1.0.6", + "version": "1.0.7", "changelog": "✨ Added and fixed several improvements across the project.\n✨ 新增并修复了一些内容,包含若干功能改进与问题修复。", "download": "https://github.com/Quorinex/Kiro-Go" } diff --git a/web/index.html b/web/index.html index 26accb1..8c57aac 100644 --- a/web/index.html +++ b/web/index.html @@ -1025,12 +1025,21 @@ +
+ + +
@@ -1166,6 +1175,7 @@ 'accounts.copyJSON': '复制 JSON', 'accounts.copyJSONSuccess': 'JSON 已复制到剪贴板', 'accounts.trialDays': '天后到期', + 'accounts.overage': '超额调用', 'time.expired': '已过期', 'time.minutes': '分钟', 'time.hours': '小时', @@ -1195,6 +1205,8 @@ 'settings.preferredEndpoint': '首选端点', 'settings.endpointAuto': '自动选择', 'settings.endpointHint': '选择首选端点,自动选择模式下会根据可用性自动选择端点', + 'settings.endpointFallback': '端点不可用时自动切换', + 'settings.endpointFallbackHint': '关闭后,仅使用选定的端点,不会自动切换到其他端点', 'settings.saveEndpoint': '保存端点设置', 'settings.endpointSaved': '端点设置已保存', 'settings.adminPassword': '管理密码', @@ -1262,6 +1274,20 @@ 'modal.localDesc': '通过 Kiro IDE 本地缓存文件添加账号', 'modal.credentialsTitle': '凭证 JSON', 'modal.credentialsDesc': '通过 Kiro Account Manager 导出的凭证添加账号', + 'modal.cookieTitle': 'Kiro 网页 Cookie', + 'modal.cookieDesc': '通过从 Kiro 网页 Cookie 中获取的 RefreshToken 添加账号', + 'cookie.howToGet': '如何获取 RefreshToken?', + 'cookie.step1': '打开浏览器,访问并登录', + 'cookie.step2': '在页面空白处右键 → 检查 → Application → Cookies → https://app.kiro.dev', + 'cookie.step3': '在列表中找到 RefreshToken,双击其 Value 列并复制', + 'cookie.refreshToken': 'RefreshToken', + 'cookie.provider': '登录方式', + 'cookie.github': 'GitHub', + 'cookie.google': 'Google', + 'cookie.refreshTokenPlaceholder': '粘贴 RefreshToken Cookie 的值', + 'cookie.refreshTokenMissing': '请填写 RefreshToken', + 'cookie.importSuccess': '账号添加成功', + 'cookie.link': 'https://app.kiro.dev/account/usage', 'builderid.startLogin': '开始登录', 'builderid.verifyCode': '请在浏览器中输入上方验证码', 'builderid.verifyUrl': '验证链接', @@ -1351,7 +1377,10 @@ 'filter.banned': '已封禁', 'accounts.weight': '权重', 'detail.weight': '请求权重', - 'detail.weightHint': '0-1=普通, 2+=高优先级' + 'detail.weightHint': '0-1=普通, 2+=高优先级', + 'detail.overage': '超额调用设置', + 'detail.allowOverage': '配额耗尽后继续调用', + 'detail.overageHint': '超额后调用频率权重(1~10),值越大频率越高' }, en: { 'login.subtitle': 'Enter admin password to login', @@ -1391,6 +1420,7 @@ 'accounts.confirmDelete': 'Confirm delete?', 'accounts.copyJSON': 'Copy JSON', 'accounts.copyJSONSuccess': 'JSON copied to clipboard', + 'accounts.overage': 'Overage', 'time.expired': 'Expired', 'time.minutes': 'min', 'time.hours': 'hr', @@ -1420,6 +1450,8 @@ 'settings.preferredEndpoint': 'Preferred Endpoint', 'settings.endpointAuto': 'Auto', 'settings.endpointHint': 'Select preferred endpoint. In auto-select mode, the endpoint is automatically selected based on availability.', + 'settings.endpointFallback': 'Fallback to other endpoints', + 'settings.endpointFallbackHint': 'When disabled, only the selected endpoint is used without automatic fallback', 'settings.saveEndpoint': 'Save Endpoint Settings', 'settings.endpointSaved': 'Endpoint settings saved', 'settings.adminPassword': 'Admin Password', @@ -1483,6 +1515,20 @@ 'modal.localDesc': 'Add account via Kiro IDE local cache files', 'modal.credentialsTitle': 'Credentials JSON', 'modal.credentialsDesc': 'Add account via Kiro Account Manager exported credentials', + 'modal.cookieTitle': 'Kiro Web Cookie', + 'modal.cookieDesc': 'Add account via RefreshToken obtained from Kiro Web Cookie', + 'cookie.howToGet': 'How to get the RefreshToken?', + 'cookie.step1': 'Open your browser and sign in at', + 'cookie.step2': 'Right-click on blank area → Inspect → Application → Cookies → https://app.kiro.dev', + 'cookie.step3': 'Find RefreshToken in the list, double-click its Value and copy it', + 'cookie.refreshToken': 'RefreshToken', + 'cookie.provider': 'Login Provider', + 'cookie.github': 'GitHub', + 'cookie.google': 'Google', + 'cookie.refreshTokenPlaceholder': 'Paste the RefreshToken cookie value here', + 'cookie.refreshTokenMissing': 'RefreshToken is required', + 'cookie.importSuccess': 'Account added successfully', + 'cookie.link': 'https://app.kiro.dev/account/usage', 'builderid.startLogin': 'Start Login', 'builderid.verifyCode': 'Enter the code above in your browser', 'builderid.verifyUrl': 'Verification URL', @@ -1572,7 +1618,10 @@ 'filter.banned': 'Banned', 'accounts.weight': 'Weight', 'detail.weight': 'Request Weight', - 'detail.weightHint': '0-1=normal, 2+=higher priority' + 'detail.weightHint': '0-1=normal, 2+=higher priority', + 'detail.overage': 'Overage Settings', + 'detail.allowOverage': 'Continue calling after quota is exhausted', + 'detail.overageHint': 'Call frequency weight when over quota (1–10); higher = more frequent' } }; let currentLang = localStorage.getItem('kiro_lang') || 'zh'; @@ -1838,6 +1887,7 @@ const isSelected = selectedAccounts.has(a.id); const weightVal = a.weight || 0; const weightBadge = weightVal >= 2 ? 'W:' + weightVal + '' : ''; + const overageBadge = a.allowOverage ? '' + t('accounts.overage') + ':' + (a.overageWeight || 1) + '' : ''; return '
' + '
' + + '

' + t('detail.overage') + '

' + + '' + + '' + + '' + t('detail.overageHint') + '' + + '' + + '
' + '

' + t('detail.subscription') + '

' + '
' + t('detail.subscriptionType') + '
' + (a.subscriptionTitle || a.subscriptionType || '-') + '
' + '
' + t('detail.tokenExpiry') + '
' + (a.expiresAt ? new Date(a.expiresAt * 1000).toLocaleString() : '-') + '
' + @@ -2062,6 +2119,20 @@ if (d.success) { alert(t('detail.saved')); loadAccounts(); } else { alert(t('detail.saveFailed') + ': ' + d.error); } } catch (e) { alert(t('detail.saveFailed')); } } + async function saveOverageSettings(id) { + const allowOverage = document.getElementById('allowOverageInput').checked; + let overageWeight = parseInt(document.getElementById('overageWeightInput').value) || 1; + overageWeight = Math.max(1, Math.min(10, overageWeight)); + document.getElementById('overageWeightInput').value = overageWeight; + try { + const res = await fetch('/admin/api/accounts/' + id, { + method: 'PUT', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, + body: JSON.stringify({ allowOverage, overageWeight }) + }); + const d = await res.json(); + if (d.success) { alert(t('detail.saved')); loadAccounts(); } else { alert(t('detail.saveFailed') + ': ' + d.error); } + } catch (e) { alert(t('detail.saveFailed')); } + } async function quickSetWeight(id, value) { const weight = parseInt(value) || 0; try { @@ -2102,11 +2173,12 @@ const res = await fetch('/admin/api/endpoint', { headers: { 'X-Admin-Password': password } }); const d = await res.json(); document.getElementById('preferredEndpoint').value = d.preferredEndpoint || 'auto'; + document.getElementById('endpointFallback').checked = d.endpointFallback !== false; } async function saveEndpointConfig() { const res = await fetch('/admin/api/endpoint', { method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, - body: JSON.stringify({ preferredEndpoint: document.getElementById('preferredEndpoint').value }) + body: JSON.stringify({ preferredEndpoint: document.getElementById('preferredEndpoint').value, endpointFallback: document.getElementById('endpointFallback').checked }) }); const d = await res.json(); if (d.success) { alert(t('settings.endpointSaved')); } else { alert(t('common.saveFailed') + ': ' + d.error); } @@ -2317,6 +2389,7 @@ '
' + t('modal.ssoTitle') + '
' + t('modal.ssoDesc') + '
' + '
' + t('modal.localTitle') + '
' + t('modal.localDesc') + '
' + '
' + t('modal.credentialsTitle') + '
' + t('modal.credentialsDesc') + '
' + + '
' + t('modal.cookieTitle') + '
' + t('modal.cookieDesc') + '
' + '
' + ''; } else if (type === 'builderid') { @@ -2348,6 +2421,13 @@ '
' + t('credentials.batchHint') + '
' + '
' + ''; + } else if (type === 'cookie') { + title.textContent = t('modal.cookieTitle'); + body.innerHTML = + '

' + t('cookie.howToGet') + '

  1. ' + t('cookie.step1') + ' ' + t('cookie.link') + '
  2. ' + t('cookie.step2') + '
  3. ' + t('cookie.step3') + '
' + + '
' + + '
' + + ''; } else if (type === 'sso') { title.textContent = t('modal.ssoTitle'); body.innerHTML = @@ -2458,6 +2538,16 @@ newIds.forEach(id => autoRefreshNewAccount(id)); } catch (e) { alert(t('credentials.jsonError')); } } + async function importFromCookie() { + const refreshToken = document.getElementById('cookieRefreshToken').value.trim(); + if (!refreshToken) { alert(t('cookie.refreshTokenMissing')); return; } + const provider = document.getElementById('cookieProvider').value; + const payload = { refreshToken, accessToken: '', clientId: '', clientSecret: '', authMethod: 'social', provider }; + const res = await fetch('/admin/api/auth/credentials', { method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, body: JSON.stringify(payload) }); + const d = await res.json(); + if (d.success) { closeModal(); loadAccounts(); loadStats(); alert(t('cookie.importSuccess') + ': ' + (d.account?.email || d.account?.id)); autoRefreshNewAccount(d.account?.id); } + else alert(t('common.failed') + ': ' + d.error); + } async function importSsoToken() { const res = await fetch('/admin/api/auth/sso-token', { method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, body: JSON.stringify({ bearerToken: document.getElementById('ssoToken').value, region: document.getElementById('ssoRegion').value }) }); const d = await res.json(); @@ -2700,4 +2790,4 @@ - \ No newline at end of file +