diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index c9dc57bb..f8a7d313 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -8,6 +8,7 @@ import ( "errors" "flag" "log" + "log/slog" "net/http" "os" "os/signal" @@ -44,7 +45,25 @@ func init() { } } +// initLogger configures the default slog handler based on gin.Mode(). +// In non-release mode, Debug level logs are enabled. +func initLogger() { + var level slog.Level + if gin.Mode() == gin.ReleaseMode { + level = slog.LevelInfo + } else { + level = slog.LevelDebug + } + handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + }) + slog.SetDefault(slog.New(handler)) +} + func main() { + // Initialize slog logger based on gin mode + initLogger() + // Parse command line flags setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode") showVersion := flag.Bool("version", false, "Show version information") diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 42084b37..00a78480 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -258,8 +258,43 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` + // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) + MaxAccountSwitches int `mapstructure:"max_account_switches"` + // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) + MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"` + + // Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用 + AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"` + // Scheduling: 账号调度相关配置 Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` + + // TLSFingerprint: TLS指纹伪装配置 + TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` +} + +// TLSFingerprintConfig TLS指纹伪装配置 +// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端 +type TLSFingerprintConfig struct { + // Enabled: 是否全局启用TLS指纹功能 + Enabled bool `mapstructure:"enabled"` + // Profiles: 预定义的TLS指纹配置模板 + // key 为模板名称,如 "claude_cli_v2", "chrome_120" 等 + Profiles map[string]TLSProfileConfig `mapstructure:"profiles"` +} + +// TLSProfileConfig 单个TLS指纹模板的配置 +type TLSProfileConfig struct { + // Name: 模板显示名称 + Name string `mapstructure:"name"` + // EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用) + EnableGREASE bool `mapstructure:"enable_grease"` + // CipherSuites: TLS加密套件列表(空则使用内置默认值) + CipherSuites []uint16 `mapstructure:"cipher_suites"` + // Curves: 椭圆曲线列表(空则使用内置默认值) + Curves []uint16 `mapstructure:"curves"` + // PointFormats: 点格式列表(空则使用内置默认值) + PointFormats []uint8 `mapstructure:"point_formats"` } // GatewaySchedulingConfig accounts scheduling configuration. @@ -272,6 +307,9 @@ type GatewaySchedulingConfig struct { FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` + // 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机) + FallbackSelectionMode string `mapstructure:"fallback_selection_mode"` + // 负载计算 LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` @@ -781,6 +819,9 @@ func setDefaults() { viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.inject_beta_for_apikey", false) viper.SetDefault("gateway.failover_on_400", false) + viper.SetDefault("gateway.max_account_switches", 10) + viper.SetDefault("gateway.max_account_switches_gemini", 3) + viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) @@ -793,11 +834,12 @@ func setDefaults() { viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_keepalive_interval", 10) - viper.SetDefault("gateway.max_line_size", 10*1024*1024) + viper.SetDefault("gateway.max_line_size", 40*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) + viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) @@ -809,6 +851,8 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) + // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) + viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) // TokenRefresh diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 33c91dae..10a53f56 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -173,6 +173,7 @@ func (h *AccountHandler) List(c *gin.Context) { // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能) windowCostAccountIDs := make([]int64, 0) sessionLimitAccountIDs := make([]int64, 0) + sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置 for i := range accounts { acc := &accounts[i] if acc.IsAnthropicOAuthOrSetupToken() { @@ -181,6 +182,7 @@ func (h *AccountHandler) List(c *gin.Context) { } if acc.GetMaxSessions() > 0 { sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) + sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute } } } @@ -189,9 +191,9 @@ func (h *AccountHandler) List(c *gin.Context) { var windowCosts map[int64]float64 var activeSessions map[int64]int - // 获取活跃会话数(批量查询) + // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { - activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs) + activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts) if activeSessions == nil { activeSessions = make(map[int64]int) } @@ -211,12 +213,8 @@ func (h *AccountHandler) List(c *gin.Context) { } accCopy := acc // 闭包捕获 g.Go(func() error { - var startTime time.Time - if accCopy.SessionWindowStart != nil { - startTime = *accCopy.SessionWindowStart - } else { - startTime = time.Now().Add(-5 * time.Hour) - } + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := accCopy.GetCurrentWindowStartTime() stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime) if err == nil && stats != nil { mu.Lock() @@ -545,6 +543,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) { newCredentials[k] = v } } + + // 如果 project_id 获取失败,先更新凭证,再标记账户为 error + if tokenInfo.ProjectIDMissing { + // 先更新凭证 + _, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + Credentials: newCredentials, + }) + if updateErr != nil { + response.InternalError(c, "Failed to update credentials: "+updateErr.Error()) + return + } + // 标记账户为 error + if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil { + response.InternalError(c, "Failed to set account error: "+setErr.Error()) + return + } + response.Success(c, gin.H{ + "message": "Token refreshed but project_id is missing, account marked as error", + "warning": "missing_project_id", + }) + return + } + + // 成功获取到 project_id,如果之前是 missing_project_id 错误则清除 + if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") { + if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil { + response.InternalError(c, "Failed to clear account error: "+clearErr.Error()) + return + } + } } else { // Use Anthropic/Claude OAuth service to refresh token tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 457d52fc..b820a3fb 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -200,6 +200,10 @@ func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*se return &account, nil } +func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error { + return nil +} + func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) { account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable} return &account, nil diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d8f10e6c..66b86ea0 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -161,6 +161,16 @@ func AccountFromServiceShallow(a *service.Account) *Account { if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 { out.SessionIdleTimeoutMin = &idleTimeout } + // TLS指纹伪装开关 + if a.IsTLSFingerprintEnabled() { + enabled := true + out.EnableTLSFingerprint = &enabled + } + // 会话ID伪装开关 + if a.IsSessionIDMaskingEnabled() { + enabled := true + out.EnableSessionIDMasking = &enabled + } } return out diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index ae9da254..4247dcbf 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -112,6 +112,15 @@ type Account struct { MaxSessions *int `json:"max_sessions,omitempty"` SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` + // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) + // 从 extra 字段提取,方便前端显示和编辑 + EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` + + // 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效) + // 启用后将在15分钟内固定 metadata.user_id 中的 session ID + // 从 extra 字段提取,方便前端显示和编辑 + EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 8c32be21..6c8d9ebe 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -31,6 +31,8 @@ type GatewayHandler struct { userService *service.UserService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + maxAccountSwitchesGemini int } // NewGatewayHandler creates a new GatewayHandler @@ -44,8 +46,16 @@ func NewGatewayHandler( cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) + maxAccountSwitches := 10 + maxAccountSwitchesGemini := 3 if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + if cfg.Gateway.MaxAccountSwitchesGemini > 0 { + maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini + } } return &GatewayHandler{ gatewayService: gatewayService, @@ -54,6 +64,8 @@ func NewGatewayHandler( userService: userService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), + maxAccountSwitches: maxAccountSwitches, + maxAccountSwitchesGemini: maxAccountSwitchesGemini, } } @@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } if platform == service.PlatformGemini { - const maxAccountSwitches = 3 + maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 @@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } - const maxAccountSwitches = 10 + maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index ec943e61..c7646b38 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if sessionHash != "" { sessionKey = "gemini:" + sessionHash } - const maxAccountSwitches = 3 + maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 68e67656..4c9dd8b9 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler( cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) + maxAccountSwitches := 3 if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } } return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, } } @@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Generate session hash (header first; fallback to prompt_cache_key) sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) - const maxAccountSwitches = 3 + maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 1248be95..a6279b11 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -16,15 +16,6 @@ import ( "time" ) -// resolveHost 从 URL 解析 host -func resolveHost(urlStr string) string { - parsed, err := url.Parse(urlStr) - if err != nil { - return "" - } - return parsed.Host -} - // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 @@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri return nil, err } - // 基础 Headers + // 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("User-Agent", UserAgent) - // Accept Header 根据请求类型设置 - if isStream { - req.Header.Set("Accept", "text/event-stream") - } else { - req.Header.Set("Accept", "application/json") - } - - // 显式设置 Host Header - if host := resolveHost(apiURL); host != "" { - req.Host = host - } - return req, nil } @@ -195,12 +174,15 @@ func isConnectionError(err error) bool { } // shouldFallbackToNextURL 判断是否应切换到下一个 URL -// 仅连接错误和 HTTP 429 触发 URL 降级 +// 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级 func shouldFallbackToNextURL(err error, statusCode int) bool { if isConnectionError(err) { return true } - return statusCode == http.StatusTooManyRequests + return statusCode == http.StatusTooManyRequests || + statusCode == http.StatusRequestTimeout || + statusCode == http.StatusNotFound || + statusCode >= 500 } // ExchangeCode 用 authorization code 交换 token @@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - // 获取可用的 URL 列表 - availableURLs := DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 - } + // 固定顺序:prod -> daily + availableURLs := BaseURLs var lastErr error for urlIdx, baseURL := range availableURLs { @@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC if err != nil { lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } @@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC // 检查是否需要 URL 降级 if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC var rawResp map[string]any _ = json.Unmarshal(respBodyBytes, &rawResp) + // 标记成功的 URL,下次优先使用 + DefaultURLAvailability.MarkSuccess(baseURL) return &loadResp, rawResp, nil } @@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - // 获取可用的 URL 列表 - availableURLs := DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 - } + // 固定顺序:prod -> daily + availableURLs := BaseURLs var lastErr error for urlIdx, baseURL := range availableURLs { @@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI if err != nil { lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } @@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI // 检查是否需要 URL 降级 if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI var rawResp map[string]any _ = json.Unmarshal(respBodyBytes, &rawResp) + // 标记成功的 URL,下次优先使用 + DefaultURLAvailability.MarkSuccess(baseURL) return &modelsResp, rawResp, nil } diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index f688332f..c1cc998c 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -143,9 +143,10 @@ type GeminiResponse struct { // GeminiCandidate Gemini 候选响应 type GeminiCandidate struct { - Content *GeminiContent `json:"content,omitempty"` - FinishReason string `json:"finishReason,omitempty"` - Index int `json:"index,omitempty"` + Content *GeminiContent `json:"content,omitempty"` + FinishReason string `json:"finishReason,omitempty"` + Index int `json:"index,omitempty"` + GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` } // GeminiUsageMetadata Gemini 用量元数据 @@ -156,6 +157,23 @@ type GeminiUsageMetadata struct { TotalTokenCount int `json:"totalTokenCount,omitempty"` } +// GeminiGroundingMetadata Gemini grounding 元数据(Web Search) +type GeminiGroundingMetadata struct { + WebSearchQueries []string `json:"webSearchQueries,omitempty"` + GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"` +} + +// GeminiGroundingChunk Gemini grounding chunk +type GeminiGroundingChunk struct { + Web *GeminiGroundingWeb `json:"web,omitempty"` +} + +// GeminiGroundingWeb Gemini grounding web 信息 +type GeminiGroundingWeb struct { + Title string `json:"title,omitempty"` + URI string `json:"uri,omitempty"` +} + // DefaultSafetySettings 默认安全设置(关闭所有过滤) var DefaultSafetySettings = []GeminiSafetySetting{ {Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"}, diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 736c45df..ee2a6c1a 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -32,8 +32,8 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // User-Agent(模拟官方客户端) - UserAgent = "antigravity/1.104.0 darwin/arm64" + // User-Agent(与 Antigravity-Manager 保持一致) + UserAgent = "antigravity/1.11.9 windows/amd64" // Session 过期时间 SessionTTL = 30 * time.Minute @@ -42,22 +42,21 @@ const ( URLAvailabilityTTL = 5 * time.Minute ) -// BaseURLs 定义 Antigravity API 端点,按优先级排序 -// fallback 顺序: sandbox → daily → prod +// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) var BaseURLs = []string{ - "https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox - "https://daily-cloudcode-pa.googleapis.com", // daily - "https://cloudcode-pa.googleapis.com", // prod + "https://cloudcode-pa.googleapis.com", // prod (优先) + "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用) } // BaseURL 默认 URL(保持向后兼容) var BaseURL = BaseURLs[0] -// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) +// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) type URLAvailability struct { mu sync.RWMutex unavailable map[string]time.Time // URL -> 恢复时间 ttl time.Duration + lastSuccess string // 最近成功请求的 URL,优先使用 } // DefaultURLAvailability 全局 URL 可用性管理器 @@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) { u.unavailable[url] = time.Now().Add(u.ttl) } +// MarkSuccess 标记 URL 请求成功,将其设为优先使用 +func (u *URLAvailability) MarkSuccess(url string) { + u.mu.Lock() + defer u.mu.Unlock() + u.lastSuccess = url + // 成功后清除该 URL 的不可用标记 + delete(u.unavailable, url) +} + // IsAvailable 检查 URL 是否可用 func (u *URLAvailability) IsAvailable(url string) bool { u.mu.RLock() @@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool { return time.Now().After(expiry) } -// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) +// GetAvailableURLs 返回可用的 URL 列表 +// 最近成功的 URL 优先,其他按默认顺序 func (u *URLAvailability) GetAvailableURLs() []string { u.mu.RLock() defer u.mu.RUnlock() now := time.Now() result := make([]string, 0, len(BaseURLs)) + + // 如果有最近成功的 URL 且可用,放在最前面 + if u.lastSuccess != "" { + expiry, exists := u.unavailable[u.lastSuccess] + if !exists || now.After(expiry) { + result = append(result, u.lastSuccess) + } + } + + // 添加其他可用的 URL(按默认顺序) for _, url := range BaseURLs { + // 跳过已添加的 lastSuccess + if url == u.lastSuccess { + continue + } expiry, exists := u.unavailable[url] if !exists || now.After(expiry) { result = append(result, url) @@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string { return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) } - -// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用) -// 格式:{形容词}-{名词}-{5位随机字符} -func GenerateMockProjectID() string { - adjectives := []string{"useful", "bright", "swift", "calm", "bold"} - nouns := []string{"fuze", "wave", "spark", "flow", "core"} - - randBytes, _ := GenerateRandomBytes(7) - - adj := adjectives[int(randBytes[0])%len(adjectives)] - noun := nouns[int(randBytes[1])%len(nouns)] - - // 生成 5 位随机字符(a-z0-9) - const charset = "abcdefghijklmnopqrstuvwxyz0123456789" - suffix := make([]byte, 5) - for i := 0; i < 5; i++ { - suffix[i] = charset[int(randBytes[i+2])%len(charset)] - } - - return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix)) -} diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index a8474576..637a4ea8 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions { } } +// webSearchFallbackModel web_search 请求使用的降级模型 +const webSearchFallbackModel = "gemini-2.5-flash" + // TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式 func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) { return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions()) @@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) + // 检测是否有 web_search 工具 + hasWebSearchTool := hasWebSearchTool(claudeReq.Tools) + requestType := "agent" + targetModel := mappedModel + if hasWebSearchTool { + requestType = "web_search" + if targetModel != webSearchFallbackModel { + targetModel = webSearchFallbackModel + } + } + // 检测是否启用 thinking isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures - allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") + allowDummyThought := strings.HasPrefix(targetModel, "gemini-") // 1. 构建 contents contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) @@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map } // 2. 构建 systemInstruction - systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts) + systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools) // 3. 构建 generationConfig reqForConfig := claudeReq @@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map reqCopy.Thinking = nil reqForConfig = &reqCopy } + if targetModel != "" && targetModel != reqForConfig.Model { + reqCopy := *reqForConfig + reqCopy.Model = targetModel + reqForConfig = &reqCopy + } generationConfig := buildGenerationConfig(reqForConfig) // 4. 构建 tools @@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map Project: projectID, RequestID: "agent-" + uuid.New().String(), UserAgent: "antigravity", // 固定值,与官方客户端一致 - RequestType: "agent", - Model: mappedModel, + RequestType: requestType, + Model: targetModel, Request: innerRequest, } @@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string { return antigravityIdentity } -// buildSystemInstruction 构建 systemInstruction -func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent { +// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致) +const mcpXMLProtocol = ` +==== MCP XML 工具调用协议 (Workaround) ==== +当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时: +1) 优先尝试 XML 格式调用:输出 ` + "`{\"arg\":\"value\"}`" + `。 +2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。 +3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。 +===========================================` + +// hasMCPTools 检测是否有 mcp__ 前缀的工具 +func hasMCPTools(tools []ClaudeTool) bool { + for _, tool := range tools { + if strings.HasPrefix(tool.Name, "mcp__") { + return true + } + } + return false +} + +// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令 +func filterOpenCodePrompt(text string) string { + if !strings.Contains(text, "You are an interactive CLI tool") { + return text + } + // 提取 "Instructions from:" 及之后的部分 + if idx := strings.Index(text, "Instructions from:"); idx >= 0 { + return text[idx:] + } + // 如果没有自定义指令,返回空 + return "" +} + +// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致) +func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent { var parts []GeminiPart // 先解析用户的 system prompt,检测是否已包含 Antigravity identity @@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans var sysStr string if err := json.Unmarshal(system, &sysStr); err == nil { if strings.TrimSpace(sysStr) != "" { - userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr}) if strings.Contains(sysStr, "You are Antigravity") { userHasAntigravityIdentity = true } + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(sysStr) + if filtered != "" { + userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) + } } } else { // 尝试解析为数组 @@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if err := json.Unmarshal(system, &sysBlocks); err == nil { for _, block := range sysBlocks { if block.Type == "text" && strings.TrimSpace(block.Text) != "" { - userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text}) if strings.Contains(block.Text, "You are Antigravity") { userHasAntigravityIdentity = true } + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(block.Text) + if filtered != "" { + userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) + } } } } @@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans // 添加用户的 system prompt parts = append(parts, userSystemParts...) + // 检测是否有 MCP 工具,如有则注入 XML 调用协议 + if hasMCPTools(tools) { + parts = append(parts, GeminiPart{Text: mcpXMLProtocol}) + } + + // 如果用户没有提供 Antigravity 身份,添加结束标记 + if !userHasAntigravityIdentity { + parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"}) + } + if len(parts) == 0 { return nil } @@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { StopSequences: DefaultStopSequences, } + // 如果请求中指定了 MaxTokens,使用请求值 + if req.MaxTokens > 0 { + config.MaxOutputTokens = req.MaxTokens + } + // Thinking 配置 if req.Thinking != nil && req.Thinking.Type == "enabled" { config.ThinkingConfig = &GeminiThinkingConfig{ @@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { return config } +func hasWebSearchTool(tools []ClaudeTool) bool { + for _, tool := range tools { + if isWebSearchTool(tool) { + return true + } + } + return false +} + +func isWebSearchTool(tool ClaudeTool) bool { + if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" { + return true + } + + name := strings.TrimSpace(tool.Name) + switch name { + case "web_search", "google_search", "web_search_20250305": + return true + default: + return false + } +} + // buildTools 构建 tools func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { if len(tools) == 0 { return nil } - // 检查是否有 web_search 工具 - hasWebSearch := false - for _, tool := range tools { - if tool.Name == "web_search" { - hasWebSearch = true - break - } - } - - if hasWebSearch { - // Web Search 工具映射 - return []GeminiToolDeclaration{{ - GoogleSearch: &GeminiGoogleSearch{ - EnhancedContent: &GeminiEnhancedContent{ - ImageSearch: &GeminiImageSearch{ - MaxResultCount: 5, - }, - }, - }, - }} - } + hasWebSearch := hasWebSearchTool(tools) // 普通工具 var funcDecls []GeminiFunctionDecl for _, tool := range tools { + if isWebSearchTool(tool) { + continue + } // 跳过无效工具名称 if strings.TrimSpace(tool.Name) == "" { log.Printf("Warning: skipping tool with empty name") @@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } if len(funcDecls) == 0 { - return nil + if !hasWebSearch { + return nil + } + + // Web Search 工具映射 + return []GeminiToolDeclaration{{ + GoogleSearch: &GeminiGoogleSearch{ + EnhancedContent: &GeminiEnhancedContent{ + ImageSearch: &GeminiImageSearch{ + MaxResultCount: 5, + }, + }, + }, + }} } return []GeminiToolDeclaration{{ diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index cd7f5f80..04424c03 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -3,6 +3,7 @@ package antigravity import ( "encoding/json" "fmt" + "strings" ) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) @@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, p.processPart(&part) } + if len(geminiResp.Candidates) > 0 { + if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil { + p.processGrounding(grounding) + } + } + // 刷新剩余内容 p.flushThinking() p.flushText() @@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) { } } +func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) { + groundingText := buildGroundingText(grounding) + if groundingText == "" { + return + } + + p.flushThinking() + p.flushText() + p.textBuilder += groundingText + p.flushText() +} + // flushText 刷新 text builder func (p *NonStreamingProcessor) flushText() { if p.textBuilder == "" { @@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon } } +func buildGroundingText(grounding *GeminiGroundingMetadata) string { + if grounding == nil { + return "" + } + + var builder strings.Builder + + if len(grounding.WebSearchQueries) > 0 { + _, _ = builder.WriteString("\n\n---\nWeb search queries: ") + _, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", ")) + } + + if len(grounding.GroundingChunks) > 0 { + var links []string + for i, chunk := range grounding.GroundingChunks { + if chunk.Web == nil { + continue + } + title := strings.TrimSpace(chunk.Web.Title) + if title == "" { + title = "Source" + } + uri := strings.TrimSpace(chunk.Web.URI) + if uri == "" { + uri = "#" + } + links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri)) + } + + if len(links) > 0 { + _, _ = builder.WriteString("\n\nSources:\n") + _, _ = builder.WriteString(strings.Join(links, "\n")) + } + } + + return builder.String() +} + // generateRandomID 生成随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index 9fe68a11..da0c6f97 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -27,6 +27,8 @@ type StreamingProcessor struct { pendingSignature string trailingSignature string originalModel string + webSearchQueries []string + groundingChunks []GeminiGroundingChunk // 累计 usage inputTokens int @@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { } } + if len(geminiResp.Candidates) > 0 { + p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata) + } + // 检查是否结束 if len(geminiResp.Candidates) > 0 { finishReason := geminiResp.Candidates[0].FinishReason @@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte { return result.Bytes() } +func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) { + if grounding == nil { + return + } + + if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 { + p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...) + } + + if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 { + p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...) + } +} + // processThinking 处理 thinking func (p *StreamingProcessor) processThinking(text, signature string) []byte { var result bytes.Buffer @@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { p.trailingSignature = "" } + if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 { + groundingText := buildGroundingText(&GeminiGroundingMetadata{ + WebSearchQueries: p.webSearchQueries, + GroundingChunks: p.groundingChunks, + }) + if groundingText != "" { + _, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{ + "type": "text", + "text": "", + })) + _, _ = result.Write(p.emitDelta("text_delta", map[string]any{ + "text": groundingText, + })) + _, _ = result.Write(p.endBlock()) + } + } + // 确定 stop_reason stopReason := "end_turn" if p.usedTool { diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index a92ff9e8..43fe12d4 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) { // 支持 page_size 和 limit 两种参数名 if ps := c.Query("page_size"); ps != "" { - if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 { + if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 { pageSize = val } } else if l := c.Query("limit"); l != "" { - if val, err := parseInt(l); err == nil && val > 0 && val <= 100 { + if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 { pageSize = val } } diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go new file mode 100644 index 00000000..42510986 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -0,0 +1,568 @@ +// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. +// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients. +package tlsfingerprint + +import ( + "bufio" + "context" + "encoding/base64" + "fmt" + "log/slog" + "net" + "net/http" + "net/url" + + utls "github.com/refraction-networking/utls" + "golang.org/x/net/proxy" +) + +// Profile contains TLS fingerprint configuration. +type Profile struct { + Name string // Profile name for identification + CipherSuites []uint16 + Curves []uint16 + PointFormats []uint8 + EnableGREASE bool +} + +// Dialer creates TLS connections with custom fingerprints. +type Dialer struct { + profile *Profile + baseDialer func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints. +// It handles the CONNECT tunnel establishment before performing TLS handshake. +type HTTPProxyDialer struct { + profile *Profile + proxyURL *url.URL +} + +// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints. +// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel. +type SOCKS5ProxyDialer struct { + profile *Profile + proxyURL *url.URL +} + +// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x) +// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V +// JA3 Hash: 1a28e69016765d92e3b381168d68922c +// +// Note: JA3/JA4 may have slight variations due to: +// - Session ticket presence/absence +// - Extension negotiation state +var ( + // defaultCipherSuites contains all 59 cipher suites from Claude CLI + // Order is critical for JA3 fingerprint matching + defaultCipherSuites = []uint16{ + // TLS 1.3 cipher suites (MUST be first) + 0x1302, // TLS_AES_256_GCM_SHA384 + 0x1303, // TLS_CHACHA20_POLY1305_SHA256 + 0x1301, // TLS_AES_128_GCM_SHA256 + + // ECDHE + AES-GCM + 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + 0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + 0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + + // DHE + AES-GCM + 0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 + + // ECDHE/DHE + AES-CBC-SHA256/384 + 0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 + 0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 + 0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 + 0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 + + // DHE-DSS/RSA + AES-GCM + 0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 + 0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 + + // ChaCha20-Poly1305 + 0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + 0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + 0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + + // AES-CCM (256-bit) + 0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 + 0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM + 0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8 + 0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM + + // ARIA (256-bit) + 0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 + 0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 + 0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 + 0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 + + // DHE-DSS + AES-GCM (128-bit) + 0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 + + // AES-CCM (128-bit) + 0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 + 0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM + 0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8 + 0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM + + // ARIA (128-bit) + 0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 + 0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 + 0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 + 0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 + + // ECDHE/DHE + AES-CBC-SHA384/256 (more) + 0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 + 0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 + 0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 + 0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 + + // ECDHE/DHE + AES-CBC-SHA (legacy) + 0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + 0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA + 0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA + 0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA + 0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA + 0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA + 0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA + 0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA + + // RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit) + 0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384 + 0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8 + 0xc09d, // TLS_RSA_WITH_AES_256_CCM + 0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384 + + // RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit) + 0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256 + 0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8 + 0xc09c, // TLS_RSA_WITH_AES_128_CCM + 0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256 + + // RSA + AES-CBC (non-PFS, legacy) + 0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256 + 0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256 + 0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA + 0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA + + // Renegotiation indication + 0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV + } + + // defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE) + defaultCurves = []utls.CurveID{ + utls.X25519, // 0x001d + utls.CurveP256, // 0x0017 (secp256r1) + utls.CurveID(0x001e), // x448 + utls.CurveP521, // 0x0019 (secp521r1) + utls.CurveP384, // 0x0018 (secp384r1) + utls.CurveID(0x0100), // ffdhe2048 + utls.CurveID(0x0101), // ffdhe3072 + utls.CurveID(0x0102), // ffdhe4096 + utls.CurveID(0x0103), // ffdhe6144 + utls.CurveID(0x0104), // ffdhe8192 + } + + // defaultPointFormats contains all 3 point formats from Claude CLI + defaultPointFormats = []uint8{ + 0, // uncompressed + 1, // ansiX962_compressed_prime + 2, // ansiX962_compressed_char2 + } + + // defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI + defaultSignatureAlgorithms = []utls.SignatureScheme{ + 0x0403, // ecdsa_secp256r1_sha256 + 0x0503, // ecdsa_secp384r1_sha384 + 0x0603, // ecdsa_secp521r1_sha512 + 0x0807, // ed25519 + 0x0808, // ed448 + 0x0809, // rsa_pss_pss_sha256 + 0x080a, // rsa_pss_pss_sha384 + 0x080b, // rsa_pss_pss_sha512 + 0x0804, // rsa_pss_rsae_sha256 + 0x0805, // rsa_pss_rsae_sha384 + 0x0806, // rsa_pss_rsae_sha512 + 0x0401, // rsa_pkcs1_sha256 + 0x0501, // rsa_pkcs1_sha384 + 0x0601, // rsa_pkcs1_sha512 + 0x0303, // ecdsa_sha224 + 0x0301, // rsa_pkcs1_sha224 + 0x0302, // dsa_sha224 + 0x0402, // dsa_sha256 + 0x0502, // dsa_sha384 + 0x0602, // dsa_sha512 + } +) + +// NewDialer creates a new TLS fingerprint dialer. +// baseDialer is used for TCP connection establishment (supports proxy scenarios). +// If baseDialer is nil, direct TCP dial is used. +func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer { + if baseDialer == nil { + baseDialer = (&net.Dialer{}).DialContext + } + return &Dialer{profile: profile, baseDialer: baseDialer} +} + +// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies. +// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint. +func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer { + return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL} +} + +// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies. +// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint. +func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer { + return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL} +} + +// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint. +// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel +func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr) + + // Step 1: Create SOCKS5 dialer + var auth *proxy.Auth + if d.proxyURL.User != nil { + username := d.proxyURL.User.Username() + password, _ := d.proxyURL.User.Password() + auth = &proxy.Auth{ + User: username, + Password: password, + } + } + + // Determine proxy address + proxyAddr := d.proxyURL.Host + if d.proxyURL.Port() == "" { + proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port + } + + socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct) + if err != nil { + slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err) + return nil, fmt.Errorf("create SOCKS5 dialer: %w", err) + } + + // Step 2: Establish SOCKS5 tunnel to target + slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr) + conn, err := socksDialer.Dial("tcp", addr) + if err != nil { + slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err) + return nil, fmt.Errorf("SOCKS5 connect: %w", err) + } + slog.Debug("tls_fingerprint_socks5_tunnel_established") + + // Step 3: Perform TLS handshake on the tunnel with utls fingerprint + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host) + + // Build ClientHello specification from profile (Node.js/Claude CLI fingerprint) + spec := buildClientHelloSpecFromProfile(d.profile) + slog.Debug("tls_fingerprint_socks5_clienthello_spec", + "cipher_suites", len(spec.CipherSuites), + "extensions", len(spec.Extensions), + "compression_methods", spec.CompressionMethods, + "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax), + "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin)) + + if d.profile != nil { + slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) + } + + // Create uTLS connection on the tunnel + tlsConn := utls.UClient(conn, &utls.Config{ + ServerName: host, + }, utls.HelloCustom) + + if err := tlsConn.ApplyPreset(spec); err != nil { + slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("apply TLS preset: %w", err) + } + + if err := tlsConn.Handshake(); err != nil { + slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("TLS handshake failed: %w", err) + } + + state := tlsConn.ConnectionState() + slog.Debug("tls_fingerprint_socks5_handshake_success", + "version", fmt.Sprintf("0x%04x", state.Version), + "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "alpn", state.NegotiatedProtocol) + + return tlsConn, nil +} + +// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint. +// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls +func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr) + + // Step 1: TCP connect to proxy server + var proxyAddr string + if d.proxyURL.Port() != "" { + proxyAddr = d.proxyURL.Host + } else { + // Default ports + if d.proxyURL.Scheme == "https" { + proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443") + } else { + proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80") + } + } + + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", proxyAddr) + if err != nil { + slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err) + return nil, fmt.Errorf("connect to proxy: %w", err) + } + slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr) + + // Step 2: Send CONNECT request to establish tunnel + req := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: make(http.Header), + } + + // Add proxy authentication if present + if d.proxyURL.User != nil { + username := d.proxyURL.User.Username() + password, _ := d.proxyURL.User.Password() + auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + req.Header.Set("Proxy-Authorization", "Basic "+auth) + } + + slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr) + if err := req.Write(conn); err != nil { + _ = conn.Close() + slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err) + return nil, fmt.Errorf("write CONNECT request: %w", err) + } + + // Step 3: Read CONNECT response + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, req) + if err != nil { + _ = conn.Close() + slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err) + return nil, fmt.Errorf("read CONNECT response: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + _ = conn.Close() + slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status) + return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status) + } + slog.Debug("tls_fingerprint_http_proxy_tunnel_established") + + // Step 4: Perform TLS handshake on the tunnel with utls fingerprint + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host) + + // Build ClientHello specification (reuse the shared method) + spec := buildClientHelloSpecFromProfile(d.profile) + slog.Debug("tls_fingerprint_http_proxy_clienthello_spec", + "cipher_suites", len(spec.CipherSuites), + "extensions", len(spec.Extensions)) + + if d.profile != nil { + slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) + } + + // Create uTLS connection on the tunnel + // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions + tlsConn := utls.UClient(conn, &utls.Config{ + ServerName: host, + }, utls.HelloCustom) + + if err := tlsConn.ApplyPreset(spec); err != nil { + slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("apply TLS preset: %w", err) + } + + if err := tlsConn.HandshakeContext(ctx); err != nil { + slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("TLS handshake failed: %w", err) + } + + state := tlsConn.ConnectionState() + slog.Debug("tls_fingerprint_http_proxy_handshake_success", + "version", fmt.Sprintf("0x%04x", state.Version), + "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "alpn", state.NegotiatedProtocol) + + return tlsConn, nil +} + +// DialTLSContext establishes a TLS connection with the configured fingerprint. +// This method is designed to be used as http.Transport.DialTLSContext. +func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + // Establish TCP connection using base dialer (supports proxy) + slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr) + conn, err := d.baseDialer(ctx, network, addr) + if err != nil { + slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err) + return nil, err + } + slog.Debug("tls_fingerprint_tcp_connected", "addr", addr) + + // Extract hostname for SNI + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + slog.Debug("tls_fingerprint_sni_hostname", "host", host) + + // Build ClientHello specification + spec := d.buildClientHelloSpec() + slog.Debug("tls_fingerprint_clienthello_spec", + "cipher_suites", len(spec.CipherSuites), + "extensions", len(spec.Extensions)) + + // Log profile info + if d.profile != nil { + slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) + } else { + slog.Debug("tls_fingerprint_using_default_profile") + } + + // Create uTLS connection + // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions + tlsConn := utls.UClient(conn, &utls.Config{ + ServerName: host, + }, utls.HelloCustom) + + // Apply fingerprint + if err := tlsConn.ApplyPreset(spec); err != nil { + slog.Debug("tls_fingerprint_apply_preset_failed", "error", err) + _ = conn.Close() + return nil, err + } + slog.Debug("tls_fingerprint_preset_applied") + + // Perform TLS handshake + if err := tlsConn.HandshakeContext(ctx); err != nil { + slog.Debug("tls_fingerprint_handshake_failed", + "error", err, + "local_addr", conn.LocalAddr(), + "remote_addr", conn.RemoteAddr()) + _ = conn.Close() + return nil, fmt.Errorf("TLS handshake failed: %w", err) + } + + // Log successful handshake details + state := tlsConn.ConnectionState() + slog.Debug("tls_fingerprint_handshake_success", + "version", fmt.Sprintf("0x%04x", state.Version), + "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "alpn", state.NegotiatedProtocol) + + return tlsConn, nil +} + +// buildClientHelloSpec constructs the ClientHello specification based on the profile. +func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec { + return buildClientHelloSpecFromProfile(d.profile) +} + +// toUTLSCurves converts uint16 slice to utls.CurveID slice. +func toUTLSCurves(curves []uint16) []utls.CurveID { + result := make([]utls.CurveID, len(curves)) + for i, c := range curves { + result[i] = utls.CurveID(c) + } + return result +} + +// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile. +// This is a standalone function that can be used by both Dialer and HTTPProxyDialer. +func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec { + // Get cipher suites + var cipherSuites []uint16 + if profile != nil && len(profile.CipherSuites) > 0 { + cipherSuites = profile.CipherSuites + } else { + cipherSuites = defaultCipherSuites + } + + // Get curves + var curves []utls.CurveID + if profile != nil && len(profile.Curves) > 0 { + curves = toUTLSCurves(profile.Curves) + } else { + curves = defaultCurves + } + + // Get point formats + var pointFormats []uint8 + if profile != nil && len(profile.PointFormats) > 0 { + pointFormats = profile.PointFormats + } else { + pointFormats = defaultPointFormats + } + + // Check if GREASE is enabled + enableGREASE := profile != nil && profile.EnableGREASE + + extensions := make([]utls.TLSExtension, 0, 16) + + if enableGREASE { + extensions = append(extensions, &utls.UtlsGREASEExtension{}) + } + + // SNI extension - MUST be explicitly added for HelloCustom mode + // utls will populate the server name from Config.ServerName + extensions = append(extensions, &utls.SNIExtension{}) + + // Claude CLI extension order (captured from tshark): + // server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35), + // alpn(16), encrypt_then_mac(22), extended_master_secret(23), + // signature_algorithms(13), supported_versions(43), + // psk_key_exchange_modes(45), key_share(51) + extensions = append(extensions, + &utls.SupportedPointsExtension{SupportedPoints: pointFormats}, + &utls.SupportedCurvesExtension{Curves: curves}, + &utls.SessionTicketExtension{}, + &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}, + &utls.GenericExtension{Id: 22}, + &utls.ExtendedMasterSecretExtension{}, + &utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms}, + &utls.SupportedVersionsExtension{Versions: []uint16{ + utls.VersionTLS13, + utls.VersionTLS12, + }}, + &utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}}, + &utls.KeyShareExtension{KeyShares: []utls.KeyShare{ + {Group: utls.X25519}, + }}, + ) + + if enableGREASE { + extensions = append(extensions, &utls.UtlsGREASEExtension{}) + } + + return &utls.ClientHelloSpec{ + CipherSuites: cipherSuites, + CompressionMethods: []uint8{0}, // null compression only (standard) + Extensions: extensions, + TLSVersMax: utls.VersionTLS13, + TLSVersMin: utls.VersionTLS10, + } +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go new file mode 100644 index 00000000..2aed1287 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -0,0 +1,307 @@ +// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. +// +// Integration tests for verifying TLS fingerprint correctness. +// These tests make actual network requests and should be run manually. +// +// Run with: go test -v ./internal/pkg/tlsfingerprint/... +// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/... +package tlsfingerprint + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" +) + +// FingerprintResponse represents the response from tls.peet.ws/api/all. +type FingerprintResponse struct { + IP string `json:"ip"` + TLS TLSInfo `json:"tls"` + HTTP2 any `json:"http2"` +} + +// TLSInfo contains TLS fingerprint details. +type TLSInfo struct { + JA3 string `json:"ja3"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + PeetPrint string `json:"peetprint"` + PeetPrintHash string `json:"peetprint_hash"` + ClientRandom string `json:"client_random"` + SessionID string `json:"session_id"` +} + +// TestDialerBasicConnection tests that the dialer can establish TLS connections. +func TestDialerBasicConnection(t *testing.T) { + if testing.Short() { + t.Skip("skipping network test in short mode") + } + + // Create a dialer with default profile + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + // Create HTTP client with custom TLS dialer + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Make a request to a known HTTPS endpoint + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. +// This test uses tls.peet.ws to verify the fingerprint. +// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) +// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +func TestJA3Fingerprint(t *testing.T) { + // Skip if network is unavailable or if running in short mode + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + profile := &Profile{ + Name: "Claude CLI Test", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Use tls.peet.ws fingerprint detection API + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + } + + // Log all fingerprint information + t.Logf("JA3: %s", fpResp.TLS.JA3) + t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) + t.Logf("JA4: %s", fpResp.TLS.JA4) + t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) + t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) + + // Verify JA3 hash matches expected value + expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + if fpResp.TLS.JA3Hash == expectedJA3Hash { + t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) + } else { + t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) + } + + // Verify JA4 fingerprint + // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] + // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) + // The suffix _a33745022dd6_1f22a2ca17c4 should match + expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" + if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { + t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + } else { + t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + } + + // Verify JA4 prefix (t13d5911h1 or t13i5911h1) + // d = domain (SNI present), i = IP (no SNI) + // Since we connect to tls.peet.ws (domain), we expect 'd' + expectedJA4Prefix := "t13d5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { + t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) + } else { + // Also accept 'i' variant for IP connections + altPrefix := "t13i5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { + t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) + } else { + t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) + } + } + + // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) + if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { + t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") + } else { + t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") + } + + // Verify extension list (should be 11 extensions including SNI) + // Expected: 0-11-10-35-16-22-23-13-43-45-51 + expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" + if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { + t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) + } else { + t.Logf("Warning: JA3 extension list may differ") + } +} + +// TestDialerWithProfile tests that different profiles produce different fingerprints. +func TestDialerWithProfile(t *testing.T) { + // Create two dialers with different profiles + profile1 := &Profile{ + Name: "Profile 1 - No GREASE", + EnableGREASE: false, + } + profile2 := &Profile{ + Name: "Profile 2 - With GREASE", + EnableGREASE: true, + } + + dialer1 := NewDialer(profile1, nil) + dialer2 := NewDialer(profile2, nil) + + // Build specs and compare + // Note: We can't directly compare JA3 without making network requests + // but we can verify the specs are different + spec1 := dialer1.buildClientHelloSpec() + spec2 := dialer2.buildClientHelloSpec() + + // Profile with GREASE should have more extensions + if len(spec2.Extensions) <= len(spec1.Extensions) { + t.Error("expected GREASE profile to have more extensions") + } +} + +// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation. +// Note: This is a unit test - actual proxy testing requires a proxy server. +func TestHTTPProxyDialerBasic(t *testing.T) { + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + + // Test that dialer is created without panic + proxyURL := mustParseURL("http://proxy.example.com:8080") + dialer := NewHTTPProxyDialer(profile, proxyURL) + + if dialer == nil { + t.Fatal("expected dialer to be created") + } + if dialer.profile != profile { + t.Error("expected profile to be set") + } + if dialer.proxyURL != proxyURL { + t.Error("expected proxyURL to be set") + } +} + +// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation. +// Note: This is a unit test - actual proxy testing requires a proxy server. +func TestSOCKS5ProxyDialerBasic(t *testing.T) { + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + + // Test that dialer is created without panic + proxyURL := mustParseURL("socks5://proxy.example.com:1080") + dialer := NewSOCKS5ProxyDialer(profile, proxyURL) + + if dialer == nil { + t.Fatal("expected dialer to be created") + } + if dialer.profile != profile { + t.Error("expected profile to be set") + } + if dialer.proxyURL != proxyURL { + t.Error("expected proxyURL to be set") + } +} + +// TestBuildClientHelloSpec tests ClientHello spec construction. +func TestBuildClientHelloSpec(t *testing.T) { + // Test with nil profile (should use defaults) + spec := buildClientHelloSpecFromProfile(nil) + + if len(spec.CipherSuites) == 0 { + t.Error("expected cipher suites to be set") + } + if len(spec.Extensions) == 0 { + t.Error("expected extensions to be set") + } + + // Verify default cipher suites are used + if len(spec.CipherSuites) != len(defaultCipherSuites) { + t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites)) + } + + // Test with custom profile + customProfile := &Profile{ + Name: "Custom", + EnableGREASE: false, + CipherSuites: []uint16{0x1301, 0x1302}, + } + spec = buildClientHelloSpecFromProfile(customProfile) + + if len(spec.CipherSuites) != 2 { + t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites)) + } +} + +// TestToUTLSCurves tests curve ID conversion. +func TestToUTLSCurves(t *testing.T) { + input := []uint16{0x001d, 0x0017, 0x0018} + result := toUTLSCurves(input) + + if len(result) != len(input) { + t.Errorf("expected %d curves, got %d", len(input), len(result)) + } + + for i, curve := range result { + if uint16(curve) != input[i] { + t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve)) + } + } +} + +// Helper function to parse URL without error handling. +func mustParseURL(rawURL string) *url.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return u +} diff --git a/backend/internal/pkg/tlsfingerprint/registry.go b/backend/internal/pkg/tlsfingerprint/registry.go new file mode 100644 index 00000000..6e9dc539 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/registry.go @@ -0,0 +1,171 @@ +// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. +package tlsfingerprint + +import ( + "log/slog" + "sort" + "sync" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// DefaultProfileName is the name of the built-in Claude CLI profile. +const DefaultProfileName = "claude_cli_v2" + +// Registry manages TLS fingerprint profiles. +// It holds a collection of profiles that can be used for TLS fingerprint simulation. +// Profiles are selected based on account ID using modulo operation. +type Registry struct { + mu sync.RWMutex + profiles map[string]*Profile + profileNames []string // Sorted list of profile names for deterministic selection +} + +// NewRegistry creates a new TLS fingerprint profile registry. +// It initializes with the built-in default profile. +func NewRegistry() *Registry { + r := &Registry{ + profiles: make(map[string]*Profile), + profileNames: make([]string, 0), + } + + // Register the built-in default profile + r.registerBuiltinProfile() + + return r +} + +// NewRegistryFromConfig creates a new registry and loads profiles from config. +// If the config has custom profiles defined, they will be merged with the built-in default. +func NewRegistryFromConfig(cfg *config.TLSFingerprintConfig) *Registry { + r := NewRegistry() + + if cfg == nil || !cfg.Enabled { + slog.Debug("tls_registry_disabled", "reason", "disabled or no config") + return r + } + + // Load custom profiles from config + for name, profileCfg := range cfg.Profiles { + profile := &Profile{ + Name: profileCfg.Name, + EnableGREASE: profileCfg.EnableGREASE, + CipherSuites: profileCfg.CipherSuites, + Curves: profileCfg.Curves, + PointFormats: profileCfg.PointFormats, + } + + // If the profile has empty values, they will use defaults in dialer + r.RegisterProfile(name, profile) + slog.Debug("tls_registry_loaded_profile", "key", name, "name", profileCfg.Name) + } + + slog.Debug("tls_registry_initialized", "profile_count", len(r.profileNames), "profiles", r.profileNames) + return r +} + +// registerBuiltinProfile adds the default Claude CLI profile to the registry. +func (r *Registry) registerBuiltinProfile() { + defaultProfile := &Profile{ + Name: "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)", + EnableGREASE: false, // Node.js does not use GREASE + // Empty slices will cause dialer to use built-in defaults + CipherSuites: nil, + Curves: nil, + PointFormats: nil, + } + r.RegisterProfile(DefaultProfileName, defaultProfile) +} + +// RegisterProfile adds or updates a profile in the registry. +func (r *Registry) RegisterProfile(name string, profile *Profile) { + r.mu.Lock() + defer r.mu.Unlock() + + // Check if this is a new profile + _, exists := r.profiles[name] + r.profiles[name] = profile + + if !exists { + r.profileNames = append(r.profileNames, name) + // Keep names sorted for deterministic selection + sort.Strings(r.profileNames) + } +} + +// GetProfile returns a profile by name. +// Returns nil if the profile does not exist. +func (r *Registry) GetProfile(name string) *Profile { + r.mu.RLock() + defer r.mu.RUnlock() + return r.profiles[name] +} + +// GetDefaultProfile returns the built-in default profile. +func (r *Registry) GetDefaultProfile() *Profile { + return r.GetProfile(DefaultProfileName) +} + +// GetProfileByAccountID returns a profile for the given account ID. +// The profile is selected using: profileNames[accountID % len(profiles)] +// This ensures deterministic profile assignment for each account. +func (r *Registry) GetProfileByAccountID(accountID int64) *Profile { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.profileNames) == 0 { + return nil + } + + // Use modulo to select profile index + // Use absolute value to handle negative IDs (though unlikely) + idx := accountID + if idx < 0 { + idx = -idx + } + selectedIndex := int(idx % int64(len(r.profileNames))) + selectedName := r.profileNames[selectedIndex] + + return r.profiles[selectedName] +} + +// ProfileCount returns the number of registered profiles. +func (r *Registry) ProfileCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.profiles) +} + +// ProfileNames returns a sorted list of all registered profile names. +func (r *Registry) ProfileNames() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + // Return a copy to prevent modification + names := make([]string, len(r.profileNames)) + copy(names, r.profileNames) + return names +} + +// Global registry instance for convenience +var globalRegistry *Registry +var globalRegistryOnce sync.Once + +// GlobalRegistry returns the global TLS fingerprint registry. +// The registry is lazily initialized with the default profile. +func GlobalRegistry() *Registry { + globalRegistryOnce.Do(func() { + globalRegistry = NewRegistry() + }) + return globalRegistry +} + +// InitGlobalRegistry initializes the global registry with configuration. +// This should be called during application startup. +// It is safe to call multiple times; subsequent calls will update the registry. +func InitGlobalRegistry(cfg *config.TLSFingerprintConfig) *Registry { + globalRegistryOnce.Do(func() { + globalRegistry = NewRegistryFromConfig(cfg) + }) + return globalRegistry +} diff --git a/backend/internal/pkg/tlsfingerprint/registry_test.go b/backend/internal/pkg/tlsfingerprint/registry_test.go new file mode 100644 index 00000000..752ba0cc --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/registry_test.go @@ -0,0 +1,243 @@ +package tlsfingerprint + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func TestNewRegistry(t *testing.T) { + r := NewRegistry() + + // Should have exactly one profile (the default) + if r.ProfileCount() != 1 { + t.Errorf("expected 1 profile, got %d", r.ProfileCount()) + } + + // Should have the default profile + profile := r.GetDefaultProfile() + if profile == nil { + t.Error("expected default profile to exist") + } + + // Default profile name should be in the list + names := r.ProfileNames() + if len(names) != 1 || names[0] != DefaultProfileName { + t.Errorf("expected profile names to be [%s], got %v", DefaultProfileName, names) + } +} + +func TestRegisterProfile(t *testing.T) { + r := NewRegistry() + + // Register a new profile + customProfile := &Profile{ + Name: "Custom Profile", + EnableGREASE: true, + } + r.RegisterProfile("custom", customProfile) + + // Should now have 2 profiles + if r.ProfileCount() != 2 { + t.Errorf("expected 2 profiles, got %d", r.ProfileCount()) + } + + // Should be able to retrieve the custom profile + retrieved := r.GetProfile("custom") + if retrieved == nil { + t.Fatal("expected custom profile to exist") + } + if retrieved.Name != "Custom Profile" { + t.Errorf("expected profile name 'Custom Profile', got '%s'", retrieved.Name) + } + if !retrieved.EnableGREASE { + t.Error("expected EnableGREASE to be true") + } +} + +func TestGetProfile(t *testing.T) { + r := NewRegistry() + + // Get existing profile + profile := r.GetProfile(DefaultProfileName) + if profile == nil { + t.Error("expected default profile to exist") + } + + // Get non-existing profile + nonExistent := r.GetProfile("nonexistent") + if nonExistent != nil { + t.Error("expected nil for non-existent profile") + } +} + +func TestGetProfileByAccountID(t *testing.T) { + r := NewRegistry() + + // With only default profile, all account IDs should return the same profile + for i := int64(0); i < 10; i++ { + profile := r.GetProfileByAccountID(i) + if profile == nil { + t.Errorf("expected profile for account %d, got nil", i) + } + } + + // Add more profiles + r.RegisterProfile("profile_a", &Profile{Name: "Profile A"}) + r.RegisterProfile("profile_b", &Profile{Name: "Profile B"}) + + // Now we have 3 profiles: claude_cli_v2, profile_a, profile_b + // Names are sorted, so order is: claude_cli_v2, profile_a, profile_b + expectedOrder := []string{DefaultProfileName, "profile_a", "profile_b"} + names := r.ProfileNames() + for i, name := range expectedOrder { + if names[i] != name { + t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i]) + } + } + + // Test modulo selection + // Account ID 0 % 3 = 0 -> claude_cli_v2 + // Account ID 1 % 3 = 1 -> profile_a + // Account ID 2 % 3 = 2 -> profile_b + // Account ID 3 % 3 = 0 -> claude_cli_v2 + testCases := []struct { + accountID int64 + expectedName string + }{ + {0, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, + {1, "Profile A"}, + {2, "Profile B"}, + {3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, + {4, "Profile A"}, + {5, "Profile B"}, + {100, "Profile A"}, // 100 % 3 = 1 + {-1, "Profile A"}, // |-1| % 3 = 1 + {-3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, // |-3| % 3 = 0 + } + + for _, tc := range testCases { + profile := r.GetProfileByAccountID(tc.accountID) + if profile == nil { + t.Errorf("expected profile for account %d, got nil", tc.accountID) + continue + } + if profile.Name != tc.expectedName { + t.Errorf("account %d: expected profile name '%s', got '%s'", tc.accountID, tc.expectedName, profile.Name) + } + } +} + +func TestNewRegistryFromConfig(t *testing.T) { + // Test with nil config + r := NewRegistryFromConfig(nil) + if r.ProfileCount() != 1 { + t.Errorf("expected 1 profile with nil config, got %d", r.ProfileCount()) + } + + // Test with disabled config + disabledCfg := &config.TLSFingerprintConfig{ + Enabled: false, + } + r = NewRegistryFromConfig(disabledCfg) + if r.ProfileCount() != 1 { + t.Errorf("expected 1 profile with disabled config, got %d", r.ProfileCount()) + } + + // Test with enabled config and custom profiles + enabledCfg := &config.TLSFingerprintConfig{ + Enabled: true, + Profiles: map[string]config.TLSProfileConfig{ + "custom1": { + Name: "Custom Profile 1", + EnableGREASE: true, + }, + "custom2": { + Name: "Custom Profile 2", + EnableGREASE: false, + }, + }, + } + r = NewRegistryFromConfig(enabledCfg) + + // Should have 3 profiles: default + 2 custom + if r.ProfileCount() != 3 { + t.Errorf("expected 3 profiles, got %d", r.ProfileCount()) + } + + // Check custom profiles exist + custom1 := r.GetProfile("custom1") + if custom1 == nil || custom1.Name != "Custom Profile 1" { + t.Error("expected custom1 profile to exist with correct name") + } + custom2 := r.GetProfile("custom2") + if custom2 == nil || custom2.Name != "Custom Profile 2" { + t.Error("expected custom2 profile to exist with correct name") + } +} + +func TestProfileNames(t *testing.T) { + r := NewRegistry() + + // Add profiles in non-alphabetical order + r.RegisterProfile("zebra", &Profile{Name: "Zebra"}) + r.RegisterProfile("alpha", &Profile{Name: "Alpha"}) + r.RegisterProfile("beta", &Profile{Name: "Beta"}) + + names := r.ProfileNames() + + // Should be sorted alphabetically + expected := []string{"alpha", "beta", DefaultProfileName, "zebra"} + if len(names) != len(expected) { + t.Errorf("expected %d names, got %d", len(expected), len(names)) + } + for i, name := range expected { + if names[i] != name { + t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i]) + } + } + + // Test that returned slice is a copy (modifying it shouldn't affect registry) + names[0] = "modified" + originalNames := r.ProfileNames() + if originalNames[0] == "modified" { + t.Error("modifying returned slice should not affect registry") + } +} + +func TestConcurrentAccess(t *testing.T) { + r := NewRegistry() + + // Run concurrent reads and writes + done := make(chan bool) + + // Writers + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + r.RegisterProfile("concurrent"+string(rune('0'+id)), &Profile{Name: "Concurrent"}) + } + done <- true + }(i) + } + + // Readers + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + _ = r.ProfileCount() + _ = r.ProfileNames() + _ = r.GetProfileByAccountID(int64(id * j)) + _ = r.GetProfile(DefaultProfileName) + } + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < 20; i++ { + <-done + } + + // Test should pass without data races (run with -race flag) +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 84bd7b9e..c11c079b 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -575,6 +575,15 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } } +func (r *accountRepository) ClearError(ctx context.Context, id int64) error { + _, err := r.client.Account.Update(). + Where(dbaccount.IDEQ(id)). + SetStatus(service.StatusActive). + SetErrorMessage(""). + Save(ctx) + return err +} + func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { _, err := r.client.AccountGroup.Create(). SetAccountID(accountID). @@ -993,7 +1002,16 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s builder.SetSessionWindowEnd(*end) } _, err := builder.Save(ctx) - return err + if err != nil { + return err + } + // 触发调度器缓存更新(仅当窗口时间有变化时) + if start != nil || end != nil { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) + } + } + return nil } func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go index 6d834b40..a1072057 100644 --- a/backend/internal/repository/api_key_cache.go +++ b/backend/internal/repository/api_key_cache.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -12,9 +13,10 @@ import ( ) const ( - apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" - apiKeyRateLimitDuration = 24 * time.Hour - apiKeyAuthCachePrefix = "apikey:auth:" + apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" + apiKeyRateLimitDuration = 24 * time.Hour + apiKeyAuthCachePrefix = "apikey:auth:" + authCacheInvalidateChannel = "auth:cache:invalidate" ) // apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. @@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err() } + +// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances +func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err() +} + +// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages +func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel) + + // Verify subscription is working + _, err := pubsub.Receive(ctx) + if err != nil { + _ = pubsub.Close() + return fmt.Errorf("subscribe to auth cache invalidation: %w", err) + } + + go func() { + defer func() { + if err := pubsub.Close(); err != nil { + log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err) + } + }() + + ch := pubsub.Channel() + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-ch: + if !ok { + return + } + if msg != nil { + handler(msg.Payload) + } + } + } + }() + + return nil +} diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index feb32541..b0f15f19 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/http" "net/url" @@ -14,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i return resp, nil } +// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求 +// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹 +// +// 参数: +// - req: HTTP 请求对象 +// - proxyURL: 代理地址,空字符串表示直连 +// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择 +// - accountConcurrency: 账户并发限制,用于动态调整连接池大小 +// - enableTLSFingerprint: 是否启用 TLS 指纹伪装 +// +// TLS 指纹说明: +// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹 +// - 指纹模板根据 accountID % len(profiles) 自动选择 +// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景 +func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + // 如果未启用 TLS 指纹,直接使用标准请求路径 + if !enableTLSFingerprint { + return s.Do(req, proxyURL, accountID, accountConcurrency) + } + + // TLS 指纹已启用,记录调试日志 + targetHost := "" + if req != nil && req.URL != nil { + targetHost = req.URL.Host + } + proxyInfo := "direct" + if proxyURL != "" { + proxyInfo = proxyURL + } + slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo) + + if err := s.validateRequestHost(req); err != nil { + return nil, err + } + + // 获取 TLS 指纹 Profile + registry := tlsfingerprint.GlobalRegistry() + profile := registry.GetProfileByAccountID(accountID) + if profile == nil { + // 如果获取不到 profile,回退到普通请求 + slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request") + return s.Do(req, proxyURL, accountID, accountConcurrency) + } + + slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE) + + // 获取或创建带 TLS 指纹的客户端 + entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile) + if err != nil { + slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err) + return nil, err + } + + // 执行请求 + resp, err := entry.client.Do(req) + if err != nil { + // 请求失败,立即减少计数 + atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) + slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err) + return nil, err + } + + slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode) + + // 包装响应体,在关闭时自动减少计数并更新时间戳 + resp.Body = wrapTrackedBody(resp.Body, func() { + atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) + }) + + return resp, nil +} + +// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端 +func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) { + return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true) +} + +// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目 +// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 +func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { + isolation := s.getIsolationMode() + proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 + cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) + poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" + + now := time.Now() + nowUnix := now.UnixNano() + + // 读锁快速路径 + s.mu.RLock() + if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.RUnlock() + slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey) + return entry, nil + } + s.mu.RUnlock() + + // 写锁慢路径 + s.mu.Lock() + if entry, ok := s.clients[cacheKey]; ok { + if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.Unlock() + slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey) + return entry, nil + } + slog.Debug("tls_fingerprint_evicting_stale_client", + "account_id", accountID, + "cache_key", cacheKey, + "proxy_changed", entry.proxyKey != proxyKey, + "pool_changed", entry.poolKey != poolKey) + s.removeClientLocked(cacheKey, entry) + } + + // 超出缓存上限时尝试淘汰 + if enforceLimit && s.maxUpstreamClients() > 0 { + s.evictIdleLocked(now) + if len(s.clients) >= s.maxUpstreamClients() { + if !s.evictOldestIdleLocked() { + s.mu.Unlock() + return nil, errUpstreamClientLimitReached + } + } + } + + // 创建带 TLS 指纹的 Transport + slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey) + settings := s.resolvePoolSettings(isolation, accountConcurrency) + transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile) + if err != nil { + s.mu.Unlock() + return nil, fmt.Errorf("build TLS fingerprint transport: %w", err) + } + + client := &http.Client{Transport: transport} + if s.shouldValidateResolvedIP() { + client.CheckRedirect = s.redirectChecker + } + + entry := &upstreamClientEntry{ + client: client, + proxyKey: proxyKey, + poolKey: poolKey, + } + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.StoreInt64(&entry.inFlight, 1) + } + s.clients[cacheKey] = entry + + s.evictIdleLocked(now) + s.evictOverLimitLocked() + s.mu.Unlock() + return entry, nil +} + func (s *httpUpstreamService) shouldValidateResolvedIP() bool { if s.cfg == nil { return false @@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra return transport, nil } +// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport +// 使用 utls 库模拟 Claude CLI 的 TLS 指纹 +// +// 参数: +// - settings: 连接池配置 +// - proxyURL: 代理 URL(nil 表示直连) +// - profile: TLS 指纹配置 +// +// 返回: +// - *http.Transport: 配置好的 Transport 实例 +// - error: 配置错误 +// +// 代理类型处理: +// - nil/空: 直连,使用 TLSFingerprintDialer +// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手) +// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手) +func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) { + transport := &http.Transport{ + MaxIdleConns: settings.maxIdleConns, + MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, + MaxConnsPerHost: settings.maxConnsPerHost, + IdleConnTimeout: settings.idleConnTimeout, + ResponseHeaderTimeout: settings.responseHeaderTimeout, + // 禁用默认的 TLS,我们使用自定义的 DialTLSContext + ForceAttemptHTTP2: false, + } + + // 根据代理类型选择合适的 TLS 指纹 Dialer + if proxyURL == nil { + // 直连:使用 TLSFingerprintDialer + slog.Debug("tls_fingerprint_transport_direct") + dialer := tlsfingerprint.NewDialer(profile, nil) + transport.DialTLSContext = dialer.DialTLSContext + } else { + scheme := strings.ToLower(proxyURL.Scheme) + switch scheme { + case "socks5", "socks5h": + // SOCKS5 代理:使用 SOCKS5ProxyDialer + slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host) + socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL) + transport.DialTLSContext = socks5Dialer.DialTLSContext + case "http", "https": + // HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道) + slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host) + httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL) + transport.DialTLSContext = httpDialer.DialTLSContext + default: + // 未知代理类型,回退到普通代理配置(无 TLS 指纹) + slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme) + if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { + return nil, err + } + } + } + + return transport, nil +} + // trackedBody 带跟踪功能的响应体包装器 // 在 Close 时执行回调,用于更新请求计数 type trackedBody struct { diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go index d28477b7..c4986547 100644 --- a/backend/internal/repository/identity_cache.go +++ b/backend/internal/repository/identity_cache.go @@ -11,8 +11,10 @@ import ( ) const ( - fingerprintKeyPrefix = "fingerprint:" - fingerprintTTL = 24 * time.Hour + fingerprintKeyPrefix = "fingerprint:" + fingerprintTTL = 24 * time.Hour + maskedSessionKeyPrefix = "masked_session:" + maskedSessionTTL = 15 * time.Minute ) // fingerprintKey generates the Redis key for account fingerprint cache. @@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string { return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) } +// maskedSessionKey generates the Redis key for masked session ID cache. +func maskedSessionKey(accountID int64) string { + return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID) +} + type identityCache struct { rdb *redis.Client } @@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp } return c.rdb.Set(ctx, key, val, fingerprintTTL).Err() } + +func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) { + key := maskedSessionKey(accountID) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return "", nil + } + return "", err + } + return val, nil +} + +func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error { + key := maskedSessionKey(accountID) + return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err() +} diff --git a/backend/internal/repository/session_limit_cache.go b/backend/internal/repository/session_limit_cache.go index 16f2a69c..3dc89f87 100644 --- a/backend/internal/repository/session_limit_cache.go +++ b/backend/internal/repository/session_limit_cache.go @@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID } // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 -func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { +func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) { if len(accountIDs) == 0 { return make(map[int64]int), nil } @@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco // 使用 pipeline 批量执行 pipe := c.rdb.Pipeline() - idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds()) cmds := make(map[int64]*redis.Cmd, len(accountIDs)) for _, accountID := range accountIDs { key := sessionLimitKey(accountID) + // 使用各账号自己的 idleTimeout,如果没有则用默认值 + idleTimeout := c.defaultIdleTimeout + if idleTimeouts != nil { + if t, ok := idleTimeouts[accountID]; ok && t > 0 { + idleTimeout = t + } + } + idleTimeoutSeconds := int(idleTimeout.Seconds()) cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds) } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 0a549b19..be8a8df8 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -618,6 +618,14 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { return nil } +func (stubApiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return nil +} + +func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + return nil +} + type stubGroupRepo struct{} func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { @@ -736,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin return errors.New("not implemented") } +func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return errors.New("not implemented") } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 36ba0bcc..27f693d6 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -576,6 +576,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool { return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken) } +// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征 +func (a *Account) IsTLSFingerprintEnabled() bool { + // 仅支持 Anthropic OAuth/SetupToken 账号 + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["enable_tls_fingerprint"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID, +// 使上游认为请求来自同一个会话 +func (a *Account) IsSessionIDMaskingEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["session_id_masking_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { @@ -652,6 +690,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo return WindowCostNotSchedulable } +// GetCurrentWindowStartTime 获取当前有效的窗口开始时间 +// 逻辑: +// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart +// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始) +func (a *Account) GetCurrentWindowStartTime() time.Time { + now := time.Now() + + // 窗口未过期,使用记录的窗口开始时间 + if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) { + return *a.SessionWindowStart + } + + // 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始) + // 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致 + return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) +} + // parseExtraFloat64 从 extra 字段解析 float64 值 func parseExtraFloat64(value any) float64 { switch v := value.(type) { diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index ede5b12f..90365d2f 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -37,6 +37,7 @@ type AccountRepository interface { UpdateLastUsed(ctx context.Context, id int64) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error SetError(ctx context.Context, id int64, errorMsg string) error + ClearError(ctx context.Context, id int64) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 36af719c..e5eabfc6 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin panic("unexpected SetError call") } +func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error { + panic("unexpected ClearError call") +} + func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { panic("unexpected SetSchedulable call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 8419c2b4..46376c69 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index e3c0974e..6c617e27 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -369,12 +369,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou // 如果没有缓存,从数据库查询 if windowStats == nil { - var startTime time.Time - if account.SessionWindowStart != nil { - startTime = *account.SessionWindowStart - } else { - startTime = time.Now().Add(-5 * time.Hour) - } + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := account.GetCurrentWindowStartTime() stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) if err != nil { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index c0694e4e..0afa0716 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -42,6 +42,7 @@ type AdminService interface { DeleteAccount(ctx context.Context, id int64) error RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error) + SetAccountError(ctx context.Context, id int64, errorMsg string) error SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) @@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac return account, nil } +func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error { + return s.accountRepo.SetError(ctx, id, errorMsg) +} + func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 7f3e97a2..043f338d 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -12,6 +12,7 @@ import ( mathrand "math/rand" "net" "net/http" + "os" "strings" "sync/atomic" "time" @@ -28,6 +29,207 @@ const ( antigravityRetryMaxDelay = 16 * time.Second ) +const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" + +// antigravityRetryLoopParams 重试循环的参数 +type antigravityRetryLoopParams struct { + ctx context.Context + prefix string + account *Account + proxyURL string + accessToken string + action string + body []byte + quotaScope AntigravityQuotaScope + c *gin.Context + httpUpstream HTTPUpstream + settingService *SettingService + handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) +} + +// antigravityRetryLoopResult 重试循环的结果 +type antigravityRetryLoopResult struct { + resp *http.Response +} + +// antigravityRetryLoop 执行带 URL fallback 的重试循环 +func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs + } + + var resp *http.Response + var usedBaseURL string + logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody + maxBytes := 2048 + if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + } + getUpstreamDetail := func(body []byte) string { + if !logBody { + return "" + } + return truncateString(string(body), maxBytes) + } + +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + usedBaseURL = baseURL + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + select { + case <-p.ctx.Done(): + log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) + return nil, p.ctx.Err() + default: + } + + upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) + if err != nil { + return nil, err + } + + // Capture upstream request body for ops retry of this attempt. + if p.c != nil && len(p.body) > 0 { + p.c.Set(OpsUpstreamRequestBodyKey, string(p.body)) + } + + resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop + } + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err) + setOpsUpstreamError(p.c, 0, safeErr, "") + return nil, fmt.Errorf("upstream request failed after retries: %w", err) + } + + // 429 限流处理:区分 URL 级别限流和账户配额限流 + if resp.StatusCode == http.StatusTooManyRequests { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + // "Resource has been exhausted" 是 URL 级别限流,切换 URL + if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 { + log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop + } + + // 账户/模型配额限流,重试 3 次(指数退避) + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + + // 重试用尽,标记账户限流 + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope) + log.Printf("%s status=429 rate_limited base_url=%s body=%s", p.prefix, baseURL, truncateForLog(respBody, 200)) + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + // 其他可重试错误 + if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } + } + + if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" { + antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL) + } + + return &antigravityRetryLoopResult{resp: resp}, nil +} + +// shouldRetryAntigravityError 判断是否应该重试 +func shouldRetryAntigravityError(statusCode int) bool { + switch statusCode { + case 429, 500, 502, 503, 504, 529: + return true + default: + return false + } +} + +// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试) +// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功 +// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效 +func isURLLevelRateLimit(body []byte) bool { + // 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model" + bodyStr := string(body) + return strings.Contains(bodyStr, "Resource has been exhausted") && + !strings.Contains(bodyStr, "capacity on this model") +} + // isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) func isAntigravityConnectionError(err error) bool { if err == nil { @@ -238,7 +440,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account if err != nil { lastErr = fmt.Errorf("请求失败: %w", err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } @@ -254,7 +455,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 检查是否需要 URL 降级 if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -266,6 +466,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 解析流式响应,提取文本 text := extractTextFromSSEResponse(respBody) + // 标记成功的 URL,下次优先使用 + antigravity.DefaultURLAvailability.MarkSuccess(baseURL) return &TestConnectionResult{ Text: text, MappedModel: mappedModel, @@ -276,13 +478,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account } // buildGeminiTestRequest 构建 Gemini 格式测试请求 +// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1 func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) { payload := map[string]any{ "contents": []map[string]any{ { "role": "user", "parts": []map[string]any{ - {"text": "hi"}, + {"text": "."}, }, }, }, @@ -292,22 +495,26 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri {"text": antigravity.GetDefaultIdentityPatch()}, }, }, + "generationConfig": map[string]any{ + "maxOutputTokens": 1, + }, } payloadBytes, _ := json.Marshal(payload) return s.wrapV1InternalRequest(projectID, model, payloadBytes) } // buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式 +// 使用最小 token 消耗:输入 "." + MaxTokens: 1 func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) { claudeReq := &antigravity.ClaudeRequest{ Model: mappedModel, Messages: []antigravity.ClaudeMessage{ { Role: "user", - Content: json.RawMessage(`"hi"`), + Content: json.RawMessage(`"."`), }, }, - MaxTokens: 1024, + MaxTokens: 1, Stream: false, } return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel) @@ -523,9 +730,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, proxyURL = account.Proxy.URL() } - // Sanitize thinking blocks (clean cache_control and flatten history thinking) - sanitizeThinkingBlocks(&claudeReq) - // 获取转换选项 // Antigravity 上游要求必须包含身份提示词,否则会返回 429 transformOpts := s.getClaudeTransformOptions(ctx) @@ -537,150 +741,29 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } - // Safety net: ensure no cache_control leaked into Gemini request - geminiBody = cleanCacheControlFromGeminiJSON(geminiBody) - // Antigravity 上游只支持流式请求,统一使用 streamGenerateContent // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" - // URL fallback 循环 - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 - } - - // 重试循环 - var resp *http.Response -urlFallbackLoop: - for urlIdx, baseURL := range availableURLs { - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } - - upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody) - // Capture upstream request body for ops retry of this attempt. - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(geminiBody)) - } - if err != nil { - return nil, err - } - - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - // 检查是否应触发 URL 降级 - if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) - continue urlFallbackLoop - } - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - setOpsUpstreamError(c, 0, safeErr, "") - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - // 检查是否应触发 URL 降级(仅 429) - if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) - continue urlFallbackLoop - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - } - // 最后一次尝试也失败 - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - break urlFallbackLoop - } + // 执行带重试的请求 + result, err := antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: geminiBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + handleError: s.handleUpstreamError, + }) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } + resp := result.resp defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { @@ -739,11 +822,20 @@ urlFallbackLoop: if txErr != nil { continue } - retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody) - if buildErr != nil { - continue - } - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + retryResult, retryErr := antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: retryGeminiBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + handleError: s.handleUpstreamError, + }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -757,6 +849,7 @@ urlFallbackLoop: continue } + retryResp := retryResult.resp if retryResp.StatusCode < 400 { _ = resp.Body.Close() resp = retryResp @@ -766,6 +859,13 @@ urlFallbackLoop: retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) _ = retryResp.Body.Close() + if retryResp.StatusCode == http.StatusTooManyRequests { + retryBaseURL := "" + if retryResp.Request != nil && retryResp.Request.URL != nil { + retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host + } + log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) + } kind := "signature_retry" if strings.TrimSpace(stage.name) != "" { kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_") @@ -920,143 +1020,6 @@ func extractAntigravityErrorMessage(body []byte) string { return "" } -// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix) -// This should not be needed if transformation is correct, but serves as a safety net -func cleanCacheControlFromGeminiJSON(body []byte) []byte { - // Try a more robust approach: parse and clean - var data map[string]any - if err := json.Unmarshal(body, &data); err != nil { - log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err) - return body - } - - cleaned := removeCacheControlFromAny(data) - if !cleaned { - return body - } - - if result, err := json.Marshal(data); err == nil { - log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON") - return result - } - - return body -} - -// removeCacheControlFromAny recursively removes cache_control fields -func removeCacheControlFromAny(v any) bool { - cleaned := false - - switch val := v.(type) { - case map[string]any: - for k, child := range val { - if k == "cache_control" { - delete(val, k) - cleaned = true - } else if removeCacheControlFromAny(child) { - cleaned = true - } - } - case []any: - for _, item := range val { - if removeCacheControlFromAny(item) { - cleaned = true - } - } - } - - return cleaned -} - -// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks -// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement) -// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors -func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) { - if req == nil { - return - } - - log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages)) - - // Clean system blocks - if len(req.System) > 0 { - var systemBlocks []map[string]any - if err := json.Unmarshal(req.System, &systemBlocks); err == nil { - for i := range systemBlocks { - if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil { - if removeCacheControlFromAny(systemBlocks[i]) { - log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i) - } - } - } - // Marshal back - if cleaned, err := json.Marshal(systemBlocks); err == nil { - req.System = cleaned - } - } - } - - // Clean message content blocks and flatten history - lastMsgIdx := len(req.Messages) - 1 - for msgIdx := range req.Messages { - raw := req.Messages[msgIdx].Content - if len(raw) == 0 { - continue - } - - // Try to parse as blocks array - var blocks []map[string]any - if err := json.Unmarshal(raw, &blocks); err != nil { - continue - } - - cleaned := false - for blockIdx := range blocks { - blockType, _ := blocks[blockIdx]["type"].(string) - - // Check for thinking blocks (typed or untyped) - if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil { - // 1. Clean cache_control - if removeCacheControlFromAny(blocks[blockIdx]) { - log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx) - cleaned = true - } - - // 2. Flatten to text if it's a history message (not the last one) - if msgIdx < lastMsgIdx { - log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx) - - // Extract thinking content - var textContent string - if t, ok := blocks[blockIdx]["thinking"].(string); ok { - textContent = t - } else { - // Fallback for non-string content (marshal it) - if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil { - textContent = string(b) - } - } - - // Convert to text block - blocks[blockIdx]["type"] = "text" - blocks[blockIdx]["text"] = textContent - delete(blocks[blockIdx], "thinking") - delete(blocks[blockIdx], "signature") - delete(blocks[blockIdx], "cache_control") // Ensure it's gone - cleaned = true - } - } - } - - // Marshal back if modified - if cleaned { - if marshaled, err := json.Marshal(blocks); err == nil { - req.Messages[msgIdx].Content = marshaled - } - } - } -} - // stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request. // This preserves the thinking content while avoiding signature validation errors. // Note: redacted_thinking blocks are removed because they cannot be converted to text. @@ -1352,138 +1315,25 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" - // URL fallback 循环 - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 - } - - // 重试循环 - var resp *http.Response -urlFallbackLoop: - for urlIdx, baseURL := range availableURLs { - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } - - upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody) - if err != nil { - return nil, err - } - - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - // 检查是否应触发 URL 降级 - if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) - continue urlFallbackLoop - } - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - setOpsUpstreamError(c, 0, safeErr, "") - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - // 检查是否应触发 URL 降级(仅 429) - if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) - continue urlFallbackLoop - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - } - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - break urlFallbackLoop - } + // 执行带重试的请求 + result, err := antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: upstreamAction, + body: wrappedBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + handleError: s.handleUpstreamError, + }) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } + resp := result.resp defer func() { if resp != nil && resp.Body != nil { _ = resp.Body.Close() @@ -1525,8 +1375,6 @@ urlFallbackLoop: goto handleSuccess } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) @@ -1537,6 +1385,7 @@ urlFallbackLoop: if unwrapErr != nil || len(unwrappedForOps) == 0 { unwrappedForOps = respBody } + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -1581,6 +1430,7 @@ urlFallbackLoop: Message: upstreamMsg, Detail: upstreamDetail, }) + log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) c.Data(resp.StatusCode, contentType, unwrappedForOps) return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } @@ -1637,15 +1487,6 @@ handleSuccess: }, nil } -func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool { - switch statusCode { - case 429, 500, 502, 503, 504, 529: - return true - default: - return false - } -} - func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { case 401, 403, 429, 529: @@ -1679,33 +1520,48 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { } } +func antigravityUseScopeRateLimit() bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv))) + return v == "1" || v == "true" || v == "yes" || v == "on" +} + func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { // 429 使用 Gemini 格式解析(从 body 解析重置时间) if statusCode == 429 { + useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != "" resetAt := ParseGeminiRateLimitResetTime(body) if resetAt == nil { - // 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟 - defaultDur := 1 * time.Minute - if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) { - defaultDur = 5 * time.Minute + // 解析失败:使用配置的 fallback 时间,直接限流整个账户 + fallbackMinutes := 5 + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { + fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes } + defaultDur := time.Duration(fallbackMinutes) * time.Minute ra := time.Now().Add(defaultDur) - log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) - if quotaScope == "" { - return - } - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + if useScopeLimit { + log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) + if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { + log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + } + } else { + log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { + log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) + } } return } resetTime := time.Unix(*resetAt, 0) - log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if quotaScope == "" { - return - } - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + if useScopeLimit { + log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) + if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { + log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + } + } else { + log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { + log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) + } } return } @@ -1884,7 +1740,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } // handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端 -// Gemini 流式响应中每个 chunk 都包含累积的完整文本,只需保留最后一个有效响应 +// Gemini 流式响应是增量的,需要累积所有 chunk 的内容 func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -1897,6 +1753,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont var firstTokenMs *int var last map[string]any var lastWithParts map[string]any + var collectedImageParts []map[string]any // 收集所有包含图片的 parts + var collectedTextParts []string // 收集所有文本片段 type scanEvent struct { line string @@ -1999,6 +1857,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont // 保留最后一个有 parts 的响应 if parts := extractGeminiParts(parsed); len(parts) > 0 { lastWithParts = parsed + // 收集包含图片和文本的 parts + for _, part := range parts { + if inlineData, ok := part["inlineData"].(map[string]any); ok { + collectedImageParts = append(collectedImageParts, part) + _ = inlineData // 避免 unused 警告 + } + if text, ok := part["text"].(string); ok && text != "" { + collectedTextParts = append(collectedTextParts, text) + } + } } case <-intervalCh: @@ -2020,6 +1888,16 @@ returnResponse: log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") } + // 如果收集到了图片 parts,需要合并到最终响应中 + if len(collectedImageParts) > 0 { + finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts) + } + + // 如果收集到了文本,需要合并到最终响应中 + if len(collectedTextParts) > 0 { + finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts) + } + respBody, err := json.Marshal(finalResponse) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) @@ -2029,6 +1907,115 @@ returnResponse: return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil } +// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调 +func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) { + // 深拷贝 response + result = make(map[string]any) + for k, v := range response { + result[k] = v + } + + // 获取或创建 candidates + candidates, ok := result["candidates"].([]any) + if !ok || len(candidates) == 0 { + candidates = []any{map[string]any{}} + } + + // 获取第一个 candidate + candidate, ok := candidates[0].(map[string]any) + if !ok { + candidate = make(map[string]any) + candidates[0] = candidate + } + + // 获取或创建 content + content, ok := candidate["content"].(map[string]any) + if !ok { + content = map[string]any{"role": "model"} + candidate["content"] = content + } + + // 获取现有 parts + existingParts, ok = content["parts"].([]any) + if !ok { + existingParts = []any{} + } + + // 返回更新回调 + setParts = func(newParts []any) { + content["parts"] = newParts + result["candidates"] = candidates + } + + return result, existingParts, setParts +} + +// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中 +func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any { + if len(imageParts) == 0 { + return response + } + + result, existingParts, setParts := getOrCreateGeminiParts(response) + + // 检查现有 parts 中是否已经有图片 + for _, p := range existingParts { + if pm, ok := p.(map[string]any); ok { + if _, hasInline := pm["inlineData"]; hasInline { + return result // 已有图片,不重复添加 + } + } + } + + // 添加收集到的图片 parts + for _, imgPart := range imageParts { + existingParts = append(existingParts, imgPart) + } + setParts(existingParts) + return result +} + +// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中 +func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any { + if len(textParts) == 0 { + return response + } + + mergedText := strings.Join(textParts, "") + result, existingParts, setParts := getOrCreateGeminiParts(response) + + // 查找并更新第一个 text part,或创建新的 + newParts := make([]any, 0, len(existingParts)+1) + textUpdated := false + + for _, p := range existingParts { + pm, ok := p.(map[string]any) + if !ok { + newParts = append(newParts, p) + continue + } + if _, hasText := pm["text"]; hasText && !textUpdated { + // 用累积的文本替换 + newPart := make(map[string]any) + for k, v := range pm { + newPart[k] = v + } + newPart["text"] = mergedText + newParts = append(newParts, newPart) + textUpdated = true + } else { + newParts = append(newParts, pm) + } + } + + if !textUpdated { + newParts = append([]any{map[string]any{"text": mergedText}}, newParts...) + } + + setParts(newParts) + return result +} + func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error { c.JSON(status, gin.H{ "type": "error", diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index ecf0a553..52293cd5 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct { // AntigravityTokenInfo token 信息 type AntigravityTokenInfo struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - ExpiresAt int64 `json:"expires_at"` - TokenType string `json:"token_type"` - Email string `json:"email,omitempty"` - ProjectID string `json:"project_id,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + TokenType string `json:"token_type"` + Email string `json:"email,omitempty"` + ProjectID string `json:"project_id,omitempty"` + ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id } // ExchangeCode 用 authorization code 交换 token @@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig result.ProjectID = loadResp.CloudAICompanionProject } - // 兜底:随机生成 project_id - if result.ProjectID == "" { - result.ProjectID = antigravity.GenerateMockProjectID() - fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID) - } - return result, nil } @@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou return nil, err } - // 保留原有的 project_id 和 email - existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) - if existingProjectID != "" { - tokenInfo.ProjectID = existingProjectID - } + // 保留原有的 email existingEmail := strings.TrimSpace(account.GetCredential("email")) if existingEmail != "" { tokenInfo.Email = existingEmail } + // 每次刷新都调用 LoadCodeAssist 获取 project_id + client := antigravity.NewClient(proxyURL) + loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken) + if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" { + // LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失 + existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) + tokenInfo.ProjectID = existingProjectID + tokenInfo.ProjectIDMissing = true + } else { + tokenInfo.ProjectID = loadResp.CloudAICompanionProject + } + return tokenInfo, nil } diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index c9024e33..07eb563d 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - // 如果没有 project_id,生成一个随机的 - if projectID == "" { - projectID = antigravity.GenerateMockProjectID() - } - client := antigravity.NewClient(proxyURL) // 调用 API 获取配额 diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go new file mode 100644 index 00000000..53ec6fdf --- /dev/null +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -0,0 +1,190 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +type stubAntigravityUpstream struct { + firstBase string + secondBase string + calls []string +} + +func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + url := req.URL.String() + s.calls = append(s.calls, url) + if strings.HasPrefix(url, s.firstBase) { + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)), + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil +} + +func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return s.Do(req, proxyURL, accountID, accountConcurrency) +} + +type scopeLimitCall struct { + accountID int64 + scope AntigravityQuotaScope + resetAt time.Time +} + +type rateLimitCall struct { + accountID int64 + resetAt time.Time +} + +type stubAntigravityAccountRepo struct { + AccountRepository + scopeCalls []scopeLimitCall + rateCalls []rateLimitCall +} + +func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { + s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt}) + return nil +} + +func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt}) + return nil +} + +func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvailability := antigravity.DefaultURLAvailability + defer func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvailability + }() + + base1 := "https://ag-1.test" + base2 := "https://ag-2.test" + antigravity.BaseURLs = []string{base1, base2} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + + upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2} + account := &Account{ + ID: 1, + Name: "acc-1", + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + var handleErrorCalled bool + result, err := antigravityRetryLoop(antigravityRetryLoopParams{ + prefix: "[test]", + ctx: context.Background(), + account: account, + proxyURL: "", + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + quotaScope: AntigravityQuotaScopeClaude, + httpUpstream: upstream, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { + handleErrorCalled = true + }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.False(t, handleErrorCalled) + require.Len(t, upstream.calls, 2) + require.True(t, strings.HasPrefix(upstream.calls[0], base1)) + require.True(t, strings.HasPrefix(upstream.calls[1], base2)) + + available := antigravity.DefaultURLAvailability.GetAvailableURLs() + require.NotEmpty(t, available) + require.Equal(t, base2, available[0]) +} + +func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) { + t.Setenv(antigravityScopeRateLimitEnv, "true") + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("3s") + svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) + + require.Len(t, repo.scopeCalls, 1) + require.Empty(t, repo.rateCalls) + call := repo.scopeCalls[0] + require.Equal(t, account.ID, call.accountID) + require.Equal(t, AntigravityQuotaScopeClaude, call.scope) + require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second) +} + +func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) { + t.Setenv(antigravityScopeRateLimitEnv, "false") + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("2s") + svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) + + require.Len(t, repo.rateCalls, 1) + require.Empty(t, repo.scopeCalls) + call := repo.rateCalls[0] + require.Equal(t, account.ID, call.accountID) + require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second) +} + +func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { + now := time.Now() + future := now.Add(10 * time.Minute) + + account := &Account{ + ID: 1, + Name: "acc", + Platform: PlatformAntigravity, + Status: StatusActive, + Schedulable: true, + } + + account.RateLimitResetAt = &future + require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5")) + require.False(t, account.IsSchedulableForModel("gemini-3-flash")) + + account.RateLimitResetAt = nil + account.Extra = map[string]any{ + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": future.Format(time.RFC3339), + }, + }, + } + + require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5")) + require.True(t, account.IsSchedulableForModel("gemini-3-flash")) +} + +func buildGeminiRateLimitBody(delay string) []byte { + return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay)) +} diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index 9dd4463f..a07c86e6 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun } } + // 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记 + if tokenInfo.ProjectIDMissing { + return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity") + } + return newCredentials, nil } diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 521f1da5..eb5c7534 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) { s.authCacheL1 = cache } +// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation. +// This should be called after the service is fully initialized. +func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) { + if s.cache == nil || s.authCacheL1 == nil { + return + } + if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) { + s.authCacheL1.Del(cacheKey) + }); err != nil { + // Log but don't fail - L1 cache will still work, just without cross-instance invalidation + println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error()) + } +} + func (s *APIKeyService) authCacheKey(key string) string { sum := sha256.Sum256([]byte(key)) return hex.EncodeToString(sum[:]) @@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { return } _ = s.cache.DeleteAuthCache(ctx, cacheKey) + // Publish invalidation message to other instances + _ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey) } func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index ecc570c7..ef1ff990 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -65,6 +65,10 @@ type APIKeyCache interface { GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error DeleteAuthCache(ctx context.Context, key string) error + + // Pub/Sub for L1 cache invalidation across instances + PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error + SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error } // APIKeyAuthCacheInvalidator 提供认证缓存失效能力 diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 5f2d69c4..c5e9cd47 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { return nil } +func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return nil +} + +func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + return nil +} + func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { cache := &authCacheStub{} repo := &authRepoStub{ diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 32ae884e..092b7fce 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error return nil } +func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return nil +} + +func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + return nil +} + // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // 预期行为: // - GetKeyAndOwnerID 返回所有者 ID 为 1 diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index a38c34cd..9850cbf0 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error { return nil } +func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a08f3e48..9565da29 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -11,6 +11,8 @@ import ( "fmt" "io" "log" + "log/slog" + mathrand "math/rand" "net/http" "os" "regexp" @@ -445,11 +447,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制) +// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { + // 调试日志:记录调度入口参数 + excludedIDsList := make([]int64, 0, len(excludedIDs)) + for id := range excludedIDs { + excludedIDsList = append(excludedIDsList, id) + } + slog.Debug("account_scheduling_starting", + "group_id", derefGroupID(groupID), + "model", requestedModel, + "session", shortSessionHash(sessionHash), + "excluded_ids", excludedIDsList) + cfg := s.schedulingConfig() - // 提取会话 UUID(用于会话数量限制) - sessionUUID := extractSessionUUID(metadataUserID) var stickyAccountID int64 if sessionHash != "" && s.cache != nil { @@ -475,41 +486,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) - if err != nil { - return nil, err + // 复制排除列表,用于会话限制拒绝时的重试 + localExcluded := make(map[int64]struct{}) + for k, v := range excludedIDs { + localExcluded[k] = v } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { + + for { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded) + if err != nil { + return nil, err + } + + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + // 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符) + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + result.ReleaseFunc() // 释放槽位 + localExcluded[account.ID] = struct{}{} // 排除此账号 + continue // 重新选择 + } return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, }, nil } + + // 对于等待计划的情况,也需要先检查会话限制 + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + localExcluded[account.ID] = struct{}{} + continue + } + + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil } platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) @@ -625,7 +658,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) { + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 // 继续到负载感知选择 } else { @@ -643,15 +676,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 } @@ -714,7 +752,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } @@ -732,20 +770,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 5. 所有路由账号槽位满,返回等待计划(选择负载最低的) - acc := routingAvailable[0].account - if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID) + // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的) + // 遍历找到第一个满足会话限制的账号 + for _, item := range routingAvailable { + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { + continue // 会话限制已满,尝试下一个 + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + } + return &AccountSelectionResult{ + Account: item.account, + WaitPlan: &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) @@ -773,7 +817,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err == nil && result.Acquired { // 会话数量限制检查 // Session count limit check - if !s.checkAndRegisterSession(ctx, account, sessionUUID) { + if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) @@ -787,15 +831,22 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + // 会话数量限制检查(等待计划也需要占用会话配额) + // Session count limit check (wait plan also requires session quota) + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + // 会话限制已满,继续到 Layer 2 + // Session limit full, continue to Layer 2 + } else { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } } } @@ -845,7 +896,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { return result, nil } } else { @@ -895,7 +946,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } @@ -913,8 +964,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } // ============ Layer 3: 兜底排队 ============ - sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) + s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode) for _, acc := range candidates { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, acc, sessionHash) { + continue // 会话限制已满,尝试下一个账号 + } return &AccountSelectionResult{ Account: acc, WaitPlan: &AccountWaitPlan{ @@ -928,7 +983,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, errors.New("no available accounts") } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -936,7 +991,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, acc, sessionUUID) { + if !s.checkAndRegisterSession(ctx, acc, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } @@ -1093,7 +1148,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { if s.schedulerSnapshot != nil { - return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err == nil { + slog.Debug("account_scheduling_list_snapshot", + "group_id", derefGroupID(groupID), + "platform", platform, + "use_mixed", useMixed, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + } + return accounts, useMixed, err } useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform if useMixed { @@ -1106,6 +1178,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) } if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) return nil, useMixed, err } filtered := make([]Account, 0, len(accounts)) @@ -1115,6 +1191,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i } filtered = append(filtered, acc) } + slog.Debug("account_scheduling_list_mixed", + "group_id", derefGroupID(groupID), + "platform", platform, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } return filtered, useMixed, nil } @@ -1129,8 +1219,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) } if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) return nil, useMixed, err } + slog.Debug("account_scheduling_list_single", + "group_id", derefGroupID(groupID), + "platform", platform, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } return accounts, useMixed, nil } @@ -1196,12 +1303,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, // 缓存未命中,从数据库查询 { - var startTime time.Time - if account.SessionWindowStart != nil { - startTime = *account.SessionWindowStart - } else { - startTime = time.Now().Add(-5 * time.Hour) - } + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := account.GetCurrentWindowStartTime() stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) if err != nil { @@ -1234,15 +1337,16 @@ checkSchedulability: // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 +// sessionID: 会话标识符(使用粘性会话的 hash) // 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) -func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool { +func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool { // 只检查 Anthropic OAuth/SetupToken 账号 if !account.IsAnthropicOAuthOrSetupToken() { return true } maxSessions := account.GetMaxSessions() - if maxSessions <= 0 || sessionUUID == "" { + if maxSessions <= 0 || sessionID == "" { return true // 未启用会话限制或无会话ID } @@ -1252,7 +1356,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute - allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout) + allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout) if err != nil { // 失败开放:缓存错误时允许通过 return true @@ -1260,18 +1364,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A return allowed } -// extractSessionUUID 从 metadata.user_id 中提取会话 UUID -// 格式: user_{64位hex}_account__session_{uuid} -func extractSessionUUID(metadataUserID string) string { - if metadataUserID == "" { - return "" - } - if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 { - return match[1] - } - return "" -} - func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { if s.schedulerSnapshot != nil { return s.schedulerSnapshot.GetAccount(ctx, accountID) @@ -1301,6 +1393,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { }) } +// sortCandidatesForFallback 根据配置选择排序策略 +// mode: "last_used"(按最后使用时间) 或 "random"(随机) +func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { + if mode == "random" { + // 先按优先级排序,然后在同优先级内随机打乱 + sortAccountsByPriorityOnly(accounts, preferOAuth) + shuffleWithinPriority(accounts) + } else { + // 默认按最后使用时间排序 + sortAccountsByPriorityAndLastUsed(accounts, preferOAuth) + } +} + +// sortAccountsByPriorityOnly 仅按优先级排序 +func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + }) +} + +// shuffleWithinPriority 在同优先级内随机打乱顺序 +func shuffleWithinPriority(accounts []*Account) { + if len(accounts) <= 1 { + return + } + r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) + start := 0 + for start < len(accounts) { + priority := accounts[start].Priority + end := start + 1 + for end < len(accounts) && accounts[end].Priority == priority { + end++ + } + // 对 [start, end) 范围内的账户随机打乱 + if end-start > 1 { + r.Shuffle(end-start, func(i, j int) { + accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i] + }) + } + start = end + } +} + // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { preferOAuth := platform == PlatformGemini @@ -2158,6 +2300,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A proxyURL = account.Proxy.URL() } + // 调试日志:记录即将转发的账号信息 + log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", + account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + // 重试循环 var resp *http.Response retryStart := time.Now() @@ -2172,7 +2318,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 发送请求 - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { if resp != nil && resp.Body != nil { _ = resp.Body.Close() @@ -2246,7 +2392,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) if buildErr == nil { - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { if retryResp.StatusCode < 400 { log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) @@ -2278,7 +2424,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel) if buildErr2 == nil { - retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency) + retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { resp = retryResp2 break @@ -2393,6 +2539,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) + // 调试日志:打印重试耗尽后的错误响应 + log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + s.handleRetryExhaustedSideEffects(ctx, resp, account) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -2420,6 +2570,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) + // 调试日志:打印上游错误响应 + log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + s.handleFailoverSideEffects(ctx, resp, account) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -2549,9 +2703,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex fingerprint = fp // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { body = newBody } } @@ -2770,6 +2925,10 @@ func extractUpstreamErrorMessage(body []byte) string { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + // 调试日志:打印上游错误响应 + log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -3478,7 +3637,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 发送请求 - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") @@ -3500,7 +3659,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) if buildErr == nil { - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp respBody, err = io.ReadAll(resp.Body) @@ -3578,12 +3737,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // OAuth 账号:应用统一指纹和重写 userID + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 if account.IsOAuth() && s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil { accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { body = newBody } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 0ddd72b8..c63a020c 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error { return nil } +func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return nil } diff --git a/backend/internal/service/http_upstream_port.go b/backend/internal/service/http_upstream_port.go index 9357f763..0e4cfbec 100644 --- a/backend/internal/service/http_upstream_port.go +++ b/backend/internal/service/http_upstream_port.go @@ -10,6 +10,7 @@ import "net/http" // - 支持可选代理配置 // - 支持账户级连接池隔离 // - 实现类负责连接池管理和复用 +// - 支持可选的 TLS 指纹伪装 type HTTPUpstream interface { // Do 执行 HTTP 请求 // @@ -27,4 +28,28 @@ type HTTPUpstream interface { // - 调用方必须关闭 resp.Body,否则会导致连接泄漏 // - 响应体可能已被包装以跟踪请求生命周期 Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) + + // DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求 + // + // 参数: + // - req: HTTP 请求对象,由调用方构建 + // - proxyURL: 代理服务器地址,空字符串表示直连 + // - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择 + // - accountConcurrency: 账户并发限制,用于动态调整连接池大小 + // - enableTLSFingerprint: 是否启用 TLS 指纹伪装 + // + // 返回: + // - *http.Response: HTTP 响应,调用方必须关闭 Body + // - error: 请求错误(网络错误、超时等) + // + // TLS 指纹说明: + // - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹 + // - TLS 指纹模板根据 accountID % len(profiles) 自动选择 + // - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景 + // - 如果 enableTLSFingerprint=false,行为与 Do 方法相同 + // + // 注意: + // - 调用方必须关闭 resp.Body,否则会导致连接泄漏 + // - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响 + DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) } diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 1ffa8057..e2e723b0 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -8,9 +8,11 @@ import ( "encoding/json" "fmt" "log" + "log/slog" "net/http" "regexp" "strconv" + "strings" "time" ) @@ -49,6 +51,13 @@ type Fingerprint struct { type IdentityCache interface { GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error) SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error + // GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能) + // 返回的 sessionID 是一个 UUID 格式的字符串 + // 如果不存在或已过期(15分钟无请求),返回空字符串 + GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) + // SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟 + // 每次调用都会刷新 TTL + SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error } // IdentityService 管理OAuth账号的请求身份指纹 @@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI return json.Marshal(reqMap) } +// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 +// 如果账号启用了会话ID伪装(session_id_masking_enabled), +// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变) +func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { + // 先执行常规的 RewriteUserID 逻辑 + newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) + if err != nil { + return newBody, err + } + + // 检查是否启用会话ID伪装 + if !account.IsSessionIDMaskingEnabled() { + return newBody, nil + } + + // 解析重写后的 body,提取 user_id + var reqMap map[string]any + if err := json.Unmarshal(newBody, &reqMap); err != nil { + return newBody, nil + } + + metadata, ok := reqMap["metadata"].(map[string]any) + if !ok { + return newBody, nil + } + + userID, ok := metadata["user_id"].(string) + if !ok || userID == "" { + return newBody, nil + } + + // 查找 _session_ 的位置,替换其后的内容 + const sessionMarker = "_session_" + idx := strings.LastIndex(userID, sessionMarker) + if idx == -1 { + return newBody, nil + } + + // 获取或生成固定的伪装 session ID + maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID) + if err != nil { + log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err) + return newBody, nil + } + + if maskedSessionID == "" { + // 首次或已过期,生成新的伪装 session ID + maskedSessionID = generateRandomUUID() + log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) + } + + // 刷新 TTL(每次请求都刷新,保持 15 分钟有效期) + if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil { + log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err) + } + + // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 + newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID + + slog.Debug("session_id_masking_applied", + "account_id", account.ID, + "before", userID, + "after", newUserID, + ) + + metadata["user_id"] = newUserID + reqMap["metadata"] = metadata + + return json.Marshal(reqMap) +} + +// generateRandomUUID 生成随机 UUID v4 格式字符串 +func generateRandomUUID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + // fallback: 使用时间戳生成 + h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano()))) + b = h[:16] + } + + // 设置 UUID v4 版本和变体位 + b[6] = (b[6] & 0x0f) | 0x40 + b[8] = (b[8] & 0x3f) | 0x80 + + return fmt.Sprintf("%x-%x-%x-%x-%x", + b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) +} + // generateClientID 生成64位十六进制客户端ID(32字节随机数) func generateClientID() string { b := make([]byte, 32) diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 2d75dd5a..41bd253c 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc return false } - tempMatched := false + // 先尝试临时不可调度规则(401除外) + // 如果匹配成功,直接返回,不执行后续禁用逻辑 if statusCode != 401 { - tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody) + if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { + return true + } } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if upstreamMsg != "" { @@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } switch statusCode { + case 400: + // 只有当错误信息包含 "organization has been disabled" 时才禁用 + if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") { + msg := "Organization disabled (400): " + upstreamMsg + s.handleAuthError(ctx, account, msg) + shouldDisable = true + } + // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 if account.Type == AccountTypeOAuth { @@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } } - if tempMatched { - return true - } return shouldDisable } diff --git a/backend/internal/service/session_limit_cache.go b/backend/internal/service/session_limit_cache.go index f6f0c26a..5482d610 100644 --- a/backend/internal/service/session_limit_cache.go +++ b/backend/internal/service/session_limit_cache.go @@ -38,8 +38,9 @@ type SessionLimitCache interface { GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 + // idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时 // 返回 map[accountID]count,查询失败的账号不在 map 中 - GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) + GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) // IsSessionActive 检查特定会话是否活跃(未过期) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 26cfd97d..02e7d445 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { newCredentials, err := refresher.Refresh(ctx, account) - if err == nil { - // 刷新成功,更新账号credentials + + // 如果有新凭证,先更新(即使有错误也要保存 token) + if newCredentials != nil { account.Credentials = newCredentials - if err := s.accountRepo.Update(ctx, account); err != nil { - return fmt.Errorf("failed to save credentials: %w", err) + if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { + return fmt.Errorf("failed to save credentials: %w", saveErr) + } + } + + if err == nil { + // Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态 + if account.Platform == PlatformAntigravity && + account.Status == StatusError && + strings.Contains(account.ErrorMessage, "missing_project_id:") { + if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil { + log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr) + } else { + log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID) + } } // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { @@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool { "invalid_client", // 客户端配置错误 "unauthorized_client", // 客户端未授权 "access_denied", // 访问被拒绝 + "missing_project_id", // 缺少 project_id } for _, needle := range nonRetryable { if strings.Contains(msg, needle) { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 0b9bc20c..b210286d 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -1,6 +1,7 @@ package service import ( + "context" "database/sql" "time" @@ -196,6 +197,8 @@ func ProvideOpsScheduledReportService( // ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力 func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator { + // Start Pub/Sub subscriber for L1 cache invalidation across instances + apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background()) return apiKeyService } diff --git a/deploy/README.md b/deploy/README.md index f697247d..c42e7552 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -401,3 +401,58 @@ sudo systemctl status redis 2. **Database connection failed**: Check PostgreSQL is running and credentials are correct 3. **Redis connection failed**: Check Redis is running and password is correct 4. **Permission denied**: Ensure proper file ownership for binary install + +--- + +## TLS Fingerprint Configuration + +Sub2API supports TLS fingerprint simulation to make requests appear as if they come from the official Claude CLI (Node.js client). + +### Default Behavior + +- Built-in `claude_cli_v2` profile simulates Node.js 20.x + OpenSSL 3.x +- JA3 Hash: `1a28e69016765d92e3b381168d68922c` +- JA4: `t13d5911h1_a33745022dd6_1f22a2ca17c4` +- Profile selection: `accountID % profileCount` + +### Configuration + +```yaml +gateway: + tls_fingerprint: + enabled: true # Global switch + profiles: + # Simple profile (uses default cipher suites) + profile_1: + name: "Profile 1" + + # Profile with custom cipher suites (use compact array format) + profile_2: + name: "Profile 2" + cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] + curves: [29, 23, 24] + point_formats: [0] + + # Another custom profile + profile_3: + name: "Profile 3" + cipher_suites: [4865, 4866, 4867, 49199, 49200] + curves: [29, 23, 24, 25] +``` + +### Profile Fields + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Display name (required) | +| `cipher_suites` | []uint16 | Cipher suites in decimal. Empty = default | +| `curves` | []uint16 | Elliptic curves in decimal. Empty = default | +| `point_formats` | []uint8 | EC point formats. Empty = default | + +### Common Values Reference + +**Cipher Suites (TLS 1.3):** `4865` (AES_128_GCM), `4866` (AES_256_GCM), `4867` (CHACHA20) + +**Cipher Suites (TLS 1.2):** `49195`, `49196`, `49199`, `49200` (ECDHE variants) + +**Curves:** `29` (X25519), `23` (P-256), `24` (P-384), `25` (P-521) diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 1f4aa266..558b8ef0 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -210,6 +210,19 @@ gateway: outbox_backlog_rebuild_rows: 10000 # 全量重建周期(秒),0 表示禁用 full_rebuild_interval_seconds: 300 + # TLS fingerprint simulation / TLS 指纹伪装 + # Default profile "claude_cli_v2" simulates Node.js 20.x + # 默认模板 "claude_cli_v2" 模拟 Node.js 20.x 指纹 + tls_fingerprint: + enabled: true + # profiles: + # profile_1: + # name: "Custom Profile 1" + # profile_2: + # name: "Custom Profile 2" + # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] + # curves: [29, 23, 24] + # point_formats: [0] # ============================================================================= # API Key Auth Cache Configuration diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index c81de00e..7906cd6b 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1191,6 +1191,190 @@ + +
+
+

