diff --git a/auth/oidc.go b/auth/oidc.go index 329ef23..40d3456 100644 --- a/auth/oidc.go +++ b/auth/oidc.go @@ -20,6 +20,9 @@ func RefreshToken(account *config.Account) (string, string, int64, error) { // refreshOIDCToken IdC/Builder ID token 刷新 func refreshOIDCToken(refreshToken, clientID, clientSecret, region string) (string, string, int64, error) { + if clientID == "" || clientSecret == "" { + return "", "", 0, fmt.Errorf("OIDC refresh requires clientId and clientSecret") + } if region == "" { region = "us-east-1" } diff --git a/config/config.go b/config/config.go index d62dcb8..e2a7c58 100644 --- a/config/config.go +++ b/config/config.go @@ -50,6 +50,9 @@ type Account struct { ExpiresAt int64 `json:"expiresAt,omitempty"` // Token expiration timestamp (Unix seconds) MachineId string `json:"machineId,omitempty"` // UUID machine identifier for request tracking + // Priority weight for load balancing (higher = more requests) + Weight int `json:"weight,omitempty"` // 0 or 1 = normal, 2+ = higher priority + // Account status Enabled bool `json:"enabled"` // Whether account is active in the pool BanStatus string `json:"banStatus,omitempty"` // Ban status: "ACTIVE", "BANNED", "SUSPENDED" diff --git a/pool/account.go b/pool/account.go index f259a25..0f1a2f1 100644 --- a/pool/account.go +++ b/pool/account.go @@ -36,13 +36,25 @@ func GetPool() *AccountPool { } // Reload 从配置重新加载账号 +// 构建加权列表:weight<=1 出现 1 次,weight>=2 出现 weight 次 func (p *AccountPool) Reload() { p.mu.Lock() defer p.mu.Unlock() - p.accounts = config.GetEnabledAccounts() + enabled := config.GetEnabledAccounts() + var weighted []config.Account + for _, a := range enabled { + w := a.Weight + if w < 1 { + w = 1 + } + for j := 0; j < w; j++ { + weighted = append(weighted, a) + } + } + p.accounts = weighted } -// GetNext 获取下一个可用账号(轮询) +// GetNext 获取下一个可用账号(加权轮询) func (p *AccountPool) GetNext() *config.Account { p.mu.RLock() defer p.mu.RUnlock() @@ -53,30 +65,47 @@ func (p *AccountPool) GetNext() *config.Account { now := time.Now() n := len(p.accounts) + seen := make(map[string]bool) - // 轮询查找可用账号 + // 加权轮询查找可用账号 for i := 0; i < n; i++ { idx := atomic.AddUint64(&p.currentIndex, 1) % uint64(n) acc := &p.accounts[idx] + if seen[acc.ID] { + continue + } + // 跳过冷却中的账号 if cooldown, ok := p.cooldowns[acc.ID]; ok && now.Before(cooldown) { + seen[acc.ID] = true continue } // 跳过即将过期的 Token if acc.ExpiresAt > 0 && time.Now().Unix() > acc.ExpiresAt-300 { + seen[acc.ID] = true + continue + } + + // 跳过额度已用尽的账号(适用于所有订阅类型) + if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit { + seen[acc.ID] = true continue } return acc } - // 无可用账号,返回冷却时间最短的 + // 无可用账号,返回冷却时间最短的(排除额度用尽的) var best *config.Account var earliest time.Time for i := range p.accounts { acc := &p.accounts[i] + // 额度用尽的账号不作为 fallback + if acc.UsageLimit > 0 && acc.UsageCurrent >= acc.UsageLimit { + continue + } if cooldown, ok := p.cooldowns[acc.ID]; ok { if best == nil || cooldown.Before(earliest) { best = acc diff --git a/proxy/handler.go b/proxy/handler.go index 2559406..7e15e61 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -1540,6 +1540,8 @@ func (h *Handler) handleAdminAPI(w http.ResponseWriter, r *http.Request) { h.apiGetAccounts(w, r) case path == "/accounts" && r.Method == "POST": h.apiAddAccount(w, r) + case path == "/accounts/batch" && r.Method == "POST": + h.apiBatchAccounts(w, r) case strings.HasPrefix(path, "/accounts/") && strings.HasSuffix(path, "/refresh") && r.Method == "POST": id := strings.TrimSuffix(strings.TrimPrefix(path, "/accounts/"), "/refresh") h.apiRefreshAccount(w, r, id) @@ -1626,6 +1628,7 @@ func (h *Handler) apiGetAccounts(w http.ResponseWriter, r *http.Request) { "expiresAt": a.ExpiresAt, "hasToken": a.AccessToken != "", "machineId": a.MachineId, + "weight": a.Weight, "subscriptionType": a.SubscriptionType, "subscriptionTitle": a.SubscriptionTitle, "daysRemaining": a.DaysRemaining, @@ -1717,6 +1720,9 @@ func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id st if v, ok := updates["machineId"].(string); ok { existing.MachineId = v } + if v, ok := updates["weight"].(float64); ok { + existing.Weight = int(v) + } if err := config.UpdateAccount(id, *existing); err != nil { w.WriteHeader(500) @@ -1728,6 +1734,95 @@ func (h *Handler) apiUpdateAccount(w http.ResponseWriter, r *http.Request, id st json.NewEncoder(w).Encode(map[string]bool{"success": true}) } +// apiBatchAccounts 批量操作账号(启用/禁用/刷新) +func (h *Handler) apiBatchAccounts(w http.ResponseWriter, r *http.Request) { + var req struct { + IDs []string `json:"ids"` + Action string `json:"action"` // "enable", "disable", "refresh" + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid JSON"}) + return + } + if len(req.IDs) == 0 { + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "No account IDs provided"}) + return + } + + switch req.Action { + case "enable", "disable": + enabled := req.Action == "enable" + accounts := config.GetAccounts() + idSet := make(map[string]bool) + for _, id := range req.IDs { + idSet[id] = true + } + for _, a := range accounts { + if idSet[a.ID] { + a.Enabled = enabled + if enabled && a.BanStatus != "" && a.BanStatus != "ACTIVE" { + a.BanStatus = "ACTIVE" + a.BanReason = "" + a.BanTime = 0 + } + config.UpdateAccount(a.ID, a) + } + } + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{"success": true, "count": len(req.IDs)}) + + case "refresh": + successCount := 0 + failCount := 0 + for _, id := range req.IDs { + accounts := config.GetAccounts() + var account *config.Account + for i := range accounts { + if accounts[i].ID == id { + account = &accounts[i] + break + } + } + if account == nil { + failCount++ + continue + } + // 刷新 token + if account.RefreshToken != "" { + if newAccess, newRefresh, newExpires, err := auth.RefreshToken(account); err == nil { + account.AccessToken = newAccess + if newRefresh != "" { + account.RefreshToken = newRefresh + } + account.ExpiresAt = newExpires + config.UpdateAccountToken(id, newAccess, newRefresh, newExpires) + h.pool.UpdateToken(id, newAccess, newRefresh, newExpires) + } + } + // 刷新账户信息 + info, err := RefreshAccountInfo(account) + if err != nil { + failCount++ + continue + } + config.UpdateAccountInfo(id, *info) + successCount++ + } + h.pool.Reload() + json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "refreshed": successCount, + "failed": failCount, + }) + + default: + w.WriteHeader(400) + json.NewEncoder(w).Encode(map[string]string{"error": "Invalid action: " + req.Action}) + } +} + func (h *Handler) apiStartIamSso(w http.ResponseWriter, r *http.Request) { var req struct { StartUrl string `json:"startUrl"` diff --git a/proxy/kiro.go b/proxy/kiro.go index 1a6f53a..a58eff8 100644 --- a/proxy/kiro.go +++ b/proxy/kiro.go @@ -16,7 +16,7 @@ import ( "github.com/google/uuid" ) -const KiroVersion = "0.6.18" +const KiroVersion = "0.7.45" // 双端点配置(429 时自动 fallback) type kiroEndpoint struct { @@ -168,11 +168,11 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt machineId := account.MachineId var userAgent, amzUserAgent string if machineId != "" { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s-%s", KiroVersion, machineId) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s %s", KiroVersion, machineId) + userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s-%s", KiroVersion, machineId) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s %s", KiroVersion, machineId) } else { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/linux lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", KiroVersion) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s", KiroVersion) + userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s", KiroVersion) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s", KiroVersion) } // 根据配置排序端点 @@ -195,7 +195,7 @@ func CallKiroAPI(account *config.Account, payload *KiroPayload, callback *KiroSt req.Header.Set("X-Amz-Target", ep.AmzTarget) req.Header.Set("User-Agent", userAgent) req.Header.Set("X-Amz-User-Agent", amzUserAgent) - req.Header.Set("x-amzn-kiro-agent-mode", "spec") + req.Header.Set("x-amzn-kiro-agent-mode", "vibe") req.Header.Set("x-amzn-codewhisperer-optout", "true") req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) diff --git a/proxy/kiro_api.go b/proxy/kiro_api.go index fde124d..7252182 100644 --- a/proxy/kiro_api.go +++ b/proxy/kiro_api.go @@ -12,7 +12,7 @@ import ( const ( kiroRestAPIBase = "https://codewhisperer.us-east-1.amazonaws.com" - kiroVersion = "0.6.18" + kiroVersion = "0.7.45" ) // GetUsageLimits 获取账户使用量和订阅信息 @@ -113,11 +113,11 @@ func setKiroHeaders(req *http.Request, account *config.Account) { machineId := account.MachineId var userAgent, amzUserAgent string if machineId != "" { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s-%s", kiroVersion, machineId) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE %s %s", kiroVersion, machineId) + userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s-%s", kiroVersion, machineId) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s %s", kiroVersion, machineId) } else { - userAgent = fmt.Sprintf("aws-sdk-js/1.0.18 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-%s", kiroVersion) - amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.18 KiroIDE-%s", kiroVersion) + userAgent = fmt.Sprintf("aws-sdk-js/1.0.27 ua/2.1 os/linux lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E KiroIDE-%s", kiroVersion) + amzUserAgent = fmt.Sprintf("aws-sdk-js/1.0.27 KiroIDE %s", kiroVersion) } req.Header.Set("Authorization", "Bearer "+account.AccessToken) diff --git a/web/index.html b/web/index.html index a46823c..bda7341 100644 --- a/web/index.html +++ b/web/index.html @@ -532,7 +532,7 @@ .account-stats { display: grid; - grid-template-columns: repeat(4, 1fr); + grid-template-columns: repeat(5, 1fr); gap: 6px; margin-top: 10px; padding-top: 10px; @@ -932,6 +932,26 @@ +
' + t('accounts.empty') + '
'; return; } - container.innerHTML = accountsData.map(a => { + container.innerHTML = filtered.map(a => { const usagePercent = (a.usagePercent || 0) * 100; const usageClass = usagePercent > 90 ? 'critical' : usagePercent > 70 ? 'high' : ''; const trialUsagePercent = (a.trialUsagePercent || 0) * 100; const trialUsageClass = trialUsagePercent > 90 ? 'critical' : trialUsagePercent > 70 ? 'high' : ''; - return '