diff --git a/config/config.go b/config/config.go index 22cfd48..134aa3c 100644 --- a/config/config.go +++ b/config/config.go @@ -108,6 +108,9 @@ type Config struct { // Endpoint configuration: "auto", "codewhisperer", or "amazonq" PreferredEndpoint string `json:"preferredEndpoint,omitempty"` + // General behavior settings + InvalidModelRetries int `json:"invalidModelRetries,omitempty"` // Same-endpoint retry count on INVALID_MODEL_ID (default: 3) + // Global statistics (persisted across restarts) TotalRequests int `json:"totalRequests,omitempty"` // Total API requests received SuccessRequests int `json:"successRequests,omitempty"` // Successful requests count @@ -445,6 +448,30 @@ func UpdatePreferredEndpoint(endpoint string) error { return Save() } +// GetInvalidModelRetries 返回 INVALID_MODEL_ID 同端点重试次数(默认 3) +func GetInvalidModelRetries() int { + cfgLock.RLock() + defer cfgLock.RUnlock() + if cfg.InvalidModelRetries < 0 { + return 0 + } + if cfg.InvalidModelRetries == 0 { + return 3 + } + return cfg.InvalidModelRetries +} + +// UpdateInvalidModelRetries 更新 INVALID_MODEL_ID 同端点重试次数 +func UpdateInvalidModelRetries(n int) error { + cfgLock.Lock() + defer cfgLock.Unlock() + if n < 0 { + n = 0 + } + cfg.InvalidModelRetries = n + return Save() +} + type KiroClientConfig struct { KiroVersion string SystemVersion string diff --git a/proxy/handler.go b/proxy/handler.go index 85afc5e..3332b08 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -1781,6 +1781,10 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) { h.apiGetEndpointConfig(w, r) case path == "/endpoint" && r.Method == "POST": h.apiUpdateEndpointConfig(w, r) + case path == "/general" && r.Method == "GET": + h.apiGetGeneralConfig(w, r) + case path == "/general" && r.Method == "POST": + h.apiUpdateGeneralConfig(w, r) case path == "/version" && r.Method == "GET": h.apiGetVersion(w, r) case path == "/export" && r.Method == "POST": @@ -2745,6 +2749,41 @@ func (h *Handler) apiUpdateEndpointConfig(w http.ResponseWriter, r *http.Request json.NewEncoder(w).Encode(map[string]bool{"success": true}) } +// apiGetGeneralConfig 获取通用设置 +func (h *Handler) apiGetGeneralConfig(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]interface{}{ + "invalidModelRetries": config.GetInvalidModelRetries(), + }) +} + +// apiUpdateGeneralConfig 更新通用设置 +func (h *Handler) apiUpdateGeneralConfig(w http.ResponseWriter, r *http.Request) { + var req struct { + InvalidModelRetries *int `json:"invalidModelRetries"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + + if req.InvalidModelRetries != nil { + n := *req.InvalidModelRetries + if n < 0 || n > 20 { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "invalidModelRetries must be 0-20"}) + return + } + if err := config.UpdateInvalidModelRetries(n); err != nil { + w.WriteHeader(500) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + } + + json.NewEncoder(w).Encode(map[string]bool{"success": true}) +} + // apiGetVersion 获取版本信息 func (h *Handler) apiGetVersion(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{ diff --git a/proxy/kiro.go b/proxy/kiro.go index 7fcaa64..8b00589 100644 --- a/proxy/kiro.go +++ b/proxy/kiro.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "kiro-go/config" + "log" "net/http" "net/url" "strconv" @@ -163,73 +164,130 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt return err } + modelID := payload.ConversationState.CurrentMessage.UserInputMessage.ModelID + accountLabel := account.Email + if accountLabel == "" { + accountLabel = account.ID + } + // 根据配置排序端点 endpoints := getSortedEndpoints(config.GetPreferredEndpoint()) + invalidModelRetries := config.GetInvalidModelRetries() + + endpointNames := make([]string, 0, len(endpoints)) + for _, ep := range endpoints { + endpointNames = append(endpointNames, ep.Name) + } + log.Printf("[KiroAPI] request start account=%s model=%q endpoints=[%s] invalid_model_retries=%d", accountLabel, modelID, strings.Join(endpointNames, ","), invalidModelRetries) + + requestStart := time.Now() var lastErr error for _, ep := range endpoints { // 更新 payload 中的 origin payload.ConversationState.CurrentMessage.UserInputMessage.Origin = ep.Origin - reqBody, _ := json.Marshal(payload) - req, err := http.NewRequest("POST", ep.URL, bytes.NewReader(reqBody)) - if err != nil { - lastErr = err - continue - } - - host := "" - if parsedURL, parseErr := url.Parse(ep.URL); parseErr == nil { - host = parsedURL.Host - } - headerValues := buildStreamingHeaderValues(account, host) - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "*/*") - req.Header.Set("X-Amz-Target", ep.AmzTarget) - applyKiroBaseHeaders(req, account, headerValues) - req.Header.Set("x-amzn-kiro-agent-mode", "vibe") - req.Header.Set("x-amzn-codewhisperer-optout", "true") - req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") - req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - - resp, err := kiroHttpClient.Do(req) - if err != nil { - lastErr = err - fmt.Printf("[KiroAPI] Endpoint %s failed: %v\n", ep.Name, err) - continue - } - - if resp.StatusCode == 429 { - resp.Body.Close() - fmt.Printf("[KiroAPI] Endpoint %s quota exhausted (429), trying next...\n", ep.Name) - lastErr = fmt.Errorf("quota exhausted on %s", ep.Name) - continue - } - - if resp.StatusCode != 200 { - errBody, _ := io.ReadAll(resp.Body) - resp.Body.Close() - lastErr = fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, ep.Name, string(errBody)) - // 认证错误不继续尝试 - if resp.StatusCode == 401 || resp.StatusCode == 403 { - return lastErr + // 单端点内重试循环:INVALID_MODEL_ID 时同端点重试 + maxAttempts := invalidModelRetries + 1 + shouldFallback := false + for attempt := 1; attempt <= maxAttempts; attempt++ { + reqBody, _ := json.Marshal(payload) + req, err := http.NewRequest("POST", ep.URL, bytes.NewReader(reqBody)) + if err != nil { + lastErr = err + shouldFallback = true + break } - fmt.Printf("[KiroAPI] Endpoint %s error: %v\n", ep.Name, lastErr) - continue + + host := "" + if parsedURL, parseErr := url.Parse(ep.URL); parseErr == nil { + host = parsedURL.Host + } + headerValues := buildStreamingHeaderValues(account, host) + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + req.Header.Set("X-Amz-Target", ep.AmzTarget) + applyKiroBaseHeaders(req, account, headerValues) + req.Header.Set("x-amzn-kiro-agent-mode", "vibe") + req.Header.Set("x-amzn-codewhisperer-optout", "true") + req.Header.Set("Amz-Sdk-Request", fmt.Sprintf("attempt=%d; max=%d", attempt, maxAttempts)) + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + attemptStart := time.Now() + log.Printf("[KiroAPI] try endpoint=%s attempt=%d/%d account=%s model=%q origin=%s", ep.Name, attempt, maxAttempts, accountLabel, modelID, ep.Origin) + + resp, err := kiroHttpClient.Do(req) + if err != nil { + lastErr = err + log.Printf("[KiroAPI] endpoint=%s attempt=%d/%d account=%s model=%q transport_error elapsed=%s err=%v", ep.Name, attempt, maxAttempts, accountLabel, modelID, time.Since(attemptStart), err) + shouldFallback = true + break + } + + if resp.StatusCode == 429 { + resp.Body.Close() + log.Printf("[KiroAPI] endpoint=%s attempt=%d/%d account=%s model=%q status=429 quota_exhausted elapsed=%s", ep.Name, attempt, maxAttempts, accountLabel, modelID, time.Since(attemptStart)) + lastErr = fmt.Errorf("quota exhausted on %s", ep.Name) + shouldFallback = true + break + } + + if resp.StatusCode != 200 { + errBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + lastErr = fmt.Errorf("HTTP %d from %s: %s", resp.StatusCode, ep.Name, string(errBody)) + bodyStr := string(errBody) + + // 认证错误不继续尝试 + if resp.StatusCode == 401 || resp.StatusCode == 403 { + log.Printf("[KiroAPI] endpoint=%s attempt=%d/%d account=%s model=%q status=%d auth_error elapsed=%s body=%s", ep.Name, attempt, maxAttempts, accountLabel, modelID, resp.StatusCode, time.Since(attemptStart), truncateForLog(bodyStr, 300)) + return lastErr + } + + // INVALID_MODEL_ID: 同端点再试 + if resp.StatusCode == 400 && strings.Contains(bodyStr, "INVALID_MODEL_ID") { + if attempt < maxAttempts { + log.Printf("[KiroAPI] endpoint=%s attempt=%d/%d account=%s model=%q status=400 INVALID_MODEL_ID retrying elapsed=%s", ep.Name, attempt, maxAttempts, accountLabel, modelID, time.Since(attemptStart)) + continue + } + log.Printf("[KiroAPI] endpoint=%s attempt=%d/%d account=%s model=%q status=400 INVALID_MODEL_ID exhausted, fallback elapsed=%s", ep.Name, attempt, maxAttempts, accountLabel, modelID, time.Since(attemptStart)) + shouldFallback = true + break + } + + log.Printf("[KiroAPI] endpoint=%s attempt=%d/%d account=%s model=%q status=%d elapsed=%s body=%s", ep.Name, attempt, maxAttempts, accountLabel, modelID, resp.StatusCode, time.Since(attemptStart), truncateForLog(bodyStr, 300)) + shouldFallback = true + break + } + + log.Printf("[KiroAPI] endpoint=%s attempt=%d/%d account=%s model=%q status=200 headers_elapsed=%s, streaming...", ep.Name, attempt, maxAttempts, accountLabel, modelID, time.Since(attemptStart)) + err = parseEventStream(resp.Body, callback) + resp.Body.Close() + log.Printf("[KiroAPI] endpoint=%s account=%s model=%q done total_elapsed=%s err=%v", ep.Name, accountLabel, modelID, time.Since(requestStart), err) + return err } - err = parseEventStream(resp.Body, callback) - resp.Body.Close() - return err + if !shouldFallback { + break + } } + log.Printf("[KiroAPI] all endpoints failed account=%s model=%q total_elapsed=%s last_err=%v", accountLabel, modelID, time.Since(requestStart), lastErr) if lastErr != nil { return lastErr } return fmt.Errorf("all endpoints failed") } +func truncateForLog(s string, max int) string { + s = strings.ReplaceAll(s, "\n", " ") + if len(s) <= max { + return s + } + return s[:max] + "...(truncated)" +} + // ==================== Event Stream 解析 ==================== // parseEventStream 解析 AWS Event Stream 二进制格式 diff --git a/web/index.html b/web/index.html index bda7341..4c08fb4 100644 --- a/web/index.html +++ b/web/index.html @@ -969,6 +969,17 @@ data-i18n-placeholder="settings.apiKeyPlaceholder"> +
+
+
+ + + +
+ +
@@ -1123,6 +1134,11 @@ 'settings.enableApiKey': '启用 API Key 验证', 'settings.apiKeyPlaceholder': '留空则不验证', 'settings.generateApiKey': '随机生成', + 'settings.generalSettings': '通用设置', + 'settings.invalidModelRetries': 'INVALID_MODEL_ID 同端点重试次数', + 'settings.invalidModelRetriesHint': '当上游返回 INVALID_MODEL_ID(HTTP 400)时,先在当前端点重试 N 次后再 fallback 到下一个端点。默认 3,范围 0-20', + 'settings.saveGeneral': '保存通用设置', + 'settings.generalSaved': '通用设置已保存', 'settings.thinkingSettings': 'Thinking 模式设置', 'settings.thinkingSuffix': '触发后缀', 'settings.thinkingSuffixHint': '模型名称加此后缀即启用思考模式,如 claude-sonnet-4.5-thinking', @@ -1329,6 +1345,11 @@ 'settings.enableApiKey': 'Enable API Key Verification', 'settings.apiKeyPlaceholder': 'Leave empty to disable', 'settings.generateApiKey': 'Generate', + 'settings.generalSettings': 'General Settings', + 'settings.invalidModelRetries': 'INVALID_MODEL_ID same-endpoint retries', + 'settings.invalidModelRetriesHint': 'When upstream returns INVALID_MODEL_ID (HTTP 400), retry the current endpoint N times before falling back. Default 3, range 0-20', + 'settings.saveGeneral': 'Save General Settings', + 'settings.generalSaved': 'General settings saved', 'settings.thinkingSettings': 'Thinking Mode Settings', 'settings.thinkingSuffix': 'Trigger Suffix', 'settings.thinkingSuffixHint': 'Add this suffix to model name to enable thinking mode, e.g. claude-sonnet-4.5-thinking', @@ -1991,6 +2012,7 @@ document.getElementById('apiKeyInput').value = d.apiKey || ''; loadThinkingConfig(); loadEndpointConfig(); + loadGeneralConfig(); } async function loadThinkingConfig() { const res = await fetch('/admin/api/thinking', { headers: { 'X-Admin-Password': password } }); @@ -2020,6 +2042,26 @@ const d = await res.json(); if (d.success) { alert(t('settings.endpointSaved')); } else { alert(t('common.saveFailed') + ': ' + d.error); } } + async function loadGeneralConfig() { + const res = await fetch('/admin/api/general', { headers: { 'X-Admin-Password': password } }); + const d = await res.json(); + const v = (d && typeof d.invalidModelRetries === 'number') ? d.invalidModelRetries : 3; + document.getElementById('invalidModelRetries').value = v; + } + async function saveGeneralConfig() { + const raw = document.getElementById('invalidModelRetries').value; + const n = parseInt(raw, 10); + if (isNaN(n) || n < 0 || n > 20) { + alert(t('common.saveFailed') + ': 0-20'); + return; + } + const res = await fetch('/admin/api/general', { + method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Admin-Password': password }, + body: JSON.stringify({ invalidModelRetries: n }) + }); + const d = await res.json(); + if (d.success) { alert(t('settings.generalSaved')); } else { alert(t('common.saveFailed') + ': ' + d.error); } + } function generateApiKey() { const chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'; let key = 'sk-';