{{ t('admin.accounts.quotaControl.title') }}

+

+ {{ t('admin.accounts.quotaControl.hint') }} +

+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.windowCost.hint') }} +

+
+ +
+ +
+
+ +
+ $ + +
+

{{ t('admin.accounts.quotaControl.windowCost.limitHint') }}

+
+
+ +
+ $ + +
+

{{ t('admin.accounts.quotaControl.windowCost.stickyReserveHint') }}

+
+
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.sessionLimit.hint') }} +

+
+ +
+ +
+
+ + +

{{ t('admin.accounts.quotaControl.sessionLimit.maxSessionsHint') }}

+
+
+ +
+ + {{ t('common.minutes') }} +
+

{{ t('admin.accounts.quotaControl.sessionLimit.idleTimeoutHint') }}

+
+
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }} +

+
+ +
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }} +

+
+ +
+
+
+
@@ -1214,7 +1398,7 @@
- +

{{ t('admin.accounts.billingRateMultiplierHint') }}

@@ -1763,6 +1947,16 @@ const geminiAIStudioOAuthEnabled = ref(false) const showAdvancedOAuth = ref(false) const showGeminiHelpDialog = ref(false) +// Quota control state (Anthropic OAuth/SetupToken only) +const windowCostEnabled = ref(false) +const windowCostLimit = ref(null) +const windowCostStickyReserve = ref(null) +const sessionLimitEnabled = ref(false) +const maxSessions = ref(null) +const sessionIdleTimeout = ref(null) +const tlsFingerprintEnabled = ref(false) +const sessionIdMaskingEnabled = ref(false) + // Gemini tier selection (used as fallback when auto-detection is unavailable/fails) const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') const geminiTierGcp = ref<'gcp_standard' | 'gcp_enterprise'>('gcp_standard') @@ -2140,6 +2334,15 @@ const resetForm = () => { customErrorCodeInput.value = null interceptWarmupRequests.value = false autoPauseOnExpired.value = true + // Reset quota control state + windowCostEnabled.value = false + windowCostLimit.value = null + windowCostStickyReserve.value = null + sessionLimitEnabled.value = false + maxSessions.value = null + sessionIdleTimeout.value = null + tlsFingerprintEnabled.value = false + sessionIdMaskingEnabled.value = false tempUnschedEnabled.value = false tempUnschedRules.value = [] geminiOAuthType.value = 'code_assist' @@ -2407,7 +2610,32 @@ const handleAnthropicExchange = async (authCode: string) => { ...proxyConfig }) - const extra = oauth.buildExtraInfo(tokenInfo) + // Build extra with quota control settings + const baseExtra = oauth.buildExtraInfo(tokenInfo) || {} + const extra: Record = { ...baseExtra } + + // Add window cost limit settings + if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) { + extra.window_cost_limit = windowCostLimit.value + extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10 + } + + // Add session limit settings + if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) { + extra.max_sessions = maxSessions.value + extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5 + } + + // Add TLS fingerprint settings + if (tlsFingerprintEnabled.value) { + extra.enable_tls_fingerprint = true + } + + // Add session ID masking settings + if (sessionIdMaskingEnabled.value) { + extra.session_id_masking_enabled = true + } + const credentials = { ...tokenInfo, ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) @@ -2475,7 +2703,32 @@ const handleCookieAuth = async (sessionKey: string) => { ...proxyConfig }) - const extra = oauth.buildExtraInfo(tokenInfo) + // Build extra with quota control settings + const baseExtra = oauth.buildExtraInfo(tokenInfo) || {} + const extra: Record = { ...baseExtra } + + // Add window cost limit settings + if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) { + extra.window_cost_limit = windowCostLimit.value + extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10 + } + + // Add session limit settings + if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) { + extra.max_sessions = maxSessions.value + extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5 + } + + // Add TLS fingerprint settings + if (tlsFingerprintEnabled.value) { + extra.enable_tls_fingerprint = true + } + + // Add session ID masking settings + if (sessionIdMaskingEnabled.value) { + extra.session_id_masking_enabled = true + } + const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name // Merge interceptWarmupRequests into credentials diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index d27364f1..81d10932 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -566,7 +566,7 @@
- +

{{ t('admin.accounts.billingRateMultiplierHint') }}

@@ -732,6 +732,60 @@ + + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }} +

+
+ +
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }} +

+
+ +
+
@@ -904,6 +958,8 @@ const windowCostStickyReserve = ref(null) const sessionLimitEnabled = ref(false) const maxSessions = ref(null) const sessionIdleTimeout = ref(null) +const tlsFingerprintEnabled = ref(false) +const sessionIdMaskingEnabled = ref(false) // Computed: current preset mappings based on platform const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic')) @@ -1237,6 +1293,8 @@ function loadQuotaControlSettings(account: Account) { sessionLimitEnabled.value = false maxSessions.value = null sessionIdleTimeout.value = null + tlsFingerprintEnabled.value = false + sessionIdMaskingEnabled.value = false // Only applies to Anthropic OAuth/SetupToken accounts if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { @@ -1255,6 +1313,16 @@ function loadQuotaControlSettings(account: Account) { maxSessions.value = account.max_sessions sessionIdleTimeout.value = account.session_idle_timeout_minutes ?? 5 } + + // Load TLS fingerprint setting + if (account.enable_tls_fingerprint === true) { + tlsFingerprintEnabled.value = true + } + + // Load session ID masking setting + if (account.session_id_masking_enabled === true) { + sessionIdMaskingEnabled.value = true + } } function formatTempUnschedKeywords(value: unknown) { @@ -1407,6 +1475,20 @@ const handleSubmit = async () => { delete newExtra.session_idle_timeout_minutes } + // TLS fingerprint setting + if (tlsFingerprintEnabled.value) { + newExtra.enable_tls_fingerprint = true + } else { + delete newExtra.enable_tls_fingerprint + } + + // Session ID masking setting + if (sessionIdMaskingEnabled.value) { + newExtra.session_id_masking_enabled = true + } else { + delete newExtra.session_id_masking_enabled + } + updatePayload.extra = newExtra } diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue index fd8742c3..9d2b40fb 100644 --- a/frontend/src/components/layout/AppHeader.vue +++ b/frontend/src/components/layout/AppHeader.vue @@ -21,8 +21,20 @@
- +
+ + + + + + @@ -211,6 +223,7 @@ const user = computed(() => authStore.user) const dropdownOpen = ref(false) const dropdownRef = ref(null) const contactInfo = computed(() => appStore.contactInfo) +const docUrl = computed(() => appStore.docUrl) // 只在标准模式的管理员下显示新手引导按钮 const showOnboardingButton = computed(() => { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index a237ad53..362a1349 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -196,7 +196,8 @@ export default { expand: 'Expand', logout: 'Logout', github: 'GitHub', - mySubscriptions: 'My Subscriptions' + mySubscriptions: 'My Subscriptions', + docs: 'Docs' }, // Auth @@ -1288,6 +1289,14 @@ export default { idleTimeout: 'Idle Timeout', idleTimeoutPlaceholder: '5', idleTimeoutHint: 'Sessions will be released after idle timeout' + }, + tlsFingerprint: { + label: 'TLS Fingerprint Simulation', + hint: 'Simulate Node.js/Claude Code client TLS fingerprint' + }, + sessionIdMasking: { + label: 'Session ID Masking', + hint: 'When enabled, fixes the session ID in metadata.user_id for 15 minutes, making upstream think requests come from the same session' } }, expired: 'Expired', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index bfe36c1f..1efd3867 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -193,7 +193,8 @@ export default { expand: '展开', logout: '退出登录', github: 'GitHub', - mySubscriptions: '我的订阅' + mySubscriptions: '我的订阅', + docs: '文档' }, // Auth @@ -1420,6 +1421,14 @@ export default { idleTimeout: '空闲超时', idleTimeoutPlaceholder: '5', idleTimeoutHint: '会话空闲超时后自动释放' + }, + tlsFingerprint: { + label: 'TLS 指纹模拟', + hint: '模拟 Node.js/Claude Code 客户端的 TLS 指纹' + }, + sessionIdMasking: { + label: '会话 ID 伪装', + hint: '启用后将在 15 分钟内固定 metadata.user_id 中的 session ID,使上游认为请求来自同一会话' } }, expired: '已过期', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 353ccb83..35e256e6 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -480,6 +480,13 @@ export interface Account { max_sessions?: number | null session_idle_timeout_minutes?: number | null + // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) + enable_tls_fingerprint?: boolean | null + + // 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效) + // 启用后将在15分钟内固定 metadata.user_id 中的 session ID + session_id_masking_enabled?: boolean | null + // 运行时状态(仅当启用对应限制时返回) current_window_cost?: number | null // 当前窗口费用 active_sessions?: number | null // 当前活跃会话数 diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index 96457172..47a15084 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -243,7 +243,7 @@ />

{{ t('admin.groups.platformHint') }}

-
+

{{ t('admin.groups.platformNotEditable') }}

-
+
{ } } -// 监听 subscription_type 变化,订阅模式时重置 rate_multiplier 为 1,is_exclusive 为 true +// 监听 subscription_type 变化,订阅模式时 is_exclusive 默认为 true watch( () => createForm.subscription_type, (newVal) => { if (newVal === 'subscription') { - createForm.rate_multiplier = 1.0 createForm.is_exclusive = true } } diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index a1731cfb..82ae3f9f 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -21,5 +21,6 @@ "types": ["vite/client"] }, "include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"], + "exclude": ["src/**/__tests__/**", "src/**/*.spec.ts", "src/**/*.test.ts"], "references": [{ "path": "./tsconfig.node.json" }] }