diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go
index c9dc57bb..f8a7d313 100644
--- a/backend/cmd/server/main.go
+++ b/backend/cmd/server/main.go
@@ -8,6 +8,7 @@ import (
"errors"
"flag"
"log"
+ "log/slog"
"net/http"
"os"
"os/signal"
@@ -44,7 +45,25 @@ func init() {
}
}
+// initLogger configures the default slog handler based on gin.Mode().
+// In non-release mode, Debug level logs are enabled.
+func initLogger() {
+ var level slog.Level
+ if gin.Mode() == gin.ReleaseMode {
+ level = slog.LevelInfo
+ } else {
+ level = slog.LevelDebug
+ }
+ handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+ Level: level,
+ })
+ slog.SetDefault(slog.New(handler))
+}
+
func main() {
+ // Initialize slog logger based on gin mode
+ initLogger()
+
// Parse command line flags
setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode")
showVersion := flag.Bool("version", false, "Show version information")
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 42084b37..00a78480 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -258,8 +258,43 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
+ // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
+ MaxAccountSwitches int `mapstructure:"max_account_switches"`
+ // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
+ MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
+
+ // Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用
+ AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
+
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
+
+ // TLSFingerprint: TLS指纹伪装配置
+ TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
+}
+
+// TLSFingerprintConfig TLS指纹伪装配置
+// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
+type TLSFingerprintConfig struct {
+ // Enabled: 是否全局启用TLS指纹功能
+ Enabled bool `mapstructure:"enabled"`
+ // Profiles: 预定义的TLS指纹配置模板
+ // key 为模板名称,如 "claude_cli_v2", "chrome_120" 等
+ Profiles map[string]TLSProfileConfig `mapstructure:"profiles"`
+}
+
+// TLSProfileConfig 单个TLS指纹模板的配置
+type TLSProfileConfig struct {
+ // Name: 模板显示名称
+ Name string `mapstructure:"name"`
+ // EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用)
+ EnableGREASE bool `mapstructure:"enable_grease"`
+ // CipherSuites: TLS加密套件列表(空则使用内置默认值)
+ CipherSuites []uint16 `mapstructure:"cipher_suites"`
+ // Curves: 椭圆曲线列表(空则使用内置默认值)
+ Curves []uint16 `mapstructure:"curves"`
+ // PointFormats: 点格式列表(空则使用内置默认值)
+ PointFormats []uint8 `mapstructure:"point_formats"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
@@ -272,6 +307,9 @@ type GatewaySchedulingConfig struct {
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
+ // 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机)
+ FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
+
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
@@ -781,6 +819,9 @@ func setDefaults() {
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
+ viper.SetDefault("gateway.max_account_switches", 10)
+ viper.SetDefault("gateway.max_account_switches_gemini", 3)
+ viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
@@ -793,11 +834,12 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
- viper.SetDefault("gateway.max_line_size", 10*1024*1024)
+ viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
+ viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
@@ -809,6 +851,8 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
+ // TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
+ viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 33c91dae..10a53f56 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -173,6 +173,7 @@ func (h *AccountHandler) List(c *gin.Context) {
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
+ sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
@@ -181,6 +182,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
+ sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
}
}
@@ -189,9 +191,9 @@ func (h *AccountHandler) List(c *gin.Context) {
var windowCosts map[int64]float64
var activeSessions map[int64]int
- // 获取活跃会话数(批量查询)
+ // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
- activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
+ activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
@@ -211,12 +213,8 @@ func (h *AccountHandler) List(c *gin.Context) {
}
accCopy := acc // 闭包捕获
g.Go(func() error {
- var startTime time.Time
- if accCopy.SessionWindowStart != nil {
- startTime = *accCopy.SessionWindowStart
- } else {
- startTime = time.Now().Add(-5 * time.Hour)
- }
+ // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
+ startTime := accCopy.GetCurrentWindowStartTime()
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
@@ -545,6 +543,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v
}
}
+
+ // 如果 project_id 获取失败,先更新凭证,再标记账户为 error
+ if tokenInfo.ProjectIDMissing {
+ // 先更新凭证
+ _, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
+ Credentials: newCredentials,
+ })
+ if updateErr != nil {
+ response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
+ return
+ }
+ // 标记账户为 error
+ if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
+ response.InternalError(c, "Failed to set account error: "+setErr.Error())
+ return
+ }
+ response.Success(c, gin.H{
+ "message": "Token refreshed but project_id is missing, account marked as error",
+ "warning": "missing_project_id",
+ })
+ return
+ }
+
+ // 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
+ if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
+ if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
+ response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
+ return
+ }
+ }
} else {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 457d52fc..b820a3fb 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -200,6 +200,10 @@ func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*se
return &account, nil
}
+func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
+ return nil
+}
+
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
return &account, nil
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index d8f10e6c..66b86ea0 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -161,6 +161,16 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
+ // TLS指纹伪装开关
+ if a.IsTLSFingerprintEnabled() {
+ enabled := true
+ out.EnableTLSFingerprint = &enabled
+ }
+ // 会话ID伪装开关
+ if a.IsSessionIDMaskingEnabled() {
+ enabled := true
+ out.EnableSessionIDMasking = &enabled
+ }
}
return out
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index ae9da254..4247dcbf 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -112,6 +112,15 @@ type Account struct {
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
+ // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 从 extra 字段提取,方便前端显示和编辑
+ EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
+
+ // 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 启用后将在15分钟内固定 metadata.user_id 中的 session ID
+ // 从 extra 字段提取,方便前端显示和编辑
+ EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
+
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 8c32be21..6c8d9ebe 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -31,6 +31,8 @@ type GatewayHandler struct {
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
+ maxAccountSwitches int
+ maxAccountSwitchesGemini int
}
// NewGatewayHandler creates a new GatewayHandler
@@ -44,8 +46,16 @@ func NewGatewayHandler(
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
+ maxAccountSwitches := 10
+ maxAccountSwitchesGemini := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
+ if cfg.Gateway.MaxAccountSwitches > 0 {
+ maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
+ }
+ if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
+ maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
+ }
}
return &GatewayHandler{
gatewayService: gatewayService,
@@ -54,6 +64,8 @@ func NewGatewayHandler(
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
+ maxAccountSwitches: maxAccountSwitches,
+ maxAccountSwitchesGemini: maxAccountSwitchesGemini,
}
}
@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
if platform == service.PlatformGemini {
- const maxAccountSwitches = 3
+ maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
- const maxAccountSwitches = 10
+ maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index ec943e61..c7646b38 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
- const maxAccountSwitches = 3
+ maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 68e67656..4c9dd8b9 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
+ maxAccountSwitches int
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
+ maxAccountSwitches := 3
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
+ if cfg.Gateway.MaxAccountSwitches > 0 {
+ maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
+ }
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
+ maxAccountSwitches: maxAccountSwitches,
}
}
@@ -189,7 +195,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
- const maxAccountSwitches = 3
+ maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go
index 1248be95..a6279b11 100644
--- a/backend/internal/pkg/antigravity/client.go
+++ b/backend/internal/pkg/antigravity/client.go
@@ -16,15 +16,6 @@ import (
"time"
)
-// resolveHost 从 URL 解析 host
-func resolveHost(urlStr string) string {
- parsed, err := url.Parse(urlStr)
- if err != nil {
- return ""
- }
- return parsed.Host
-}
-
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
// 构建 URL,流式请求添加 ?alt=sse 参数
@@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
return nil, err
}
- // 基础 Headers
+ // 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
- // Accept Header 根据请求类型设置
- if isStream {
- req.Header.Set("Accept", "text/event-stream")
- } else {
- req.Header.Set("Accept", "application/json")
- }
-
- // 显式设置 Host Header
- if host := resolveHost(apiURL); host != "" {
- req.Host = host
- }
-
return req, nil
}
@@ -195,12 +174,15 @@ func isConnectionError(err error) bool {
}
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
-// 仅连接错误和 HTTP 429 触发 URL 降级
+// 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级
func shouldFallbackToNextURL(err error, statusCode int) bool {
if isConnectionError(err) {
return true
}
- return statusCode == http.StatusTooManyRequests
+ return statusCode == http.StatusTooManyRequests ||
+ statusCode == http.StatusRequestTimeout ||
+ statusCode == http.StatusNotFound ||
+ statusCode >= 500
}
// ExchangeCode 用 authorization code 交换 token
@@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
- // 获取可用的 URL 列表
- availableURLs := DefaultURLAvailability.GetAvailableURLs()
- if len(availableURLs) == 0 {
- availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
- }
+ // 固定顺序:prod -> daily
+ availableURLs := BaseURLs
var lastErr error
for urlIdx, baseURL := range availableURLs {
@@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
if err != nil {
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
- DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
@@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
- DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
@@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
+ // 标记成功的 URL,下次优先使用
+ DefaultURLAvailability.MarkSuccess(baseURL)
return &loadResp, rawResp, nil
}
@@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
- // 获取可用的 URL 列表
- availableURLs := DefaultURLAvailability.GetAvailableURLs()
- if len(availableURLs) == 0 {
- availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
- }
+ // 固定顺序:prod -> daily
+ availableURLs := BaseURLs
var lastErr error
for urlIdx, baseURL := range availableURLs {
@@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
- DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
@@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
- DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
@@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
+ // 标记成功的 URL,下次优先使用
+ DefaultURLAvailability.MarkSuccess(baseURL)
return &modelsResp, rawResp, nil
}
diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go
index f688332f..c1cc998c 100644
--- a/backend/internal/pkg/antigravity/gemini_types.go
+++ b/backend/internal/pkg/antigravity/gemini_types.go
@@ -143,9 +143,10 @@ type GeminiResponse struct {
// GeminiCandidate Gemini 候选响应
type GeminiCandidate struct {
- Content *GeminiContent `json:"content,omitempty"`
- FinishReason string `json:"finishReason,omitempty"`
- Index int `json:"index,omitempty"`
+ Content *GeminiContent `json:"content,omitempty"`
+ FinishReason string `json:"finishReason,omitempty"`
+ Index int `json:"index,omitempty"`
+ GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
}
// GeminiUsageMetadata Gemini 用量元数据
@@ -156,6 +157,23 @@ type GeminiUsageMetadata struct {
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
+// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
+type GeminiGroundingMetadata struct {
+ WebSearchQueries []string `json:"webSearchQueries,omitempty"`
+ GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"`
+}
+
+// GeminiGroundingChunk Gemini grounding chunk
+type GeminiGroundingChunk struct {
+ Web *GeminiGroundingWeb `json:"web,omitempty"`
+}
+
+// GeminiGroundingWeb Gemini grounding web 信息
+type GeminiGroundingWeb struct {
+ Title string `json:"title,omitempty"`
+ URI string `json:"uri,omitempty"`
+}
+
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go
index 736c45df..ee2a6c1a 100644
--- a/backend/internal/pkg/antigravity/oauth.go
+++ b/backend/internal/pkg/antigravity/oauth.go
@@ -32,8 +32,8 @@ const (
"https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs"
- // User-Agent(模拟官方客户端)
- UserAgent = "antigravity/1.104.0 darwin/arm64"
+ // User-Agent(与 Antigravity-Manager 保持一致)
+ UserAgent = "antigravity/1.11.9 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
@@ -42,22 +42,21 @@ const (
URLAvailabilityTTL = 5 * time.Minute
)
-// BaseURLs 定义 Antigravity API 端点,按优先级排序
-// fallback 顺序: sandbox → daily → prod
+// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
var BaseURLs = []string{
- "https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
- "https://daily-cloudcode-pa.googleapis.com", // daily
- "https://cloudcode-pa.googleapis.com", // prod
+ "https://cloudcode-pa.googleapis.com", // prod (优先)
+ "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用)
}
// BaseURL 默认 URL(保持向后兼容)
var BaseURL = BaseURLs[0]
-// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
+// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级)
type URLAvailability struct {
mu sync.RWMutex
unavailable map[string]time.Time // URL -> 恢复时间
ttl time.Duration
+ lastSuccess string // 最近成功请求的 URL,优先使用
}
// DefaultURLAvailability 全局 URL 可用性管理器
@@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) {
u.unavailable[url] = time.Now().Add(u.ttl)
}
+// MarkSuccess 标记 URL 请求成功,将其设为优先使用
+func (u *URLAvailability) MarkSuccess(url string) {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ u.lastSuccess = url
+ // 成功后清除该 URL 的不可用标记
+ delete(u.unavailable, url)
+}
+
// IsAvailable 检查 URL 是否可用
func (u *URLAvailability) IsAvailable(url string) bool {
u.mu.RLock()
@@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool {
return time.Now().After(expiry)
}
-// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
+// GetAvailableURLs 返回可用的 URL 列表
+// 最近成功的 URL 优先,其他按默认顺序
func (u *URLAvailability) GetAvailableURLs() []string {
u.mu.RLock()
defer u.mu.RUnlock()
now := time.Now()
result := make([]string, 0, len(BaseURLs))
+
+ // 如果有最近成功的 URL 且可用,放在最前面
+ if u.lastSuccess != "" {
+ expiry, exists := u.unavailable[u.lastSuccess]
+ if !exists || now.After(expiry) {
+ result = append(result, u.lastSuccess)
+ }
+ }
+
+ // 添加其他可用的 URL(按默认顺序)
for _, url := range BaseURLs {
+ // 跳过已添加的 lastSuccess
+ if url == u.lastSuccess {
+ continue
+ }
expiry, exists := u.unavailable[url]
if !exists || now.After(expiry) {
result = append(result, url)
@@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string {
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
-
-// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
-// 格式:{形容词}-{名词}-{5位随机字符}
-func GenerateMockProjectID() string {
- adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
- nouns := []string{"fuze", "wave", "spark", "flow", "core"}
-
- randBytes, _ := GenerateRandomBytes(7)
-
- adj := adjectives[int(randBytes[0])%len(adjectives)]
- noun := nouns[int(randBytes[1])%len(nouns)]
-
- // 生成 5 位随机字符(a-z0-9)
- const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
- suffix := make([]byte, 5)
- for i := 0; i < 5; i++ {
- suffix[i] = charset[int(randBytes[i+2])%len(charset)]
- }
-
- return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
-}
diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go
index a8474576..637a4ea8 100644
--- a/backend/internal/pkg/antigravity/request_transformer.go
+++ b/backend/internal/pkg/antigravity/request_transformer.go
@@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions {
}
}
+// webSearchFallbackModel web_search 请求使用的降级模型
+const webSearchFallbackModel = "gemini-2.5-flash"
+
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
@@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
+ // 检测是否有 web_search 工具
+ hasWebSearchTool := hasWebSearchTool(claudeReq.Tools)
+ requestType := "agent"
+ targetModel := mappedModel
+ if hasWebSearchTool {
+ requestType = "web_search"
+ if targetModel != webSearchFallbackModel {
+ targetModel = webSearchFallbackModel
+ }
+ }
+
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
- allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
+ allowDummyThought := strings.HasPrefix(targetModel, "gemini-")
// 1. 构建 contents
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
@@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
}
// 2. 构建 systemInstruction
- systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
+ systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools)
// 3. 构建 generationConfig
reqForConfig := claudeReq
@@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
reqCopy.Thinking = nil
reqForConfig = &reqCopy
}
+ if targetModel != "" && targetModel != reqForConfig.Model {
+ reqCopy := *reqForConfig
+ reqCopy.Model = targetModel
+ reqForConfig = &reqCopy
+ }
generationConfig := buildGenerationConfig(reqForConfig)
// 4. 构建 tools
@@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map
Project: projectID,
RequestID: "agent-" + uuid.New().String(),
UserAgent: "antigravity", // 固定值,与官方客户端一致
- RequestType: "agent",
- Model: mappedModel,
+ RequestType: requestType,
+ Model: targetModel,
Request: innerRequest,
}
@@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string {
return antigravityIdentity
}
-// buildSystemInstruction 构建 systemInstruction
-func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
+// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致)
+const mcpXMLProtocol = `
+==== MCP XML 工具调用协议 (Workaround) ====
+当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时:
+1) 优先尝试 XML 格式调用:输出 ` + "`{\"arg\":\"value\"}`" + `。
+2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。
+3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。
+===========================================`
+
+// hasMCPTools 检测是否有 mcp__ 前缀的工具
+func hasMCPTools(tools []ClaudeTool) bool {
+ for _, tool := range tools {
+ if strings.HasPrefix(tool.Name, "mcp__") {
+ return true
+ }
+ }
+ return false
+}
+
+// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令
+func filterOpenCodePrompt(text string) string {
+ if !strings.Contains(text, "You are an interactive CLI tool") {
+ return text
+ }
+ // 提取 "Instructions from:" 及之后的部分
+ if idx := strings.Index(text, "Instructions from:"); idx >= 0 {
+ return text[idx:]
+ }
+ // 如果没有自定义指令,返回空
+ return ""
+}
+
+// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致)
+func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent {
var parts []GeminiPart
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
@@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
var sysStr string
if err := json.Unmarshal(system, &sysStr); err == nil {
if strings.TrimSpace(sysStr) != "" {
- userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
if strings.Contains(sysStr, "You are Antigravity") {
userHasAntigravityIdentity = true
}
+ // 过滤 OpenCode 默认提示词
+ filtered := filterOpenCodePrompt(sysStr)
+ if filtered != "" {
+ userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
+ }
}
} else {
// 尝试解析为数组
@@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
if err := json.Unmarshal(system, &sysBlocks); err == nil {
for _, block := range sysBlocks {
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
- userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
if strings.Contains(block.Text, "You are Antigravity") {
userHasAntigravityIdentity = true
}
+ // 过滤 OpenCode 默认提示词
+ filtered := filterOpenCodePrompt(block.Text)
+ if filtered != "" {
+ userSystemParts = append(userSystemParts, GeminiPart{Text: filtered})
+ }
}
}
}
@@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans
// 添加用户的 system prompt
parts = append(parts, userSystemParts...)
+ // 检测是否有 MCP 工具,如有则注入 XML 调用协议
+ if hasMCPTools(tools) {
+ parts = append(parts, GeminiPart{Text: mcpXMLProtocol})
+ }
+
+ // 如果用户没有提供 Antigravity 身份,添加结束标记
+ if !userHasAntigravityIdentity {
+ parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"})
+ }
+
if len(parts) == 0 {
return nil
}
@@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
StopSequences: DefaultStopSequences,
}
+ // 如果请求中指定了 MaxTokens,使用请求值
+ if req.MaxTokens > 0 {
+ config.MaxOutputTokens = req.MaxTokens
+ }
+
// Thinking 配置
if req.Thinking != nil && req.Thinking.Type == "enabled" {
config.ThinkingConfig = &GeminiThinkingConfig{
@@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
return config
}
+func hasWebSearchTool(tools []ClaudeTool) bool {
+ for _, tool := range tools {
+ if isWebSearchTool(tool) {
+ return true
+ }
+ }
+ return false
+}
+
+func isWebSearchTool(tool ClaudeTool) bool {
+ if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" {
+ return true
+ }
+
+ name := strings.TrimSpace(tool.Name)
+ switch name {
+ case "web_search", "google_search", "web_search_20250305":
+ return true
+ default:
+ return false
+ }
+}
+
// buildTools 构建 tools
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
if len(tools) == 0 {
return nil
}
- // 检查是否有 web_search 工具
- hasWebSearch := false
- for _, tool := range tools {
- if tool.Name == "web_search" {
- hasWebSearch = true
- break
- }
- }
-
- if hasWebSearch {
- // Web Search 工具映射
- return []GeminiToolDeclaration{{
- GoogleSearch: &GeminiGoogleSearch{
- EnhancedContent: &GeminiEnhancedContent{
- ImageSearch: &GeminiImageSearch{
- MaxResultCount: 5,
- },
- },
- },
- }}
- }
+ hasWebSearch := hasWebSearchTool(tools)
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
+ if isWebSearchTool(tool) {
+ continue
+ }
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
@@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
if len(funcDecls) == 0 {
- return nil
+ if !hasWebSearch {
+ return nil
+ }
+
+ // Web Search 工具映射
+ return []GeminiToolDeclaration{{
+ GoogleSearch: &GeminiGoogleSearch{
+ EnhancedContent: &GeminiEnhancedContent{
+ ImageSearch: &GeminiImageSearch{
+ MaxResultCount: 5,
+ },
+ },
+ },
+ }}
}
return []GeminiToolDeclaration{{
diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go
index cd7f5f80..04424c03 100644
--- a/backend/internal/pkg/antigravity/response_transformer.go
+++ b/backend/internal/pkg/antigravity/response_transformer.go
@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
+ "strings"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
@@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID,
p.processPart(&part)
}
+ if len(geminiResp.Candidates) > 0 {
+ if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil {
+ p.processGrounding(grounding)
+ }
+ }
+
// 刷新剩余内容
p.flushThinking()
p.flushText()
@@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
}
}
+func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) {
+ groundingText := buildGroundingText(grounding)
+ if groundingText == "" {
+ return
+ }
+
+ p.flushThinking()
+ p.flushText()
+ p.textBuilder += groundingText
+ p.flushText()
+}
+
// flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" {
@@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
}
}
+func buildGroundingText(grounding *GeminiGroundingMetadata) string {
+ if grounding == nil {
+ return ""
+ }
+
+ var builder strings.Builder
+
+ if len(grounding.WebSearchQueries) > 0 {
+ _, _ = builder.WriteString("\n\n---\nWeb search queries: ")
+ _, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", "))
+ }
+
+ if len(grounding.GroundingChunks) > 0 {
+ var links []string
+ for i, chunk := range grounding.GroundingChunks {
+ if chunk.Web == nil {
+ continue
+ }
+ title := strings.TrimSpace(chunk.Web.Title)
+ if title == "" {
+ title = "Source"
+ }
+ uri := strings.TrimSpace(chunk.Web.URI)
+ if uri == "" {
+ uri = "#"
+ }
+ links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri))
+ }
+
+ if len(links) > 0 {
+ _, _ = builder.WriteString("\n\nSources:\n")
+ _, _ = builder.WriteString(strings.Join(links, "\n"))
+ }
+ }
+
+ return builder.String()
+}
+
// generateRandomID 生成随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go
index 9fe68a11..da0c6f97 100644
--- a/backend/internal/pkg/antigravity/stream_transformer.go
+++ b/backend/internal/pkg/antigravity/stream_transformer.go
@@ -27,6 +27,8 @@ type StreamingProcessor struct {
pendingSignature string
trailingSignature string
originalModel string
+ webSearchQueries []string
+ groundingChunks []GeminiGroundingChunk
// 累计 usage
inputTokens int
@@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
}
}
+ if len(geminiResp.Candidates) > 0 {
+ p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata)
+ }
+
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
@@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
return result.Bytes()
}
+func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) {
+ if grounding == nil {
+ return
+ }
+
+ if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 {
+ p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...)
+ }
+
+ if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 {
+ p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...)
+ }
+}
+
// processThinking 处理 thinking
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
var result bytes.Buffer
@@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
p.trailingSignature = ""
}
+ if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 {
+ groundingText := buildGroundingText(&GeminiGroundingMetadata{
+ WebSearchQueries: p.webSearchQueries,
+ GroundingChunks: p.groundingChunks,
+ })
+ if groundingText != "" {
+ _, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
+ "type": "text",
+ "text": "",
+ }))
+ _, _ = result.Write(p.emitDelta("text_delta", map[string]any{
+ "text": groundingText,
+ }))
+ _, _ = result.Write(p.endBlock())
+ }
+ }
+
// 确定 stop_reason
stopReason := "end_turn"
if p.usedTool {
diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go
index a92ff9e8..43fe12d4 100644
--- a/backend/internal/pkg/response/response.go
+++ b/backend/internal/pkg/response/response.go
@@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) {
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
- if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
+ if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
- if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
+ if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
pageSize = val
}
}
diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go
new file mode 100644
index 00000000..42510986
--- /dev/null
+++ b/backend/internal/pkg/tlsfingerprint/dialer.go
@@ -0,0 +1,568 @@
+// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
+// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients.
+package tlsfingerprint
+
+import (
+ "bufio"
+ "context"
+ "encoding/base64"
+ "fmt"
+ "log/slog"
+ "net"
+ "net/http"
+ "net/url"
+
+ utls "github.com/refraction-networking/utls"
+ "golang.org/x/net/proxy"
+)
+
+// Profile contains TLS fingerprint configuration.
+type Profile struct {
+ Name string // Profile name for identification
+ CipherSuites []uint16
+ Curves []uint16
+ PointFormats []uint8
+ EnableGREASE bool
+}
+
+// Dialer creates TLS connections with custom fingerprints.
+type Dialer struct {
+ profile *Profile
+ baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)
+}
+
+// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints.
+// It handles the CONNECT tunnel establishment before performing TLS handshake.
+type HTTPProxyDialer struct {
+ profile *Profile
+ proxyURL *url.URL
+}
+
+// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints.
+// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel.
+type SOCKS5ProxyDialer struct {
+ profile *Profile
+ proxyURL *url.URL
+}
+
+// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)
+// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V
+// JA3 Hash: 1a28e69016765d92e3b381168d68922c
+//
+// Note: JA3/JA4 may have slight variations due to:
+// - Session ticket presence/absence
+// - Extension negotiation state
+var (
+ // defaultCipherSuites contains all 59 cipher suites from Claude CLI
+ // Order is critical for JA3 fingerprint matching
+ defaultCipherSuites = []uint16{
+ // TLS 1.3 cipher suites (MUST be first)
+ 0x1302, // TLS_AES_256_GCM_SHA384
+ 0x1303, // TLS_CHACHA20_POLY1305_SHA256
+ 0x1301, // TLS_AES_128_GCM_SHA256
+
+ // ECDHE + AES-GCM
+ 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
+ 0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
+ 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
+ 0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
+
+ // DHE + AES-GCM
+ 0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256
+
+ // ECDHE/DHE + AES-CBC-SHA256/384
+ 0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256
+ 0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256
+ 0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384
+ 0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256
+
+ // DHE-DSS/RSA + AES-GCM
+ 0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384
+ 0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
+
+ // ChaCha20-Poly1305
+ 0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
+ 0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
+ 0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256
+
+ // AES-CCM (256-bit)
+ 0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8
+ 0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM
+ 0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8
+ 0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM
+
+ // ARIA (256-bit)
+ 0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384
+ 0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384
+ 0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384
+ 0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384
+
+ // DHE-DSS + AES-GCM (128-bit)
+ 0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256
+
+ // AES-CCM (128-bit)
+ 0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8
+ 0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM
+ 0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8
+ 0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM
+
+ // ARIA (128-bit)
+ 0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256
+ 0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256
+ 0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256
+ 0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256
+
+ // ECDHE/DHE + AES-CBC-SHA384/256 (more)
+ 0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384
+ 0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256
+ 0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256
+ 0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256
+
+ // ECDHE/DHE + AES-CBC-SHA (legacy)
+ 0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA
+ 0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA
+ 0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA
+ 0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA
+ 0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA
+ 0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
+ 0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA
+ 0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA
+
+ // RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit)
+ 0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384
+ 0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8
+ 0xc09d, // TLS_RSA_WITH_AES_256_CCM
+ 0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384
+
+ // RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit)
+ 0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256
+ 0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8
+ 0xc09c, // TLS_RSA_WITH_AES_128_CCM
+ 0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256
+
+ // RSA + AES-CBC (non-PFS, legacy)
+ 0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256
+ 0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256
+ 0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA
+ 0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA
+
+ // Renegotiation indication
+ 0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV
+ }
+
+ // defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE)
+ defaultCurves = []utls.CurveID{
+ utls.X25519, // 0x001d
+ utls.CurveP256, // 0x0017 (secp256r1)
+ utls.CurveID(0x001e), // x448
+ utls.CurveP521, // 0x0019 (secp521r1)
+ utls.CurveP384, // 0x0018 (secp384r1)
+ utls.CurveID(0x0100), // ffdhe2048
+ utls.CurveID(0x0101), // ffdhe3072
+ utls.CurveID(0x0102), // ffdhe4096
+ utls.CurveID(0x0103), // ffdhe6144
+ utls.CurveID(0x0104), // ffdhe8192
+ }
+
+ // defaultPointFormats contains all 3 point formats from Claude CLI
+ defaultPointFormats = []uint8{
+ 0, // uncompressed
+ 1, // ansiX962_compressed_prime
+ 2, // ansiX962_compressed_char2
+ }
+
+ // defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI
+ defaultSignatureAlgorithms = []utls.SignatureScheme{
+ 0x0403, // ecdsa_secp256r1_sha256
+ 0x0503, // ecdsa_secp384r1_sha384
+ 0x0603, // ecdsa_secp521r1_sha512
+ 0x0807, // ed25519
+ 0x0808, // ed448
+ 0x0809, // rsa_pss_pss_sha256
+ 0x080a, // rsa_pss_pss_sha384
+ 0x080b, // rsa_pss_pss_sha512
+ 0x0804, // rsa_pss_rsae_sha256
+ 0x0805, // rsa_pss_rsae_sha384
+ 0x0806, // rsa_pss_rsae_sha512
+ 0x0401, // rsa_pkcs1_sha256
+ 0x0501, // rsa_pkcs1_sha384
+ 0x0601, // rsa_pkcs1_sha512
+ 0x0303, // ecdsa_sha224
+ 0x0301, // rsa_pkcs1_sha224
+ 0x0302, // dsa_sha224
+ 0x0402, // dsa_sha256
+ 0x0502, // dsa_sha384
+ 0x0602, // dsa_sha512
+ }
+)
+
+// NewDialer creates a new TLS fingerprint dialer.
+// baseDialer is used for TCP connection establishment (supports proxy scenarios).
+// If baseDialer is nil, direct TCP dial is used.
+func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer {
+ if baseDialer == nil {
+ baseDialer = (&net.Dialer{}).DialContext
+ }
+ return &Dialer{profile: profile, baseDialer: baseDialer}
+}
+
+// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies.
+// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint.
+func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer {
+ return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL}
+}
+
+// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies.
+// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint.
+func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer {
+ return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL}
+}
+
+// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint.
+// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel
+func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr)
+
+ // Step 1: Create SOCKS5 dialer
+ var auth *proxy.Auth
+ if d.proxyURL.User != nil {
+ username := d.proxyURL.User.Username()
+ password, _ := d.proxyURL.User.Password()
+ auth = &proxy.Auth{
+ User: username,
+ Password: password,
+ }
+ }
+
+ // Determine proxy address
+ proxyAddr := d.proxyURL.Host
+ if d.proxyURL.Port() == "" {
+ proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port
+ }
+
+ socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct)
+ if err != nil {
+ slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err)
+ return nil, fmt.Errorf("create SOCKS5 dialer: %w", err)
+ }
+
+ // Step 2: Establish SOCKS5 tunnel to target
+ slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr)
+ conn, err := socksDialer.Dial("tcp", addr)
+ if err != nil {
+ slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err)
+ return nil, fmt.Errorf("SOCKS5 connect: %w", err)
+ }
+ slog.Debug("tls_fingerprint_socks5_tunnel_established")
+
+ // Step 3: Perform TLS handshake on the tunnel with utls fingerprint
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = addr
+ }
+ slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host)
+
+ // Build ClientHello specification from profile (Node.js/Claude CLI fingerprint)
+ spec := buildClientHelloSpecFromProfile(d.profile)
+ slog.Debug("tls_fingerprint_socks5_clienthello_spec",
+ "cipher_suites", len(spec.CipherSuites),
+ "extensions", len(spec.Extensions),
+ "compression_methods", spec.CompressionMethods,
+ "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
+ "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
+
+ if d.profile != nil {
+ slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
+ }
+
+ // Create uTLS connection on the tunnel
+ tlsConn := utls.UClient(conn, &utls.Config{
+ ServerName: host,
+ }, utls.HelloCustom)
+
+ if err := tlsConn.ApplyPreset(spec); err != nil {
+ slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err)
+ _ = conn.Close()
+ return nil, fmt.Errorf("apply TLS preset: %w", err)
+ }
+
+ if err := tlsConn.Handshake(); err != nil {
+ slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
+ _ = conn.Close()
+ return nil, fmt.Errorf("TLS handshake failed: %w", err)
+ }
+
+ state := tlsConn.ConnectionState()
+ slog.Debug("tls_fingerprint_socks5_handshake_success",
+ "version", fmt.Sprintf("0x%04x", state.Version),
+ "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
+ "alpn", state.NegotiatedProtocol)
+
+ return tlsConn, nil
+}
+
+// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint.
+// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls
+func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr)
+
+ // Step 1: TCP connect to proxy server
+ var proxyAddr string
+ if d.proxyURL.Port() != "" {
+ proxyAddr = d.proxyURL.Host
+ } else {
+ // Default ports
+ if d.proxyURL.Scheme == "https" {
+ proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443")
+ } else {
+ proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80")
+ }
+ }
+
+ dialer := &net.Dialer{}
+ conn, err := dialer.DialContext(ctx, "tcp", proxyAddr)
+ if err != nil {
+ slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err)
+ return nil, fmt.Errorf("connect to proxy: %w", err)
+ }
+ slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr)
+
+ // Step 2: Send CONNECT request to establish tunnel
+ req := &http.Request{
+ Method: "CONNECT",
+ URL: &url.URL{Opaque: addr},
+ Host: addr,
+ Header: make(http.Header),
+ }
+
+ // Add proxy authentication if present
+ if d.proxyURL.User != nil {
+ username := d.proxyURL.User.Username()
+ password, _ := d.proxyURL.User.Password()
+ auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
+ req.Header.Set("Proxy-Authorization", "Basic "+auth)
+ }
+
+ slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr)
+ if err := req.Write(conn); err != nil {
+ _ = conn.Close()
+ slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err)
+ return nil, fmt.Errorf("write CONNECT request: %w", err)
+ }
+
+ // Step 3: Read CONNECT response
+ br := bufio.NewReader(conn)
+ resp, err := http.ReadResponse(br, req)
+ if err != nil {
+ _ = conn.Close()
+ slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err)
+ return nil, fmt.Errorf("read CONNECT response: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ _ = conn.Close()
+ slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status)
+ return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
+ }
+ slog.Debug("tls_fingerprint_http_proxy_tunnel_established")
+
+ // Step 4: Perform TLS handshake on the tunnel with utls fingerprint
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = addr
+ }
+ slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host)
+
+ // Build ClientHello specification (reuse the shared method)
+ spec := buildClientHelloSpecFromProfile(d.profile)
+ slog.Debug("tls_fingerprint_http_proxy_clienthello_spec",
+ "cipher_suites", len(spec.CipherSuites),
+ "extensions", len(spec.Extensions))
+
+ if d.profile != nil {
+ slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
+ }
+
+ // Create uTLS connection on the tunnel
+ // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
+ tlsConn := utls.UClient(conn, &utls.Config{
+ ServerName: host,
+ }, utls.HelloCustom)
+
+ if err := tlsConn.ApplyPreset(spec); err != nil {
+ slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err)
+ _ = conn.Close()
+ return nil, fmt.Errorf("apply TLS preset: %w", err)
+ }
+
+ if err := tlsConn.HandshakeContext(ctx); err != nil {
+ slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err)
+ _ = conn.Close()
+ return nil, fmt.Errorf("TLS handshake failed: %w", err)
+ }
+
+ state := tlsConn.ConnectionState()
+ slog.Debug("tls_fingerprint_http_proxy_handshake_success",
+ "version", fmt.Sprintf("0x%04x", state.Version),
+ "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
+ "alpn", state.NegotiatedProtocol)
+
+ return tlsConn, nil
+}
+
+// DialTLSContext establishes a TLS connection with the configured fingerprint.
+// This method is designed to be used as http.Transport.DialTLSContext.
+func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ // Establish TCP connection using base dialer (supports proxy)
+ slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr)
+ conn, err := d.baseDialer(ctx, network, addr)
+ if err != nil {
+ slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err)
+ return nil, err
+ }
+ slog.Debug("tls_fingerprint_tcp_connected", "addr", addr)
+
+ // Extract hostname for SNI
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = addr
+ }
+ slog.Debug("tls_fingerprint_sni_hostname", "host", host)
+
+ // Build ClientHello specification
+ spec := d.buildClientHelloSpec()
+ slog.Debug("tls_fingerprint_clienthello_spec",
+ "cipher_suites", len(spec.CipherSuites),
+ "extensions", len(spec.Extensions))
+
+ // Log profile info
+ if d.profile != nil {
+ slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
+ } else {
+ slog.Debug("tls_fingerprint_using_default_profile")
+ }
+
+ // Create uTLS connection
+ // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions
+ tlsConn := utls.UClient(conn, &utls.Config{
+ ServerName: host,
+ }, utls.HelloCustom)
+
+ // Apply fingerprint
+ if err := tlsConn.ApplyPreset(spec); err != nil {
+ slog.Debug("tls_fingerprint_apply_preset_failed", "error", err)
+ _ = conn.Close()
+ return nil, err
+ }
+ slog.Debug("tls_fingerprint_preset_applied")
+
+ // Perform TLS handshake
+ if err := tlsConn.HandshakeContext(ctx); err != nil {
+ slog.Debug("tls_fingerprint_handshake_failed",
+ "error", err,
+ "local_addr", conn.LocalAddr(),
+ "remote_addr", conn.RemoteAddr())
+ _ = conn.Close()
+ return nil, fmt.Errorf("TLS handshake failed: %w", err)
+ }
+
+ // Log successful handshake details
+ state := tlsConn.ConnectionState()
+ slog.Debug("tls_fingerprint_handshake_success",
+ "version", fmt.Sprintf("0x%04x", state.Version),
+ "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
+ "alpn", state.NegotiatedProtocol)
+
+ return tlsConn, nil
+}
+
+// buildClientHelloSpec constructs the ClientHello specification based on the profile.
+func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec {
+ return buildClientHelloSpecFromProfile(d.profile)
+}
+
+// toUTLSCurves converts uint16 slice to utls.CurveID slice.
+func toUTLSCurves(curves []uint16) []utls.CurveID {
+ result := make([]utls.CurveID, len(curves))
+ for i, c := range curves {
+ result[i] = utls.CurveID(c)
+ }
+ return result
+}
+
+// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile.
+// This is a standalone function that can be used by both Dialer and HTTPProxyDialer.
+func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec {
+ // Get cipher suites
+ var cipherSuites []uint16
+ if profile != nil && len(profile.CipherSuites) > 0 {
+ cipherSuites = profile.CipherSuites
+ } else {
+ cipherSuites = defaultCipherSuites
+ }
+
+ // Get curves
+ var curves []utls.CurveID
+ if profile != nil && len(profile.Curves) > 0 {
+ curves = toUTLSCurves(profile.Curves)
+ } else {
+ curves = defaultCurves
+ }
+
+ // Get point formats
+ var pointFormats []uint8
+ if profile != nil && len(profile.PointFormats) > 0 {
+ pointFormats = profile.PointFormats
+ } else {
+ pointFormats = defaultPointFormats
+ }
+
+ // Check if GREASE is enabled
+ enableGREASE := profile != nil && profile.EnableGREASE
+
+ extensions := make([]utls.TLSExtension, 0, 16)
+
+ if enableGREASE {
+ extensions = append(extensions, &utls.UtlsGREASEExtension{})
+ }
+
+ // SNI extension - MUST be explicitly added for HelloCustom mode
+ // utls will populate the server name from Config.ServerName
+ extensions = append(extensions, &utls.SNIExtension{})
+
+ // Claude CLI extension order (captured from tshark):
+ // server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35),
+ // alpn(16), encrypt_then_mac(22), extended_master_secret(23),
+ // signature_algorithms(13), supported_versions(43),
+ // psk_key_exchange_modes(45), key_share(51)
+ extensions = append(extensions,
+ &utls.SupportedPointsExtension{SupportedPoints: pointFormats},
+ &utls.SupportedCurvesExtension{Curves: curves},
+ &utls.SessionTicketExtension{},
+ &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}},
+ &utls.GenericExtension{Id: 22},
+ &utls.ExtendedMasterSecretExtension{},
+ &utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms},
+ &utls.SupportedVersionsExtension{Versions: []uint16{
+ utls.VersionTLS13,
+ utls.VersionTLS12,
+ }},
+ &utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}},
+ &utls.KeyShareExtension{KeyShares: []utls.KeyShare{
+ {Group: utls.X25519},
+ }},
+ )
+
+ if enableGREASE {
+ extensions = append(extensions, &utls.UtlsGREASEExtension{})
+ }
+
+ return &utls.ClientHelloSpec{
+ CipherSuites: cipherSuites,
+ CompressionMethods: []uint8{0}, // null compression only (standard)
+ Extensions: extensions,
+ TLSVersMax: utls.VersionTLS13,
+ TLSVersMin: utls.VersionTLS10,
+ }
+}
diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go
new file mode 100644
index 00000000..2aed1287
--- /dev/null
+++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go
@@ -0,0 +1,307 @@
+// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
+//
+// Integration tests for verifying TLS fingerprint correctness.
+// These tests make actual network requests and should be run manually.
+//
+// Run with: go test -v ./internal/pkg/tlsfingerprint/...
+// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
+package tlsfingerprint
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+)
+
+// FingerprintResponse represents the response from tls.peet.ws/api/all.
+type FingerprintResponse struct {
+ IP string `json:"ip"`
+ TLS TLSInfo `json:"tls"`
+ HTTP2 any `json:"http2"`
+}
+
+// TLSInfo contains TLS fingerprint details.
+type TLSInfo struct {
+ JA3 string `json:"ja3"`
+ JA3Hash string `json:"ja3_hash"`
+ JA4 string `json:"ja4"`
+ PeetPrint string `json:"peetprint"`
+ PeetPrintHash string `json:"peetprint_hash"`
+ ClientRandom string `json:"client_random"`
+ SessionID string `json:"session_id"`
+}
+
+// TestDialerBasicConnection tests that the dialer can establish TLS connections.
+func TestDialerBasicConnection(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping network test in short mode")
+ }
+
+ // Create a dialer with default profile
+ profile := &Profile{
+ Name: "Test Profile",
+ EnableGREASE: false,
+ }
+ dialer := NewDialer(profile, nil)
+
+ // Create HTTP client with custom TLS dialer
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialTLSContext: dialer.DialTLSContext,
+ },
+ Timeout: 30 * time.Second,
+ }
+
+ // Make a request to a known HTTPS endpoint
+ resp, err := client.Get("https://www.google.com")
+ if err != nil {
+ t.Fatalf("failed to connect: %v", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("expected status 200, got %d", resp.StatusCode)
+ }
+}
+
+// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
+// This test uses tls.peet.ws to verify the fingerprint.
+// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
+// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
+func TestJA3Fingerprint(t *testing.T) {
+ // Skip if network is unavailable or if running in short mode
+ if testing.Short() {
+ t.Skip("skipping integration test in short mode")
+ }
+
+ profile := &Profile{
+ Name: "Claude CLI Test",
+ EnableGREASE: false,
+ }
+ dialer := NewDialer(profile, nil)
+
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialTLSContext: dialer.DialTLSContext,
+ },
+ Timeout: 30 * time.Second,
+ }
+
+ // Use tls.peet.ws fingerprint detection API
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
+ if err != nil {
+ t.Fatalf("failed to create request: %v", err)
+ }
+ req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("failed to get fingerprint: %v", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("failed to read response: %v", err)
+ }
+
+ var fpResp FingerprintResponse
+ if err := json.Unmarshal(body, &fpResp); err != nil {
+ t.Logf("Response body: %s", string(body))
+ t.Fatalf("failed to parse fingerprint response: %v", err)
+ }
+
+ // Log all fingerprint information
+ t.Logf("JA3: %s", fpResp.TLS.JA3)
+ t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
+ t.Logf("JA4: %s", fpResp.TLS.JA4)
+ t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
+ t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
+
+ // Verify JA3 hash matches expected value
+ expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
+ if fpResp.TLS.JA3Hash == expectedJA3Hash {
+ t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
+ } else {
+ t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
+ }
+
+ // Verify JA4 fingerprint
+ // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
+ // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
+ // The suffix _a33745022dd6_1f22a2ca17c4 should match
+ expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
+ if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
+ t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
+ } else {
+ t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
+ }
+
+ // Verify JA4 prefix (t13d5911h1 or t13i5911h1)
+ // d = domain (SNI present), i = IP (no SNI)
+ // Since we connect to tls.peet.ws (domain), we expect 'd'
+ expectedJA4Prefix := "t13d5911h1"
+ if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
+ t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
+ } else {
+ // Also accept 'i' variant for IP connections
+ altPrefix := "t13i5911h1"
+ if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
+ t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
+ } else {
+ t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
+ }
+ }
+
+ // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
+ if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
+ t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
+ } else {
+ t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
+ }
+
+ // Verify extension list (should be 11 extensions including SNI)
+ // Expected: 0-11-10-35-16-22-23-13-43-45-51
+ expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
+ if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
+ t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
+ } else {
+ t.Logf("Warning: JA3 extension list may differ")
+ }
+}
+
+// TestDialerWithProfile tests that different profiles produce different fingerprints.
+func TestDialerWithProfile(t *testing.T) {
+ // Create two dialers with different profiles
+ profile1 := &Profile{
+ Name: "Profile 1 - No GREASE",
+ EnableGREASE: false,
+ }
+ profile2 := &Profile{
+ Name: "Profile 2 - With GREASE",
+ EnableGREASE: true,
+ }
+
+ dialer1 := NewDialer(profile1, nil)
+ dialer2 := NewDialer(profile2, nil)
+
+ // Build specs and compare
+ // Note: We can't directly compare JA3 without making network requests
+ // but we can verify the specs are different
+ spec1 := dialer1.buildClientHelloSpec()
+ spec2 := dialer2.buildClientHelloSpec()
+
+ // Profile with GREASE should have more extensions
+ if len(spec2.Extensions) <= len(spec1.Extensions) {
+ t.Error("expected GREASE profile to have more extensions")
+ }
+}
+
+// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation.
+// Note: This is a unit test - actual proxy testing requires a proxy server.
+func TestHTTPProxyDialerBasic(t *testing.T) {
+ profile := &Profile{
+ Name: "Test Profile",
+ EnableGREASE: false,
+ }
+
+ // Test that dialer is created without panic
+ proxyURL := mustParseURL("http://proxy.example.com:8080")
+ dialer := NewHTTPProxyDialer(profile, proxyURL)
+
+ if dialer == nil {
+ t.Fatal("expected dialer to be created")
+ }
+ if dialer.profile != profile {
+ t.Error("expected profile to be set")
+ }
+ if dialer.proxyURL != proxyURL {
+ t.Error("expected proxyURL to be set")
+ }
+}
+
+// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation.
+// Note: This is a unit test - actual proxy testing requires a proxy server.
+func TestSOCKS5ProxyDialerBasic(t *testing.T) {
+ profile := &Profile{
+ Name: "Test Profile",
+ EnableGREASE: false,
+ }
+
+ // Test that dialer is created without panic
+ proxyURL := mustParseURL("socks5://proxy.example.com:1080")
+ dialer := NewSOCKS5ProxyDialer(profile, proxyURL)
+
+ if dialer == nil {
+ t.Fatal("expected dialer to be created")
+ }
+ if dialer.profile != profile {
+ t.Error("expected profile to be set")
+ }
+ if dialer.proxyURL != proxyURL {
+ t.Error("expected proxyURL to be set")
+ }
+}
+
+// TestBuildClientHelloSpec tests ClientHello spec construction.
+func TestBuildClientHelloSpec(t *testing.T) {
+ // Test with nil profile (should use defaults)
+ spec := buildClientHelloSpecFromProfile(nil)
+
+ if len(spec.CipherSuites) == 0 {
+ t.Error("expected cipher suites to be set")
+ }
+ if len(spec.Extensions) == 0 {
+ t.Error("expected extensions to be set")
+ }
+
+ // Verify default cipher suites are used
+ if len(spec.CipherSuites) != len(defaultCipherSuites) {
+ t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites))
+ }
+
+ // Test with custom profile
+ customProfile := &Profile{
+ Name: "Custom",
+ EnableGREASE: false,
+ CipherSuites: []uint16{0x1301, 0x1302},
+ }
+ spec = buildClientHelloSpecFromProfile(customProfile)
+
+ if len(spec.CipherSuites) != 2 {
+ t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites))
+ }
+}
+
+// TestToUTLSCurves tests curve ID conversion.
+func TestToUTLSCurves(t *testing.T) {
+ input := []uint16{0x001d, 0x0017, 0x0018}
+ result := toUTLSCurves(input)
+
+ if len(result) != len(input) {
+ t.Errorf("expected %d curves, got %d", len(input), len(result))
+ }
+
+ for i, curve := range result {
+ if uint16(curve) != input[i] {
+ t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve))
+ }
+ }
+}
+
+// Helper function to parse URL without error handling.
+func mustParseURL(rawURL string) *url.URL {
+ u, err := url.Parse(rawURL)
+ if err != nil {
+ panic(err)
+ }
+ return u
+}
diff --git a/backend/internal/pkg/tlsfingerprint/registry.go b/backend/internal/pkg/tlsfingerprint/registry.go
new file mode 100644
index 00000000..6e9dc539
--- /dev/null
+++ b/backend/internal/pkg/tlsfingerprint/registry.go
@@ -0,0 +1,171 @@
+// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
+package tlsfingerprint
+
+import (
+ "log/slog"
+ "sort"
+ "sync"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+// DefaultProfileName is the name of the built-in Claude CLI profile.
+const DefaultProfileName = "claude_cli_v2"
+
+// Registry manages TLS fingerprint profiles.
+// It holds a collection of profiles that can be used for TLS fingerprint simulation.
+// Profiles are selected based on account ID using modulo operation.
+type Registry struct {
+ mu sync.RWMutex
+ profiles map[string]*Profile
+ profileNames []string // Sorted list of profile names for deterministic selection
+}
+
+// NewRegistry creates a new TLS fingerprint profile registry.
+// It initializes with the built-in default profile.
+func NewRegistry() *Registry {
+ r := &Registry{
+ profiles: make(map[string]*Profile),
+ profileNames: make([]string, 0),
+ }
+
+ // Register the built-in default profile
+ r.registerBuiltinProfile()
+
+ return r
+}
+
+// NewRegistryFromConfig creates a new registry and loads profiles from config.
+// If the config has custom profiles defined, they will be merged with the built-in default.
+func NewRegistryFromConfig(cfg *config.TLSFingerprintConfig) *Registry {
+ r := NewRegistry()
+
+ if cfg == nil || !cfg.Enabled {
+ slog.Debug("tls_registry_disabled", "reason", "disabled or no config")
+ return r
+ }
+
+ // Load custom profiles from config
+ for name, profileCfg := range cfg.Profiles {
+ profile := &Profile{
+ Name: profileCfg.Name,
+ EnableGREASE: profileCfg.EnableGREASE,
+ CipherSuites: profileCfg.CipherSuites,
+ Curves: profileCfg.Curves,
+ PointFormats: profileCfg.PointFormats,
+ }
+
+ // If the profile has empty values, they will use defaults in dialer
+ r.RegisterProfile(name, profile)
+ slog.Debug("tls_registry_loaded_profile", "key", name, "name", profileCfg.Name)
+ }
+
+ slog.Debug("tls_registry_initialized", "profile_count", len(r.profileNames), "profiles", r.profileNames)
+ return r
+}
+
+// registerBuiltinProfile adds the default Claude CLI profile to the registry.
+func (r *Registry) registerBuiltinProfile() {
+ defaultProfile := &Profile{
+ Name: "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)",
+ EnableGREASE: false, // Node.js does not use GREASE
+ // Empty slices will cause dialer to use built-in defaults
+ CipherSuites: nil,
+ Curves: nil,
+ PointFormats: nil,
+ }
+ r.RegisterProfile(DefaultProfileName, defaultProfile)
+}
+
+// RegisterProfile adds or updates a profile in the registry.
+func (r *Registry) RegisterProfile(name string, profile *Profile) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // Check if this is a new profile
+ _, exists := r.profiles[name]
+ r.profiles[name] = profile
+
+ if !exists {
+ r.profileNames = append(r.profileNames, name)
+ // Keep names sorted for deterministic selection
+ sort.Strings(r.profileNames)
+ }
+}
+
+// GetProfile returns a profile by name.
+// Returns nil if the profile does not exist.
+func (r *Registry) GetProfile(name string) *Profile {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.profiles[name]
+}
+
+// GetDefaultProfile returns the built-in default profile.
+func (r *Registry) GetDefaultProfile() *Profile {
+ return r.GetProfile(DefaultProfileName)
+}
+
+// GetProfileByAccountID returns a profile for the given account ID.
+// The profile is selected using: profileNames[accountID % len(profiles)]
+// This ensures deterministic profile assignment for each account.
+func (r *Registry) GetProfileByAccountID(accountID int64) *Profile {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ if len(r.profileNames) == 0 {
+ return nil
+ }
+
+ // Use modulo to select profile index
+ // Use absolute value to handle negative IDs (though unlikely)
+ idx := accountID
+ if idx < 0 {
+ idx = -idx
+ }
+ selectedIndex := int(idx % int64(len(r.profileNames)))
+ selectedName := r.profileNames[selectedIndex]
+
+ return r.profiles[selectedName]
+}
+
+// ProfileCount returns the number of registered profiles.
+func (r *Registry) ProfileCount() int {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return len(r.profiles)
+}
+
+// ProfileNames returns a sorted list of all registered profile names.
+func (r *Registry) ProfileNames() []string {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ // Return a copy to prevent modification
+ names := make([]string, len(r.profileNames))
+ copy(names, r.profileNames)
+ return names
+}
+
+// Global registry instance for convenience
+var globalRegistry *Registry
+var globalRegistryOnce sync.Once
+
+// GlobalRegistry returns the global TLS fingerprint registry.
+// The registry is lazily initialized with the default profile.
+func GlobalRegistry() *Registry {
+ globalRegistryOnce.Do(func() {
+ globalRegistry = NewRegistry()
+ })
+ return globalRegistry
+}
+
+// InitGlobalRegistry initializes the global registry with configuration.
+// This should be called during application startup.
+// It is safe to call multiple times; subsequent calls will update the registry.
+func InitGlobalRegistry(cfg *config.TLSFingerprintConfig) *Registry {
+ globalRegistryOnce.Do(func() {
+ globalRegistry = NewRegistryFromConfig(cfg)
+ })
+ return globalRegistry
+}
diff --git a/backend/internal/pkg/tlsfingerprint/registry_test.go b/backend/internal/pkg/tlsfingerprint/registry_test.go
new file mode 100644
index 00000000..752ba0cc
--- /dev/null
+++ b/backend/internal/pkg/tlsfingerprint/registry_test.go
@@ -0,0 +1,243 @@
+package tlsfingerprint
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+func TestNewRegistry(t *testing.T) {
+ r := NewRegistry()
+
+ // Should have exactly one profile (the default)
+ if r.ProfileCount() != 1 {
+ t.Errorf("expected 1 profile, got %d", r.ProfileCount())
+ }
+
+ // Should have the default profile
+ profile := r.GetDefaultProfile()
+ if profile == nil {
+ t.Error("expected default profile to exist")
+ }
+
+ // Default profile name should be in the list
+ names := r.ProfileNames()
+ if len(names) != 1 || names[0] != DefaultProfileName {
+ t.Errorf("expected profile names to be [%s], got %v", DefaultProfileName, names)
+ }
+}
+
+func TestRegisterProfile(t *testing.T) {
+ r := NewRegistry()
+
+ // Register a new profile
+ customProfile := &Profile{
+ Name: "Custom Profile",
+ EnableGREASE: true,
+ }
+ r.RegisterProfile("custom", customProfile)
+
+ // Should now have 2 profiles
+ if r.ProfileCount() != 2 {
+ t.Errorf("expected 2 profiles, got %d", r.ProfileCount())
+ }
+
+ // Should be able to retrieve the custom profile
+ retrieved := r.GetProfile("custom")
+ if retrieved == nil {
+ t.Fatal("expected custom profile to exist")
+ }
+ if retrieved.Name != "Custom Profile" {
+ t.Errorf("expected profile name 'Custom Profile', got '%s'", retrieved.Name)
+ }
+ if !retrieved.EnableGREASE {
+ t.Error("expected EnableGREASE to be true")
+ }
+}
+
+func TestGetProfile(t *testing.T) {
+ r := NewRegistry()
+
+ // Get existing profile
+ profile := r.GetProfile(DefaultProfileName)
+ if profile == nil {
+ t.Error("expected default profile to exist")
+ }
+
+ // Get non-existing profile
+ nonExistent := r.GetProfile("nonexistent")
+ if nonExistent != nil {
+ t.Error("expected nil for non-existent profile")
+ }
+}
+
+func TestGetProfileByAccountID(t *testing.T) {
+ r := NewRegistry()
+
+ // With only default profile, all account IDs should return the same profile
+ for i := int64(0); i < 10; i++ {
+ profile := r.GetProfileByAccountID(i)
+ if profile == nil {
+ t.Errorf("expected profile for account %d, got nil", i)
+ }
+ }
+
+ // Add more profiles
+ r.RegisterProfile("profile_a", &Profile{Name: "Profile A"})
+ r.RegisterProfile("profile_b", &Profile{Name: "Profile B"})
+
+ // Now we have 3 profiles: claude_cli_v2, profile_a, profile_b
+ // Names are sorted, so order is: claude_cli_v2, profile_a, profile_b
+ expectedOrder := []string{DefaultProfileName, "profile_a", "profile_b"}
+ names := r.ProfileNames()
+ for i, name := range expectedOrder {
+ if names[i] != name {
+ t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
+ }
+ }
+
+ // Test modulo selection
+ // Account ID 0 % 3 = 0 -> claude_cli_v2
+ // Account ID 1 % 3 = 1 -> profile_a
+ // Account ID 2 % 3 = 2 -> profile_b
+ // Account ID 3 % 3 = 0 -> claude_cli_v2
+ testCases := []struct {
+ accountID int64
+ expectedName string
+ }{
+ {0, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
+ {1, "Profile A"},
+ {2, "Profile B"},
+ {3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"},
+ {4, "Profile A"},
+ {5, "Profile B"},
+ {100, "Profile A"}, // 100 % 3 = 1
+ {-1, "Profile A"}, // |-1| % 3 = 1
+ {-3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, // |-3| % 3 = 0
+ }
+
+ for _, tc := range testCases {
+ profile := r.GetProfileByAccountID(tc.accountID)
+ if profile == nil {
+ t.Errorf("expected profile for account %d, got nil", tc.accountID)
+ continue
+ }
+ if profile.Name != tc.expectedName {
+ t.Errorf("account %d: expected profile name '%s', got '%s'", tc.accountID, tc.expectedName, profile.Name)
+ }
+ }
+}
+
+func TestNewRegistryFromConfig(t *testing.T) {
+ // Test with nil config
+ r := NewRegistryFromConfig(nil)
+ if r.ProfileCount() != 1 {
+ t.Errorf("expected 1 profile with nil config, got %d", r.ProfileCount())
+ }
+
+ // Test with disabled config
+ disabledCfg := &config.TLSFingerprintConfig{
+ Enabled: false,
+ }
+ r = NewRegistryFromConfig(disabledCfg)
+ if r.ProfileCount() != 1 {
+ t.Errorf("expected 1 profile with disabled config, got %d", r.ProfileCount())
+ }
+
+ // Test with enabled config and custom profiles
+ enabledCfg := &config.TLSFingerprintConfig{
+ Enabled: true,
+ Profiles: map[string]config.TLSProfileConfig{
+ "custom1": {
+ Name: "Custom Profile 1",
+ EnableGREASE: true,
+ },
+ "custom2": {
+ Name: "Custom Profile 2",
+ EnableGREASE: false,
+ },
+ },
+ }
+ r = NewRegistryFromConfig(enabledCfg)
+
+ // Should have 3 profiles: default + 2 custom
+ if r.ProfileCount() != 3 {
+ t.Errorf("expected 3 profiles, got %d", r.ProfileCount())
+ }
+
+ // Check custom profiles exist
+ custom1 := r.GetProfile("custom1")
+ if custom1 == nil || custom1.Name != "Custom Profile 1" {
+ t.Error("expected custom1 profile to exist with correct name")
+ }
+ custom2 := r.GetProfile("custom2")
+ if custom2 == nil || custom2.Name != "Custom Profile 2" {
+ t.Error("expected custom2 profile to exist with correct name")
+ }
+}
+
+func TestProfileNames(t *testing.T) {
+ r := NewRegistry()
+
+ // Add profiles in non-alphabetical order
+ r.RegisterProfile("zebra", &Profile{Name: "Zebra"})
+ r.RegisterProfile("alpha", &Profile{Name: "Alpha"})
+ r.RegisterProfile("beta", &Profile{Name: "Beta"})
+
+ names := r.ProfileNames()
+
+ // Should be sorted alphabetically
+ expected := []string{"alpha", "beta", DefaultProfileName, "zebra"}
+ if len(names) != len(expected) {
+ t.Errorf("expected %d names, got %d", len(expected), len(names))
+ }
+ for i, name := range expected {
+ if names[i] != name {
+ t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i])
+ }
+ }
+
+ // Test that returned slice is a copy (modifying it shouldn't affect registry)
+ names[0] = "modified"
+ originalNames := r.ProfileNames()
+ if originalNames[0] == "modified" {
+ t.Error("modifying returned slice should not affect registry")
+ }
+}
+
+func TestConcurrentAccess(t *testing.T) {
+ r := NewRegistry()
+
+ // Run concurrent reads and writes
+ done := make(chan bool)
+
+ // Writers
+ for i := 0; i < 10; i++ {
+ go func(id int) {
+ for j := 0; j < 100; j++ {
+ r.RegisterProfile("concurrent"+string(rune('0'+id)), &Profile{Name: "Concurrent"})
+ }
+ done <- true
+ }(i)
+ }
+
+ // Readers
+ for i := 0; i < 10; i++ {
+ go func(id int) {
+ for j := 0; j < 100; j++ {
+ _ = r.ProfileCount()
+ _ = r.ProfileNames()
+ _ = r.GetProfileByAccountID(int64(id * j))
+ _ = r.GetProfile(DefaultProfileName)
+ }
+ done <- true
+ }(i)
+ }
+
+ // Wait for all goroutines
+ for i := 0; i < 20; i++ {
+ <-done
+ }
+
+ // Test should pass without data races (run with -race flag)
+}
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 84bd7b9e..c11c079b 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -575,6 +575,15 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
+func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
+ _, err := r.client.Account.Update().
+ Where(dbaccount.IDEQ(id)).
+ SetStatus(service.StatusActive).
+ SetErrorMessage("").
+ Save(ctx)
+ return err
+}
+
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
_, err := r.client.AccountGroup.Create().
SetAccountID(accountID).
@@ -993,7 +1002,16 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
builder.SetSessionWindowEnd(*end)
}
_, err := builder.Save(ctx)
- return err
+ if err != nil {
+ return err
+ }
+ // 触发调度器缓存更新(仅当窗口时间有变化时)
+ if start != nil || end != nil {
+ if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
+ log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
+ }
+ }
+ return nil
}
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go
index 6d834b40..a1072057 100644
--- a/backend/internal/repository/api_key_cache.go
+++ b/backend/internal/repository/api_key_cache.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "log"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -12,9 +13,10 @@ import (
)
const (
- apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
- apiKeyRateLimitDuration = 24 * time.Hour
- apiKeyAuthCachePrefix = "apikey:auth:"
+ apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
+ apiKeyRateLimitDuration = 24 * time.Hour
+ apiKeyAuthCachePrefix = "apikey:auth:"
+ authCacheInvalidateChannel = "auth:cache:invalidate"
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
@@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}
+
+// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
+func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
+ return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
+}
+
+// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
+func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
+ pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
+
+ // Verify subscription is working
+ _, err := pubsub.Receive(ctx)
+ if err != nil {
+ _ = pubsub.Close()
+ return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
+ }
+
+ go func() {
+ defer func() {
+ if err := pubsub.Close(); err != nil {
+ log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
+ }
+ }()
+
+ ch := pubsub.Channel()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case msg, ok := <-ch:
+ if !ok {
+ return
+ }
+ if msg != nil {
+ handler(msg.Payload)
+ }
+ }
+ }
+ }()
+
+ return nil
+}
diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go
index feb32541..b0f15f19 100644
--- a/backend/internal/repository/http_upstream.go
+++ b/backend/internal/repository/http_upstream.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
+ "log/slog"
"net"
"net/http"
"net/url"
@@ -14,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
@@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
return resp, nil
}
+// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
+// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
+//
+// 参数:
+// - req: HTTP 请求对象
+// - proxyURL: 代理地址,空字符串表示直连
+// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
+// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
+// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
+//
+// TLS 指纹说明:
+// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
+// - 指纹模板根据 accountID % len(profiles) 自动选择
+// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
+func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
+ // 如果未启用 TLS 指纹,直接使用标准请求路径
+ if !enableTLSFingerprint {
+ return s.Do(req, proxyURL, accountID, accountConcurrency)
+ }
+
+ // TLS 指纹已启用,记录调试日志
+ targetHost := ""
+ if req != nil && req.URL != nil {
+ targetHost = req.URL.Host
+ }
+ proxyInfo := "direct"
+ if proxyURL != "" {
+ proxyInfo = proxyURL
+ }
+ slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo)
+
+ if err := s.validateRequestHost(req); err != nil {
+ return nil, err
+ }
+
+ // 获取 TLS 指纹 Profile
+ registry := tlsfingerprint.GlobalRegistry()
+ profile := registry.GetProfileByAccountID(accountID)
+ if profile == nil {
+ // 如果获取不到 profile,回退到普通请求
+ slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request")
+ return s.Do(req, proxyURL, accountID, accountConcurrency)
+ }
+
+ slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE)
+
+ // 获取或创建带 TLS 指纹的客户端
+ entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
+ if err != nil {
+ slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
+ return nil, err
+ }
+
+ // 执行请求
+ resp, err := entry.client.Do(req)
+ if err != nil {
+ // 请求失败,立即减少计数
+ atomic.AddInt64(&entry.inFlight, -1)
+ atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
+ slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err)
+ return nil, err
+ }
+
+ slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode)
+
+ // 包装响应体,在关闭时自动减少计数并更新时间戳
+ resp.Body = wrapTrackedBody(resp.Body, func() {
+ atomic.AddInt64(&entry.inFlight, -1)
+ atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
+ })
+
+ return resp, nil
+}
+
+// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
+func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
+ return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
+}
+
+// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
+// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
+func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
+ isolation := s.getIsolationMode()
+ proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
+ // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
+ cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
+ poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
+
+ now := time.Now()
+ nowUnix := now.UnixNano()
+
+ // 读锁快速路径
+ s.mu.RLock()
+ if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
+ atomic.StoreInt64(&entry.lastUsed, nowUnix)
+ if markInFlight {
+ atomic.AddInt64(&entry.inFlight, 1)
+ }
+ s.mu.RUnlock()
+ slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
+ return entry, nil
+ }
+ s.mu.RUnlock()
+
+ // 写锁慢路径
+ s.mu.Lock()
+ if entry, ok := s.clients[cacheKey]; ok {
+ if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
+ atomic.StoreInt64(&entry.lastUsed, nowUnix)
+ if markInFlight {
+ atomic.AddInt64(&entry.inFlight, 1)
+ }
+ s.mu.Unlock()
+ slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
+ return entry, nil
+ }
+ slog.Debug("tls_fingerprint_evicting_stale_client",
+ "account_id", accountID,
+ "cache_key", cacheKey,
+ "proxy_changed", entry.proxyKey != proxyKey,
+ "pool_changed", entry.poolKey != poolKey)
+ s.removeClientLocked(cacheKey, entry)
+ }
+
+ // 超出缓存上限时尝试淘汰
+ if enforceLimit && s.maxUpstreamClients() > 0 {
+ s.evictIdleLocked(now)
+ if len(s.clients) >= s.maxUpstreamClients() {
+ if !s.evictOldestIdleLocked() {
+ s.mu.Unlock()
+ return nil, errUpstreamClientLimitReached
+ }
+ }
+ }
+
+ // 创建带 TLS 指纹的 Transport
+ slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
+ settings := s.resolvePoolSettings(isolation, accountConcurrency)
+ transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
+ if err != nil {
+ s.mu.Unlock()
+ return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
+ }
+
+ client := &http.Client{Transport: transport}
+ if s.shouldValidateResolvedIP() {
+ client.CheckRedirect = s.redirectChecker
+ }
+
+ entry := &upstreamClientEntry{
+ client: client,
+ proxyKey: proxyKey,
+ poolKey: poolKey,
+ }
+ atomic.StoreInt64(&entry.lastUsed, nowUnix)
+ if markInFlight {
+ atomic.StoreInt64(&entry.inFlight, 1)
+ }
+ s.clients[cacheKey] = entry
+
+ s.evictIdleLocked(now)
+ s.evictOverLimitLocked()
+ s.mu.Unlock()
+ return entry, nil
+}
+
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
@@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
return transport, nil
}
+// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
+// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
+//
+// 参数:
+// - settings: 连接池配置
+// - proxyURL: 代理 URL(nil 表示直连)
+// - profile: TLS 指纹配置
+//
+// 返回:
+// - *http.Transport: 配置好的 Transport 实例
+// - error: 配置错误
+//
+// 代理类型处理:
+// - nil/空: 直连,使用 TLSFingerprintDialer
+// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
+// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
+func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
+ transport := &http.Transport{
+ MaxIdleConns: settings.maxIdleConns,
+ MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
+ MaxConnsPerHost: settings.maxConnsPerHost,
+ IdleConnTimeout: settings.idleConnTimeout,
+ ResponseHeaderTimeout: settings.responseHeaderTimeout,
+ // 禁用默认的 TLS,我们使用自定义的 DialTLSContext
+ ForceAttemptHTTP2: false,
+ }
+
+ // 根据代理类型选择合适的 TLS 指纹 Dialer
+ if proxyURL == nil {
+ // 直连:使用 TLSFingerprintDialer
+ slog.Debug("tls_fingerprint_transport_direct")
+ dialer := tlsfingerprint.NewDialer(profile, nil)
+ transport.DialTLSContext = dialer.DialTLSContext
+ } else {
+ scheme := strings.ToLower(proxyURL.Scheme)
+ switch scheme {
+ case "socks5", "socks5h":
+ // SOCKS5 代理:使用 SOCKS5ProxyDialer
+ slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host)
+ socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
+ transport.DialTLSContext = socks5Dialer.DialTLSContext
+ case "http", "https":
+ // HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
+ slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host)
+ httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
+ transport.DialTLSContext = httpDialer.DialTLSContext
+ default:
+ // 未知代理类型,回退到普通代理配置(无 TLS 指纹)
+ slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme)
+ if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ return transport, nil
+}
+
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type trackedBody struct {
diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go
index d28477b7..c4986547 100644
--- a/backend/internal/repository/identity_cache.go
+++ b/backend/internal/repository/identity_cache.go
@@ -11,8 +11,10 @@ import (
)
const (
- fingerprintKeyPrefix = "fingerprint:"
- fingerprintTTL = 24 * time.Hour
+ fingerprintKeyPrefix = "fingerprint:"
+ fingerprintTTL = 24 * time.Hour
+ maskedSessionKeyPrefix = "masked_session:"
+ maskedSessionTTL = 15 * time.Minute
)
// fingerprintKey generates the Redis key for account fingerprint cache.
@@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string {
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
}
+// maskedSessionKey generates the Redis key for masked session ID cache.
+func maskedSessionKey(accountID int64) string {
+ return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID)
+}
+
type identityCache struct {
rdb *redis.Client
}
@@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}
+
+func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) {
+ key := maskedSessionKey(accountID)
+ val, err := c.rdb.Get(ctx, key).Result()
+ if err != nil {
+ if err == redis.Nil {
+ return "", nil
+ }
+ return "", err
+ }
+ return val, nil
+}
+
+func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error {
+ key := maskedSessionKey(accountID)
+ return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err()
+}
diff --git a/backend/internal/repository/session_limit_cache.go b/backend/internal/repository/session_limit_cache.go
index 16f2a69c..3dc89f87 100644
--- a/backend/internal/repository/session_limit_cache.go
+++ b/backend/internal/repository/session_limit_cache.go
@@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
-func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
@@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
- idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
+ // 使用各账号自己的 idleTimeout,如果没有则用默认值
+ idleTimeout := c.defaultIdleTimeout
+ if idleTimeouts != nil {
+ if t, ok := idleTimeouts[accountID]; ok && t > 0 {
+ idleTimeout = t
+ }
+ }
+ idleTimeoutSeconds := int(idleTimeout.Seconds())
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 0a549b19..be8a8df8 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -618,6 +618,14 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
+func (stubApiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
+ return nil
+}
+
+func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
+ return nil
+}
+
type stubGroupRepo struct{}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
@@ -736,6 +744,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin
return errors.New("not implemented")
}
+func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error {
+ return errors.New("not implemented")
+}
+
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return errors.New("not implemented")
}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 36ba0bcc..27f693d6 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -576,6 +576,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
}
+// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
+// 仅适用于 Anthropic OAuth/SetupToken 类型账号
+// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
+func (a *Account) IsTLSFingerprintEnabled() bool {
+ // 仅支持 Anthropic OAuth/SetupToken 账号
+ if !a.IsAnthropicOAuthOrSetupToken() {
+ return false
+ }
+ if a.Extra == nil {
+ return false
+ }
+ if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
+ if enabled, ok := v.(bool); ok {
+ return enabled
+ }
+ }
+ return false
+}
+
+// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
+// 仅适用于 Anthropic OAuth/SetupToken 类型账号
+// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
+// 使上游认为请求来自同一个会话
+func (a *Account) IsSessionIDMaskingEnabled() bool {
+ if !a.IsAnthropicOAuthOrSetupToken() {
+ return false
+ }
+ if a.Extra == nil {
+ return false
+ }
+ if v, ok := a.Extra["session_id_masking_enabled"]; ok {
+ if enabled, ok := v.(bool); ok {
+ return enabled
+ }
+ }
+ return false
+}
+
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
@@ -652,6 +690,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
return WindowCostNotSchedulable
}
+// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
+// 逻辑:
+// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
+// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
+func (a *Account) GetCurrentWindowStartTime() time.Time {
+ now := time.Now()
+
+ // 窗口未过期,使用记录的窗口开始时间
+ if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
+ return *a.SessionWindowStart
+ }
+
+ // 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
+ // 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
+ return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
+}
+
// parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 {
switch v := value.(type) {
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index ede5b12f..90365d2f 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -37,6 +37,7 @@ type AccountRepository interface {
UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error
+ ClearError(ctx context.Context, id int64) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go
index 36af719c..e5eabfc6 100644
--- a/backend/internal/service/account_service_delete_test.go
+++ b/backend/internal/service/account_service_delete_test.go
@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
panic("unexpected SetError call")
}
+func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
+ panic("unexpected ClearError call")
+}
+
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
panic("unexpected SetSchedulable call")
}
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 8419c2b4..46376c69 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
- resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
- resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
- resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index e3c0974e..6c617e27 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -369,12 +369,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 如果没有缓存,从数据库查询
if windowStats == nil {
- var startTime time.Time
- if account.SessionWindowStart != nil {
- startTime = *account.SessionWindowStart
- } else {
- startTime = time.Now().Add(-5 * time.Hour)
- }
+ // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
+ startTime := account.GetCurrentWindowStartTime()
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index c0694e4e..0afa0716 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -42,6 +42,7 @@ type AdminService interface {
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
+ SetAccountError(ctx context.Context, id int64, errorMsg string) error
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
return account, nil
}
+func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
+ return s.accountRepo.SetError(ctx, id, errorMsg)
+}
+
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index 7f3e97a2..043f338d 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -12,6 +12,7 @@ import (
mathrand "math/rand"
"net"
"net/http"
+ "os"
"strings"
"sync/atomic"
"time"
@@ -28,6 +29,207 @@ const (
antigravityRetryMaxDelay = 16 * time.Second
)
+const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
+
+// antigravityRetryLoopParams 重试循环的参数
+type antigravityRetryLoopParams struct {
+ ctx context.Context
+ prefix string
+ account *Account
+ proxyURL string
+ accessToken string
+ action string
+ body []byte
+ quotaScope AntigravityQuotaScope
+ c *gin.Context
+ httpUpstream HTTPUpstream
+ settingService *SettingService
+ handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
+}
+
+// antigravityRetryLoopResult 重试循环的结果
+type antigravityRetryLoopResult struct {
+ resp *http.Response
+}
+
+// antigravityRetryLoop 执行带 URL fallback 的重试循环
+func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
+ availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
+ if len(availableURLs) == 0 {
+ availableURLs = antigravity.BaseURLs
+ }
+
+ var resp *http.Response
+ var usedBaseURL string
+ logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
+ maxBytes := 2048
+ if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
+ maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
+ }
+ getUpstreamDetail := func(body []byte) string {
+ if !logBody {
+ return ""
+ }
+ return truncateString(string(body), maxBytes)
+ }
+
+urlFallbackLoop:
+ for urlIdx, baseURL := range availableURLs {
+ usedBaseURL = baseURL
+ for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
+ select {
+ case <-p.ctx.Done():
+ log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
+ return nil, p.ctx.Err()
+ default:
+ }
+
+ upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Capture upstream request body for ops retry of this attempt.
+ if p.c != nil && len(p.body) > 0 {
+ p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
+ }
+
+ resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
+ if err != nil {
+ safeErr := sanitizeUpstreamErrorMessage(err.Error())
+ appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
+ Platform: p.account.Platform,
+ AccountID: p.account.ID,
+ AccountName: p.account.Name,
+ UpstreamStatusCode: 0,
+ Kind: "request_error",
+ Message: safeErr,
+ })
+ if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
+ log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
+ continue urlFallbackLoop
+ }
+ if attempt < antigravityMaxRetries {
+ log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
+ if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
+ log.Printf("%s status=context_canceled_during_backoff", p.prefix)
+ return nil, p.ctx.Err()
+ }
+ continue
+ }
+ log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
+ setOpsUpstreamError(p.c, 0, safeErr, "")
+ return nil, fmt.Errorf("upstream request failed after retries: %w", err)
+ }
+
+ // 429 限流处理:区分 URL 级别限流和账户配额限流
+ if resp.StatusCode == http.StatusTooManyRequests {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+
+ // "Resource has been exhausted" 是 URL 级别限流,切换 URL
+ if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
+ log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
+ continue urlFallbackLoop
+ }
+
+ // 账户/模型配额限流,重试 3 次(指数退避)
+ if attempt < antigravityMaxRetries {
+ upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
+ upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
+ Platform: p.account.Platform,
+ AccountID: p.account.ID,
+ AccountName: p.account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ Kind: "retry",
+ Message: upstreamMsg,
+ Detail: getUpstreamDetail(respBody),
+ })
+ log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
+ if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
+ log.Printf("%s status=context_canceled_during_backoff", p.prefix)
+ return nil, p.ctx.Err()
+ }
+ continue
+ }
+
+ // 重试用尽,标记账户限流
+ p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope)
+ log.Printf("%s status=429 rate_limited base_url=%s body=%s", p.prefix, baseURL, truncateForLog(respBody, 200))
+ resp = &http.Response{
+ StatusCode: resp.StatusCode,
+ Header: resp.Header.Clone(),
+ Body: io.NopCloser(bytes.NewReader(respBody)),
+ }
+ break urlFallbackLoop
+ }
+
+ // 其他可重试错误
+ if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+
+ if attempt < antigravityMaxRetries {
+ upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
+ upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
+ Platform: p.account.Platform,
+ AccountID: p.account.ID,
+ AccountName: p.account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ Kind: "retry",
+ Message: upstreamMsg,
+ Detail: getUpstreamDetail(respBody),
+ })
+ log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
+ if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
+ log.Printf("%s status=context_canceled_during_backoff", p.prefix)
+ return nil, p.ctx.Err()
+ }
+ continue
+ }
+ resp = &http.Response{
+ StatusCode: resp.StatusCode,
+ Header: resp.Header.Clone(),
+ Body: io.NopCloser(bytes.NewReader(respBody)),
+ }
+ break urlFallbackLoop
+ }
+
+ break urlFallbackLoop
+ }
+ }
+
+ if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" {
+ antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
+ }
+
+ return &antigravityRetryLoopResult{resp: resp}, nil
+}
+
+// shouldRetryAntigravityError 判断是否应该重试
+func shouldRetryAntigravityError(statusCode int) bool {
+ switch statusCode {
+ case 429, 500, 502, 503, 504, 529:
+ return true
+ default:
+ return false
+ }
+}
+
+// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试)
+// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功
+// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效
+func isURLLevelRateLimit(body []byte) bool {
+ // 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model"
+ bodyStr := string(body)
+ return strings.Contains(bodyStr, "Resource has been exhausted") &&
+ !strings.Contains(bodyStr, "capacity on this model")
+}
+
// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
func isAntigravityConnectionError(err error) bool {
if err == nil {
@@ -238,7 +440,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
if err != nil {
lastErr = fmt.Errorf("请求失败: %w", err)
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
- antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
continue
}
@@ -254,7 +455,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 检查是否需要 URL 降级
if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
- antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
continue
}
@@ -266,6 +466,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
// 解析流式响应,提取文本
text := extractTextFromSSEResponse(respBody)
+ // 标记成功的 URL,下次优先使用
+ antigravity.DefaultURLAvailability.MarkSuccess(baseURL)
return &TestConnectionResult{
Text: text,
MappedModel: mappedModel,
@@ -276,13 +478,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
}
// buildGeminiTestRequest 构建 Gemini 格式测试请求
+// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
payload := map[string]any{
"contents": []map[string]any{
{
"role": "user",
"parts": []map[string]any{
- {"text": "hi"},
+ {"text": "."},
},
},
},
@@ -292,22 +495,26 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri
{"text": antigravity.GetDefaultIdentityPatch()},
},
},
+ "generationConfig": map[string]any{
+ "maxOutputTokens": 1,
+ },
}
payloadBytes, _ := json.Marshal(payload)
return s.wrapV1InternalRequest(projectID, model, payloadBytes)
}
// buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式
+// 使用最小 token 消耗:输入 "." + MaxTokens: 1
func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) {
claudeReq := &antigravity.ClaudeRequest{
Model: mappedModel,
Messages: []antigravity.ClaudeMessage{
{
Role: "user",
- Content: json.RawMessage(`"hi"`),
+ Content: json.RawMessage(`"."`),
},
},
- MaxTokens: 1024,
+ MaxTokens: 1,
Stream: false,
}
return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel)
@@ -523,9 +730,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
proxyURL = account.Proxy.URL()
}
- // Sanitize thinking blocks (clean cache_control and flatten history thinking)
- sanitizeThinkingBlocks(&claudeReq)
-
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx)
@@ -537,150 +741,29 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
- // Safety net: ensure no cache_control leaked into Gemini request
- geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
-
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
- // URL fallback 循环
- availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
- if len(availableURLs) == 0 {
- availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
- }
-
- // 重试循环
- var resp *http.Response
-urlFallbackLoop:
- for urlIdx, baseURL := range availableURLs {
- for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
- // 检查 context 是否已取消(客户端断开连接)
- select {
- case <-ctx.Done():
- log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
- return nil, ctx.Err()
- default:
- }
-
- upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
- // Capture upstream request body for ops retry of this attempt.
- if c != nil {
- c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
- }
- if err != nil {
- return nil, err
- }
-
- resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
- if err != nil {
- safeErr := sanitizeUpstreamErrorMessage(err.Error())
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Platform: account.Platform,
- AccountID: account.ID,
- AccountName: account.Name,
- UpstreamStatusCode: 0,
- Kind: "request_error",
- Message: safeErr,
- })
- // 检查是否应触发 URL 降级
- if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
- antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
- log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
- continue urlFallbackLoop
- }
- if attempt < antigravityMaxRetries {
- log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
- if !sleepAntigravityBackoffWithContext(ctx, attempt) {
- log.Printf("%s status=context_canceled_during_backoff", prefix)
- return nil, ctx.Err()
- }
- continue
- }
- log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
- setOpsUpstreamError(c, 0, safeErr, "")
- return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
- }
-
- // 检查是否应触发 URL 降级(仅 429)
- if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
- respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
- _ = resp.Body.Close()
- upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
- upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
- logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
- maxBytes := 2048
- if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
- maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
- }
- upstreamDetail := ""
- if logBody {
- upstreamDetail = truncateString(string(respBody), maxBytes)
- }
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Platform: account.Platform,
- AccountID: account.ID,
- AccountName: account.Name,
- UpstreamStatusCode: resp.StatusCode,
- UpstreamRequestID: resp.Header.Get("x-request-id"),
- Kind: "retry",
- Message: upstreamMsg,
- Detail: upstreamDetail,
- })
- antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
- log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
- continue urlFallbackLoop
- }
-
- if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
- respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
- _ = resp.Body.Close()
-
- if attempt < antigravityMaxRetries {
- upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
- upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
- logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
- maxBytes := 2048
- if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
- maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
- }
- upstreamDetail := ""
- if logBody {
- upstreamDetail = truncateString(string(respBody), maxBytes)
- }
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Platform: account.Platform,
- AccountID: account.ID,
- AccountName: account.Name,
- UpstreamStatusCode: resp.StatusCode,
- UpstreamRequestID: resp.Header.Get("x-request-id"),
- Kind: "retry",
- Message: upstreamMsg,
- Detail: upstreamDetail,
- })
- log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
- if !sleepAntigravityBackoffWithContext(ctx, attempt) {
- log.Printf("%s status=context_canceled_during_backoff", prefix)
- return nil, ctx.Err()
- }
- continue
- }
- // 所有重试都失败,标记限流状态
- if resp.StatusCode == 429 {
- s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
- }
- // 最后一次尝试也失败
- resp = &http.Response{
- StatusCode: resp.StatusCode,
- Header: resp.Header.Clone(),
- Body: io.NopCloser(bytes.NewReader(respBody)),
- }
- break urlFallbackLoop
- }
-
- break urlFallbackLoop
- }
+ // 执行带重试的请求
+ result, err := antigravityRetryLoop(antigravityRetryLoopParams{
+ ctx: ctx,
+ prefix: prefix,
+ account: account,
+ proxyURL: proxyURL,
+ accessToken: accessToken,
+ action: action,
+ body: geminiBody,
+ quotaScope: quotaScope,
+ c: c,
+ httpUpstream: s.httpUpstream,
+ settingService: s.settingService,
+ handleError: s.handleUpstreamError,
+ })
+ if err != nil {
+ return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
+ resp := result.resp
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
@@ -739,11 +822,20 @@ urlFallbackLoop:
if txErr != nil {
continue
}
- retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
- if buildErr != nil {
- continue
- }
- retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
+ retryResult, retryErr := antigravityRetryLoop(antigravityRetryLoopParams{
+ ctx: ctx,
+ prefix: prefix,
+ account: account,
+ proxyURL: proxyURL,
+ accessToken: accessToken,
+ action: action,
+ body: retryGeminiBody,
+ quotaScope: quotaScope,
+ c: c,
+ httpUpstream: s.httpUpstream,
+ settingService: s.settingService,
+ handleError: s.handleUpstreamError,
+ })
if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
@@ -757,6 +849,7 @@ urlFallbackLoop:
continue
}
+ retryResp := retryResult.resp
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
@@ -766,6 +859,13 @@ urlFallbackLoop:
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
+ if retryResp.StatusCode == http.StatusTooManyRequests {
+ retryBaseURL := ""
+ if retryResp.Request != nil && retryResp.Request.URL != nil {
+ retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host
+ }
+ log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200))
+ }
kind := "signature_retry"
if strings.TrimSpace(stage.name) != "" {
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
@@ -920,143 +1020,6 @@ func extractAntigravityErrorMessage(body []byte) string {
return ""
}
-// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
-// This should not be needed if transformation is correct, but serves as a safety net
-func cleanCacheControlFromGeminiJSON(body []byte) []byte {
- // Try a more robust approach: parse and clean
- var data map[string]any
- if err := json.Unmarshal(body, &data); err != nil {
- log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
- return body
- }
-
- cleaned := removeCacheControlFromAny(data)
- if !cleaned {
- return body
- }
-
- if result, err := json.Marshal(data); err == nil {
- log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
- return result
- }
-
- return body
-}
-
-// removeCacheControlFromAny recursively removes cache_control fields
-func removeCacheControlFromAny(v any) bool {
- cleaned := false
-
- switch val := v.(type) {
- case map[string]any:
- for k, child := range val {
- if k == "cache_control" {
- delete(val, k)
- cleaned = true
- } else if removeCacheControlFromAny(child) {
- cleaned = true
- }
- }
- case []any:
- for _, item := range val {
- if removeCacheControlFromAny(item) {
- cleaned = true
- }
- }
- }
-
- return cleaned
-}
-
-// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
-// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
-// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
-func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
- if req == nil {
- return
- }
-
- log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
-
- // Clean system blocks
- if len(req.System) > 0 {
- var systemBlocks []map[string]any
- if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
- for i := range systemBlocks {
- if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
- if removeCacheControlFromAny(systemBlocks[i]) {
- log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
- }
- }
- }
- // Marshal back
- if cleaned, err := json.Marshal(systemBlocks); err == nil {
- req.System = cleaned
- }
- }
- }
-
- // Clean message content blocks and flatten history
- lastMsgIdx := len(req.Messages) - 1
- for msgIdx := range req.Messages {
- raw := req.Messages[msgIdx].Content
- if len(raw) == 0 {
- continue
- }
-
- // Try to parse as blocks array
- var blocks []map[string]any
- if err := json.Unmarshal(raw, &blocks); err != nil {
- continue
- }
-
- cleaned := false
- for blockIdx := range blocks {
- blockType, _ := blocks[blockIdx]["type"].(string)
-
- // Check for thinking blocks (typed or untyped)
- if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
- // 1. Clean cache_control
- if removeCacheControlFromAny(blocks[blockIdx]) {
- log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
- cleaned = true
- }
-
- // 2. Flatten to text if it's a history message (not the last one)
- if msgIdx < lastMsgIdx {
- log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
-
- // Extract thinking content
- var textContent string
- if t, ok := blocks[blockIdx]["thinking"].(string); ok {
- textContent = t
- } else {
- // Fallback for non-string content (marshal it)
- if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
- textContent = string(b)
- }
- }
-
- // Convert to text block
- blocks[blockIdx]["type"] = "text"
- blocks[blockIdx]["text"] = textContent
- delete(blocks[blockIdx], "thinking")
- delete(blocks[blockIdx], "signature")
- delete(blocks[blockIdx], "cache_control") // Ensure it's gone
- cleaned = true
- }
- }
- }
-
- // Marshal back if modified
- if cleaned {
- if marshaled, err := json.Marshal(blocks); err == nil {
- req.Messages[msgIdx].Content = marshaled
- }
- }
- }
-}
-
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
@@ -1352,138 +1315,25 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回
upstreamAction := "streamGenerateContent"
- // URL fallback 循环
- availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
- if len(availableURLs) == 0 {
- availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
- }
-
- // 重试循环
- var resp *http.Response
-urlFallbackLoop:
- for urlIdx, baseURL := range availableURLs {
- for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
- // 检查 context 是否已取消(客户端断开连接)
- select {
- case <-ctx.Done():
- log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
- return nil, ctx.Err()
- default:
- }
-
- upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody)
- if err != nil {
- return nil, err
- }
-
- resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
- if err != nil {
- safeErr := sanitizeUpstreamErrorMessage(err.Error())
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Platform: account.Platform,
- AccountID: account.ID,
- AccountName: account.Name,
- UpstreamStatusCode: 0,
- Kind: "request_error",
- Message: safeErr,
- })
- // 检查是否应触发 URL 降级
- if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
- antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
- log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1])
- continue urlFallbackLoop
- }
- if attempt < antigravityMaxRetries {
- log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
- if !sleepAntigravityBackoffWithContext(ctx, attempt) {
- log.Printf("%s status=context_canceled_during_backoff", prefix)
- return nil, ctx.Err()
- }
- continue
- }
- log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
- setOpsUpstreamError(c, 0, safeErr, "")
- return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
- }
-
- // 检查是否应触发 URL 降级(仅 429)
- if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 {
- respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
- _ = resp.Body.Close()
- upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
- upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
- logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
- maxBytes := 2048
- if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
- maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
- }
- upstreamDetail := ""
- if logBody {
- upstreamDetail = truncateString(string(respBody), maxBytes)
- }
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Platform: account.Platform,
- AccountID: account.ID,
- AccountName: account.Name,
- UpstreamStatusCode: resp.StatusCode,
- UpstreamRequestID: resp.Header.Get("x-request-id"),
- Kind: "retry",
- Message: upstreamMsg,
- Detail: upstreamDetail,
- })
- antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
- log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
- continue urlFallbackLoop
- }
-
- if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
- respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
- _ = resp.Body.Close()
-
- if attempt < antigravityMaxRetries {
- upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
- upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
- logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
- maxBytes := 2048
- if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
- maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
- }
- upstreamDetail := ""
- if logBody {
- upstreamDetail = truncateString(string(respBody), maxBytes)
- }
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Platform: account.Platform,
- AccountID: account.ID,
- AccountName: account.Name,
- UpstreamStatusCode: resp.StatusCode,
- UpstreamRequestID: resp.Header.Get("x-request-id"),
- Kind: "retry",
- Message: upstreamMsg,
- Detail: upstreamDetail,
- })
- log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
- if !sleepAntigravityBackoffWithContext(ctx, attempt) {
- log.Printf("%s status=context_canceled_during_backoff", prefix)
- return nil, ctx.Err()
- }
- continue
- }
- // 所有重试都失败,标记限流状态
- if resp.StatusCode == 429 {
- s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
- }
- resp = &http.Response{
- StatusCode: resp.StatusCode,
- Header: resp.Header.Clone(),
- Body: io.NopCloser(bytes.NewReader(respBody)),
- }
- break urlFallbackLoop
- }
-
- break urlFallbackLoop
- }
+ // 执行带重试的请求
+ result, err := antigravityRetryLoop(antigravityRetryLoopParams{
+ ctx: ctx,
+ prefix: prefix,
+ account: account,
+ proxyURL: proxyURL,
+ accessToken: accessToken,
+ action: upstreamAction,
+ body: wrappedBody,
+ quotaScope: quotaScope,
+ c: c,
+ httpUpstream: s.httpUpstream,
+ settingService: s.settingService,
+ handleError: s.handleUpstreamError,
+ })
+ if err != nil {
+ return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
+ resp := result.resp
defer func() {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
@@ -1525,8 +1375,6 @@ urlFallbackLoop:
goto handleSuccess
}
- s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
-
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
@@ -1537,6 +1385,7 @@ urlFallbackLoop:
if unwrapErr != nil || len(unwrappedForOps) == 0 {
unwrappedForOps = respBody
}
+ s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
@@ -1581,6 +1430,7 @@ urlFallbackLoop:
Message: upstreamMsg,
Detail: upstreamDetail,
})
+ log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
c.Data(resp.StatusCode, contentType, unwrappedForOps)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
@@ -1637,15 +1487,6 @@ handleSuccess:
}, nil
}
-func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool {
- switch statusCode {
- case 429, 500, 502, 503, 504, 529:
- return true
- default:
- return false
- }
-}
-
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
@@ -1679,33 +1520,48 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
+func antigravityUseScopeRateLimit() bool {
+ v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
+ return v == "1" || v == "true" || v == "yes" || v == "on"
+}
+
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
+ useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != ""
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
- // 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟
- defaultDur := 1 * time.Minute
- if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) {
- defaultDur = 5 * time.Minute
+ // 解析失败:使用配置的 fallback 时间,直接限流整个账户
+ fallbackMinutes := 5
+ if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 {
+ fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
}
+ defaultDur := time.Duration(fallbackMinutes) * time.Minute
ra := time.Now().Add(defaultDur)
- log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
- if quotaScope == "" {
- return
- }
- if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
- log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
+ if useScopeLimit {
+ log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
+ if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
+ log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
+ }
+ } else {
+ log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
+ log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
+ }
}
return
}
resetTime := time.Unix(*resetAt, 0)
- log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
- if quotaScope == "" {
- return
- }
- if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
- log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
+ if useScopeLimit {
+ log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
+ if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
+ log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
+ }
+ } else {
+ log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
+ log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
+ }
}
return
}
@@ -1884,7 +1740,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
// handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端
-// Gemini 流式响应中每个 chunk 都包含累积的完整文本,只需保留最后一个有效响应
+// Gemini 流式响应是增量的,需要累积所有 chunk 的内容
func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -1897,6 +1753,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
+ var collectedImageParts []map[string]any // 收集所有包含图片的 parts
+ var collectedTextParts []string // 收集所有文本片段
type scanEvent struct {
line string
@@ -1999,6 +1857,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
+ // 收集包含图片和文本的 parts
+ for _, part := range parts {
+ if inlineData, ok := part["inlineData"].(map[string]any); ok {
+ collectedImageParts = append(collectedImageParts, part)
+ _ = inlineData // 避免 unused 警告
+ }
+ if text, ok := part["text"].(string); ok && text != "" {
+ collectedTextParts = append(collectedTextParts, text)
+ }
+ }
}
case <-intervalCh:
@@ -2020,6 +1888,16 @@ returnResponse:
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
}
+ // 如果收集到了图片 parts,需要合并到最终响应中
+ if len(collectedImageParts) > 0 {
+ finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts)
+ }
+
+ // 如果收集到了文本,需要合并到最终响应中
+ if len(collectedTextParts) > 0 {
+ finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
+ }
+
respBody, err := json.Marshal(finalResponse)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
@@ -2029,6 +1907,115 @@ returnResponse:
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
+// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调
+func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) {
+ // 深拷贝 response
+ result = make(map[string]any)
+ for k, v := range response {
+ result[k] = v
+ }
+
+ // 获取或创建 candidates
+ candidates, ok := result["candidates"].([]any)
+ if !ok || len(candidates) == 0 {
+ candidates = []any{map[string]any{}}
+ }
+
+ // 获取第一个 candidate
+ candidate, ok := candidates[0].(map[string]any)
+ if !ok {
+ candidate = make(map[string]any)
+ candidates[0] = candidate
+ }
+
+ // 获取或创建 content
+ content, ok := candidate["content"].(map[string]any)
+ if !ok {
+ content = map[string]any{"role": "model"}
+ candidate["content"] = content
+ }
+
+ // 获取现有 parts
+ existingParts, ok = content["parts"].([]any)
+ if !ok {
+ existingParts = []any{}
+ }
+
+ // 返回更新回调
+ setParts = func(newParts []any) {
+ content["parts"] = newParts
+ result["candidates"] = candidates
+ }
+
+ return result, existingParts, setParts
+}
+
+// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
+func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
+ if len(imageParts) == 0 {
+ return response
+ }
+
+ result, existingParts, setParts := getOrCreateGeminiParts(response)
+
+ // 检查现有 parts 中是否已经有图片
+ for _, p := range existingParts {
+ if pm, ok := p.(map[string]any); ok {
+ if _, hasInline := pm["inlineData"]; hasInline {
+ return result // 已有图片,不重复添加
+ }
+ }
+ }
+
+ // 添加收集到的图片 parts
+ for _, imgPart := range imageParts {
+ existingParts = append(existingParts, imgPart)
+ }
+ setParts(existingParts)
+ return result
+}
+
+// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中
+func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any {
+ if len(textParts) == 0 {
+ return response
+ }
+
+ mergedText := strings.Join(textParts, "")
+ result, existingParts, setParts := getOrCreateGeminiParts(response)
+
+ // 查找并更新第一个 text part,或创建新的
+ newParts := make([]any, 0, len(existingParts)+1)
+ textUpdated := false
+
+ for _, p := range existingParts {
+ pm, ok := p.(map[string]any)
+ if !ok {
+ newParts = append(newParts, p)
+ continue
+ }
+ if _, hasText := pm["text"]; hasText && !textUpdated {
+ // 用累积的文本替换
+ newPart := make(map[string]any)
+ for k, v := range pm {
+ newPart[k] = v
+ }
+ newPart["text"] = mergedText
+ newParts = append(newParts, newPart)
+ textUpdated = true
+ } else {
+ newParts = append(newParts, pm)
+ }
+ }
+
+ if !textUpdated {
+ newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
+ }
+
+ setParts(newParts)
+ return result
+}
+
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go
index ecf0a553..52293cd5 100644
--- a/backend/internal/service/antigravity_oauth_service.go
+++ b/backend/internal/service/antigravity_oauth_service.go
@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
// AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct {
- AccessToken string `json:"access_token"`
- RefreshToken string `json:"refresh_token"`
- ExpiresIn int64 `json:"expires_in"`
- ExpiresAt int64 `json:"expires_at"`
- TokenType string `json:"token_type"`
- Email string `json:"email,omitempty"`
- ProjectID string `json:"project_id,omitempty"`
+ AccessToken string `json:"access_token"`
+ RefreshToken string `json:"refresh_token"`
+ ExpiresIn int64 `json:"expires_in"`
+ ExpiresAt int64 `json:"expires_at"`
+ TokenType string `json:"token_type"`
+ Email string `json:"email,omitempty"`
+ ProjectID string `json:"project_id,omitempty"`
+ ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
}
// ExchangeCode 用 authorization code 交换 token
@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.ProjectID = loadResp.CloudAICompanionProject
}
- // 兜底:随机生成 project_id
- if result.ProjectID == "" {
- result.ProjectID = antigravity.GenerateMockProjectID()
- fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
- }
-
return result, nil
}
@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
return nil, err
}
- // 保留原有的 project_id 和 email
- existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
- if existingProjectID != "" {
- tokenInfo.ProjectID = existingProjectID
- }
+ // 保留原有的 email
existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" {
tokenInfo.Email = existingEmail
}
+ // 每次刷新都调用 LoadCodeAssist 获取 project_id
+ client := antigravity.NewClient(proxyURL)
+ loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
+ if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
+ // LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
+ existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
+ tokenInfo.ProjectID = existingProjectID
+ tokenInfo.ProjectIDMissing = true
+ } else {
+ tokenInfo.ProjectID = loadResp.CloudAICompanionProject
+ }
+
return tokenInfo, nil
}
diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go
index c9024e33..07eb563d 100644
--- a/backend/internal/service/antigravity_quota_fetcher.go
+++ b/backend/internal/service/antigravity_quota_fetcher.go
@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
- // 如果没有 project_id,生成一个随机的
- if projectID == "" {
- projectID = antigravity.GenerateMockProjectID()
- }
-
client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额
diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go
new file mode 100644
index 00000000..53ec6fdf
--- /dev/null
+++ b/backend/internal/service/antigravity_rate_limit_test.go
@@ -0,0 +1,190 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
+ "github.com/stretchr/testify/require"
+)
+
+type stubAntigravityUpstream struct {
+ firstBase string
+ secondBase string
+ calls []string
+}
+
+func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
+ url := req.URL.String()
+ s.calls = append(s.calls, url)
+ if strings.HasPrefix(url, s.firstBase) {
+ return &http.Response{
+ StatusCode: http.StatusTooManyRequests,
+ Header: http.Header{},
+ Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
+ }, nil
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{},
+ Body: io.NopCloser(strings.NewReader("ok")),
+ }, nil
+}
+
+func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
+ return s.Do(req, proxyURL, accountID, accountConcurrency)
+}
+
+type scopeLimitCall struct {
+ accountID int64
+ scope AntigravityQuotaScope
+ resetAt time.Time
+}
+
+type rateLimitCall struct {
+ accountID int64
+ resetAt time.Time
+}
+
+type stubAntigravityAccountRepo struct {
+ AccountRepository
+ scopeCalls []scopeLimitCall
+ rateCalls []rateLimitCall
+}
+
+func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
+ s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
+ return nil
+}
+
+func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
+ s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
+ return nil
+}
+
+func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
+ oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
+ oldAvailability := antigravity.DefaultURLAvailability
+ defer func() {
+ antigravity.BaseURLs = oldBaseURLs
+ antigravity.DefaultURLAvailability = oldAvailability
+ }()
+
+ base1 := "https://ag-1.test"
+ base2 := "https://ag-2.test"
+ antigravity.BaseURLs = []string{base1, base2}
+ antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
+
+ upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
+ account := &Account{
+ ID: 1,
+ Name: "acc-1",
+ Platform: PlatformAntigravity,
+ Schedulable: true,
+ Status: StatusActive,
+ Concurrency: 1,
+ }
+
+ var handleErrorCalled bool
+ result, err := antigravityRetryLoop(antigravityRetryLoopParams{
+ prefix: "[test]",
+ ctx: context.Background(),
+ account: account,
+ proxyURL: "",
+ accessToken: "token",
+ action: "generateContent",
+ body: []byte(`{"input":"test"}`),
+ quotaScope: AntigravityQuotaScopeClaude,
+ httpUpstream: upstream,
+ handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
+ handleErrorCalled = true
+ },
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.resp)
+ defer func() { _ = result.resp.Body.Close() }()
+ require.Equal(t, http.StatusOK, result.resp.StatusCode)
+ require.False(t, handleErrorCalled)
+ require.Len(t, upstream.calls, 2)
+ require.True(t, strings.HasPrefix(upstream.calls[0], base1))
+ require.True(t, strings.HasPrefix(upstream.calls[1], base2))
+
+ available := antigravity.DefaultURLAvailability.GetAvailableURLs()
+ require.NotEmpty(t, available)
+ require.Equal(t, base2, available[0])
+}
+
+func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
+ t.Setenv(antigravityScopeRateLimitEnv, "true")
+ repo := &stubAntigravityAccountRepo{}
+ svc := &AntigravityGatewayService{accountRepo: repo}
+ account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
+
+ body := buildGeminiRateLimitBody("3s")
+ svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
+
+ require.Len(t, repo.scopeCalls, 1)
+ require.Empty(t, repo.rateCalls)
+ call := repo.scopeCalls[0]
+ require.Equal(t, account.ID, call.accountID)
+ require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
+ require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
+}
+
+func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
+ t.Setenv(antigravityScopeRateLimitEnv, "false")
+ repo := &stubAntigravityAccountRepo{}
+ svc := &AntigravityGatewayService{accountRepo: repo}
+ account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
+
+ body := buildGeminiRateLimitBody("2s")
+ svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
+
+ require.Len(t, repo.rateCalls, 1)
+ require.Empty(t, repo.scopeCalls)
+ call := repo.rateCalls[0]
+ require.Equal(t, account.ID, call.accountID)
+ require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
+}
+
+func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
+ now := time.Now()
+ future := now.Add(10 * time.Minute)
+
+ account := &Account{
+ ID: 1,
+ Name: "acc",
+ Platform: PlatformAntigravity,
+ Status: StatusActive,
+ Schedulable: true,
+ }
+
+ account.RateLimitResetAt = &future
+ require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
+ require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
+
+ account.RateLimitResetAt = nil
+ account.Extra = map[string]any{
+ antigravityQuotaScopesKey: map[string]any{
+ "claude": map[string]any{
+ "rate_limit_reset_at": future.Format(time.RFC3339),
+ },
+ },
+ }
+
+ require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
+ require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
+}
+
+func buildGeminiRateLimitBody(delay string) []byte {
+ return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
+}
diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go
index 9dd4463f..a07c86e6 100644
--- a/backend/internal/service/antigravity_token_refresher.go
+++ b/backend/internal/service/antigravity_token_refresher.go
@@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
}
}
+ // 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
+ if tokenInfo.ProjectIDMissing {
+ return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
+ }
+
return newCredentials, nil
}
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index 521f1da5..eb5c7534 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCacheL1 = cache
}
+// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
+// This should be called after the service is fully initialized.
+func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
+ if s.cache == nil || s.authCacheL1 == nil {
+ return
+ }
+ if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
+ s.authCacheL1.Del(cacheKey)
+ }); err != nil {
+ // Log but don't fail - L1 cache will still work, just without cross-instance invalidation
+ println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
+ }
+}
+
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
+ // Publish invalidation message to other instances
+ _ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go
index ecc570c7..ef1ff990 100644
--- a/backend/internal/service/api_key_service.go
+++ b/backend/internal/service/api_key_service.go
@@ -65,6 +65,10 @@ type APIKeyCache interface {
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
+
+ // Pub/Sub for L1 cache invalidation across instances
+ PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
+ SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go
index 5f2d69c4..c5e9cd47 100644
--- a/backend/internal/service/api_key_service_cache_test.go
+++ b/backend/internal/service/api_key_service_cache_test.go
@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
+func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
+ return nil
+}
+
+func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
+ return nil
+}
+
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go
index 32ae884e..092b7fce 100644
--- a/backend/internal/service/api_key_service_delete_test.go
+++ b/backend/internal/service/api_key_service_delete_test.go
@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
return nil
}
+func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
+ return nil
+}
+
+func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
+ return nil
+}
+
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 1
diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go
index a38c34cd..9850cbf0 100644
--- a/backend/internal/service/gateway_multiplatform_test.go
+++ b/backend/internal/service/gateway_multiplatform_test.go
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
+func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
+ return nil
+}
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index a08f3e48..9565da29 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"log"
+ "log/slog"
+ mathrand "math/rand"
"net/http"
"os"
"regexp"
@@ -445,11 +447,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
-// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
+// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
+ // 调试日志:记录调度入口参数
+ excludedIDsList := make([]int64, 0, len(excludedIDs))
+ for id := range excludedIDs {
+ excludedIDsList = append(excludedIDsList, id)
+ }
+ slog.Debug("account_scheduling_starting",
+ "group_id", derefGroupID(groupID),
+ "model", requestedModel,
+ "session", shortSessionHash(sessionHash),
+ "excluded_ids", excludedIDsList)
+
cfg := s.schedulingConfig()
- // 提取会话 UUID(用于会话数量限制)
- sessionUUID := extractSessionUUID(metadataUserID)
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
@@ -475,41 +486,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
- account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
- if err != nil {
- return nil, err
+ // 复制排除列表,用于会话限制拒绝时的重试
+ localExcluded := make(map[int64]struct{})
+ for k, v := range excludedIDs {
+ localExcluded[k] = v
}
- result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
- if err == nil && result.Acquired {
- return &AccountSelectionResult{
- Account: account,
- Acquired: true,
- ReleaseFunc: result.ReleaseFunc,
- }, nil
- }
- if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
- waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
- if waitingCount < cfg.StickySessionMaxWaiting {
+
+ for {
+ account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
+ if err != nil {
+ return nil, err
+ }
+
+ result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
+ if err == nil && result.Acquired {
+ // 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
+ if !s.checkAndRegisterSession(ctx, account, sessionHash) {
+ result.ReleaseFunc() // 释放槽位
+ localExcluded[account.ID] = struct{}{} // 排除此账号
+ continue // 重新选择
+ }
return &AccountSelectionResult{
- Account: account,
- WaitPlan: &AccountWaitPlan{
- AccountID: account.ID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- },
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
}, nil
}
+
+ // 对于等待计划的情况,也需要先检查会话限制
+ if !s.checkAndRegisterSession(ctx, account, sessionHash) {
+ localExcluded[account.ID] = struct{}{}
+ continue
+ }
+
+ if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
+ waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
+ if waitingCount < cfg.StickySessionMaxWaiting {
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: account.ID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
+ }
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: account.ID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.FallbackWaitTimeout,
+ MaxWaiting: cfg.FallbackMaxWaiting,
+ },
+ }, nil
}
- return &AccountSelectionResult{
- Account: account,
- WaitPlan: &AccountWaitPlan{
- AccountID: account.ID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.FallbackWaitTimeout,
- MaxWaiting: cfg.FallbackMaxWaiting,
- },
- }, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
@@ -625,7 +658,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
- if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
+ if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
@@ -643,15 +676,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
- return &AccountSelectionResult{
- Account: stickyAccount,
- WaitPlan: &AccountWaitPlan{
- AccountID: stickyAccountID,
- MaxConcurrency: stickyAccount.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- },
- }, nil
+ // 会话数量限制检查(等待计划也需要占用会话配额)
+ if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
+ // 会话限制已满,继续到负载感知选择
+ } else {
+ return &AccountSelectionResult{
+ Account: stickyAccount,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: stickyAccountID,
+ MaxConcurrency: stickyAccount.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
@@ -714,7 +752,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
- if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
+ if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
@@ -732,20 +770,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
- // 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
- acc := routingAvailable[0].account
- if s.debugModelRoutingEnabled() {
- log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
+ // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
+ // 遍历找到第一个满足会话限制的账号
+ for _, item := range routingAvailable {
+ if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
+ continue // 会话限制已满,尝试下一个
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
+ }
+ return &AccountSelectionResult{
+ Account: item.account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: item.account.ID,
+ MaxConcurrency: item.account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
}
- return &AccountSelectionResult{
- Account: acc,
- WaitPlan: &AccountWaitPlan{
- AccountID: acc.ID,
- MaxConcurrency: acc.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- },
- }, nil
+ // 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
@@ -773,7 +817,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
- if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
+ if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
@@ -787,15 +831,22 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
- return &AccountSelectionResult{
- Account: account,
- WaitPlan: &AccountWaitPlan{
- AccountID: accountID,
- MaxConcurrency: account.Concurrency,
- Timeout: cfg.StickySessionWaitTimeout,
- MaxWaiting: cfg.StickySessionMaxWaiting,
- },
- }, nil
+ // 会话数量限制检查(等待计划也需要占用会话配额)
+ // Session count limit check (wait plan also requires session quota)
+ if !s.checkAndRegisterSession(ctx, account, sessionHash) {
+ // 会话限制已满,继续到 Layer 2
+ // Session limit full, continue to Layer 2
+ } else {
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: accountID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
}
}
}
@@ -845,7 +896,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
- if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
+ if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
@@ -895,7 +946,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
- if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
+ if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
@@ -913,8 +964,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// ============ Layer 3: 兜底排队 ============
- sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
+ s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
for _, acc := range candidates {
+ // 会话数量限制检查(等待计划也需要占用会话配额)
+ if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
+ continue // 会话限制已满,尝试下一个账号
+ }
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
@@ -928,7 +983,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
-func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
+func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
@@ -936,7 +991,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
- if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
+ if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
@@ -1093,7 +1148,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if s.schedulerSnapshot != nil {
- return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
+ accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
+ if err == nil {
+ slog.Debug("account_scheduling_list_snapshot",
+ "group_id", derefGroupID(groupID),
+ "platform", platform,
+ "use_mixed", useMixed,
+ "count", len(accounts))
+ for _, acc := range accounts {
+ slog.Debug("account_scheduling_account_detail",
+ "account_id", acc.ID,
+ "name", acc.Name,
+ "platform", acc.Platform,
+ "type", acc.Type,
+ "status", acc.Status,
+ "tls_fingerprint", acc.IsTLSFingerprintEnabled())
+ }
+ }
+ return accounts, useMixed, err
}
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
@@ -1106,6 +1178,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
+ slog.Debug("account_scheduling_list_failed",
+ "group_id", derefGroupID(groupID),
+ "platform", platform,
+ "error", err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
@@ -1115,6 +1191,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
}
filtered = append(filtered, acc)
}
+ slog.Debug("account_scheduling_list_mixed",
+ "group_id", derefGroupID(groupID),
+ "platform", platform,
+ "raw_count", len(accounts),
+ "filtered_count", len(filtered))
+ for _, acc := range filtered {
+ slog.Debug("account_scheduling_account_detail",
+ "account_id", acc.ID,
+ "name", acc.Name,
+ "platform", acc.Platform,
+ "type", acc.Type,
+ "status", acc.Status,
+ "tls_fingerprint", acc.IsTLSFingerprintEnabled())
+ }
return filtered, useMixed, nil
}
@@ -1129,8 +1219,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
+ slog.Debug("account_scheduling_list_failed",
+ "group_id", derefGroupID(groupID),
+ "platform", platform,
+ "error", err)
return nil, useMixed, err
}
+ slog.Debug("account_scheduling_list_single",
+ "group_id", derefGroupID(groupID),
+ "platform", platform,
+ "count", len(accounts))
+ for _, acc := range accounts {
+ slog.Debug("account_scheduling_account_detail",
+ "account_id", acc.ID,
+ "name", acc.Name,
+ "platform", acc.Platform,
+ "type", acc.Type,
+ "status", acc.Status,
+ "tls_fingerprint", acc.IsTLSFingerprintEnabled())
+ }
return accounts, useMixed, nil
}
@@ -1196,12 +1303,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 缓存未命中,从数据库查询
{
- var startTime time.Time
- if account.SessionWindowStart != nil {
- startTime = *account.SessionWindowStart
- } else {
- startTime = time.Now().Add(-5 * time.Hour)
- }
+ // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
+ startTime := account.GetCurrentWindowStartTime()
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
@@ -1234,15 +1337,16 @@ checkSchedulability:
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
+// sessionID: 会话标识符(使用粘性会话的 hash)
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
-func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
+func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
maxSessions := account.GetMaxSessions()
- if maxSessions <= 0 || sessionUUID == "" {
+ if maxSessions <= 0 || sessionID == "" {
return true // 未启用会话限制或无会话ID
}
@@ -1252,7 +1356,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
- allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
+ allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
if err != nil {
// 失败开放:缓存错误时允许通过
return true
@@ -1260,18 +1364,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
return allowed
}
-// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
-// 格式: user_{64位hex}_account__session_{uuid}
-func extractSessionUUID(metadataUserID string) string {
- if metadataUserID == "" {
- return ""
- }
- if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
- return match[1]
- }
- return ""
-}
-
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
@@ -1301,6 +1393,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
})
}
+// sortCandidatesForFallback 根据配置选择排序策略
+// mode: "last_used"(按最后使用时间) 或 "random"(随机)
+func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
+ if mode == "random" {
+ // 先按优先级排序,然后在同优先级内随机打乱
+ sortAccountsByPriorityOnly(accounts, preferOAuth)
+ shuffleWithinPriority(accounts)
+ } else {
+ // 默认按最后使用时间排序
+ sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
+ }
+}
+
+// sortAccountsByPriorityOnly 仅按优先级排序
+func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
+ sort.SliceStable(accounts, func(i, j int) bool {
+ a, b := accounts[i], accounts[j]
+ if a.Priority != b.Priority {
+ return a.Priority < b.Priority
+ }
+ if preferOAuth && a.Type != b.Type {
+ return a.Type == AccountTypeOAuth
+ }
+ return false
+ })
+}
+
+// shuffleWithinPriority 在同优先级内随机打乱顺序
+func shuffleWithinPriority(accounts []*Account) {
+ if len(accounts) <= 1 {
+ return
+ }
+ r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
+ start := 0
+ for start < len(accounts) {
+ priority := accounts[start].Priority
+ end := start + 1
+ for end < len(accounts) && accounts[end].Priority == priority {
+ end++
+ }
+ // 对 [start, end) 范围内的账户随机打乱
+ if end-start > 1 {
+ r.Shuffle(end-start, func(i, j int) {
+ accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
+ })
+ }
+ start = end
+ }
+}
+
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
@@ -2158,6 +2300,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL = account.Proxy.URL()
}
+ // 调试日志:记录即将转发的账号信息
+ log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
+ account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
+
// 重试循环
var resp *http.Response
retryStart := time.Now()
@@ -2172,7 +2318,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 发送请求
- resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
+ resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
@@ -2246,7 +2392,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
- retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
+ retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
@@ -2278,7 +2424,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
if buildErr2 == nil {
- retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
+ retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil {
resp = retryResp2
break
@@ -2393,6 +2539,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
+ // 调试日志:打印重试耗尽后的错误响应
+ log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
+ account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
+
s.handleRetryExhaustedSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
@@ -2420,6 +2570,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
+ // 调试日志:打印上游错误响应
+ log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
+ account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
+
s.handleFailoverSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
@@ -2549,9 +2703,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
fingerprint = fp
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
+ // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
- if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
+ if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
@@ -2770,6 +2925,10 @@ func extractUpstreamErrorMessage(body []byte) string {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ // 调试日志:打印上游错误响应
+ log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
+ account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
+
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
@@ -3478,7 +3637,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
- resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
+ resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
@@ -3500,7 +3659,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
- retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
+ retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
@@ -3578,12 +3737,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用统一指纹和重写 userID
+ // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
- if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
+ if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go
index 0ddd72b8..c63a020c 100644
--- a/backend/internal/service/gemini_multiplatform_test.go
+++ b/backend/internal/service/gemini_multiplatform_test.go
@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
+func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
+ return nil
+}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
diff --git a/backend/internal/service/http_upstream_port.go b/backend/internal/service/http_upstream_port.go
index 9357f763..0e4cfbec 100644
--- a/backend/internal/service/http_upstream_port.go
+++ b/backend/internal/service/http_upstream_port.go
@@ -10,6 +10,7 @@ import "net/http"
// - 支持可选代理配置
// - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用
+// - 支持可选的 TLS 指纹伪装
type HTTPUpstream interface {
// Do 执行 HTTP 请求
//
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
+
+ // DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
+ //
+ // 参数:
+ // - req: HTTP 请求对象,由调用方构建
+ // - proxyURL: 代理服务器地址,空字符串表示直连
+ // - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
+ // - accountConcurrency: 账户并发限制,用于动态调整连接池大小
+ // - enableTLSFingerprint: 是否启用 TLS 指纹伪装
+ //
+ // 返回:
+ // - *http.Response: HTTP 响应,调用方必须关闭 Body
+ // - error: 请求错误(网络错误、超时等)
+ //
+ // TLS 指纹说明:
+ // - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
+ // - TLS 指纹模板根据 accountID % len(profiles) 自动选择
+ // - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
+ // - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
+ //
+ // 注意:
+ // - 调用方必须关闭 resp.Body,否则会导致连接泄漏
+ // - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
+ DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
}
diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go
index 1ffa8057..e2e723b0 100644
--- a/backend/internal/service/identity_service.go
+++ b/backend/internal/service/identity_service.go
@@ -8,9 +8,11 @@ import (
"encoding/json"
"fmt"
"log"
+ "log/slog"
"net/http"
"regexp"
"strconv"
+ "strings"
"time"
)
@@ -49,6 +51,13 @@ type Fingerprint struct {
type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
+ // GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
+ // 返回的 sessionID 是一个 UUID 格式的字符串
+ // 如果不存在或已过期(15分钟无请求),返回空字符串
+ GetMaskedSessionID(ctx context.Context, accountID int64) (string, error)
+ // SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
+ // 每次调用都会刷新 TTL
+ SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error
}
// IdentityService 管理OAuth账号的请求身份指纹
@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return json.Marshal(reqMap)
}
+// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
+// 如果账号启用了会话ID伪装(session_id_masking_enabled),
+// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
+func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
+ // 先执行常规的 RewriteUserID 逻辑
+ newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
+ if err != nil {
+ return newBody, err
+ }
+
+ // 检查是否启用会话ID伪装
+ if !account.IsSessionIDMaskingEnabled() {
+ return newBody, nil
+ }
+
+ // 解析重写后的 body,提取 user_id
+ var reqMap map[string]any
+ if err := json.Unmarshal(newBody, &reqMap); err != nil {
+ return newBody, nil
+ }
+
+ metadata, ok := reqMap["metadata"].(map[string]any)
+ if !ok {
+ return newBody, nil
+ }
+
+ userID, ok := metadata["user_id"].(string)
+ if !ok || userID == "" {
+ return newBody, nil
+ }
+
+ // 查找 _session_ 的位置,替换其后的内容
+ const sessionMarker = "_session_"
+ idx := strings.LastIndex(userID, sessionMarker)
+ if idx == -1 {
+ return newBody, nil
+ }
+
+ // 获取或生成固定的伪装 session ID
+ maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
+ if err != nil {
+ log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
+ return newBody, nil
+ }
+
+ if maskedSessionID == "" {
+ // 首次或已过期,生成新的伪装 session ID
+ maskedSessionID = generateRandomUUID()
+ log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
+ }
+
+ // 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
+ if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
+ log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
+ }
+
+ // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
+ newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
+
+ slog.Debug("session_id_masking_applied",
+ "account_id", account.ID,
+ "before", userID,
+ "after", newUserID,
+ )
+
+ metadata["user_id"] = newUserID
+ reqMap["metadata"] = metadata
+
+ return json.Marshal(reqMap)
+}
+
+// generateRandomUUID 生成随机 UUID v4 格式字符串
+func generateRandomUUID() string {
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ // fallback: 使用时间戳生成
+ h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
+ b = h[:16]
+ }
+
+ // 设置 UUID v4 版本和变体位
+ b[6] = (b[6] & 0x0f) | 0x40
+ b[8] = (b[8] & 0x3f) | 0x80
+
+ return fmt.Sprintf("%x-%x-%x-%x-%x",
+ b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
+}
+
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
func generateClientID() string {
b := make([]byte, 32)
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 2d75dd5a..41bd253c 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false
}
- tempMatched := false
+ // 先尝试临时不可调度规则(401除外)
+ // 如果匹配成功,直接返回,不执行后续禁用逻辑
if statusCode != 401 {
- tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
+ if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
+ return true
+ }
}
+
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" {
@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
switch statusCode {
+ case 400:
+ // 只有当错误信息包含 "organization has been disabled" 时才禁用
+ if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") {
+ msg := "Organization disabled (400): " + upstreamMsg
+ s.handleAuthError(ctx, account, msg)
+ shouldDisable = true
+ }
+ // 其他 400 错误(如参数问题)不处理,不禁用账号
case 401:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if account.Type == AccountTypeOAuth {
@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
}
- if tempMatched {
- return true
- }
return shouldDisable
}
diff --git a/backend/internal/service/session_limit_cache.go b/backend/internal/service/session_limit_cache.go
index f6f0c26a..5482d610 100644
--- a/backend/internal/service/session_limit_cache.go
+++ b/backend/internal/service/session_limit_cache.go
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
+ // idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count,查询失败的账号不在 map 中
- GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
+ GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go
index 26cfd97d..02e7d445 100644
--- a/backend/internal/service/token_refresh_service.go
+++ b/backend/internal/service/token_refresh_service.go
@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account)
- if err == nil {
- // 刷新成功,更新账号credentials
+
+ // 如果有新凭证,先更新(即使有错误也要保存 token)
+ if newCredentials != nil {
account.Credentials = newCredentials
- if err := s.accountRepo.Update(ctx, account); err != nil {
- return fmt.Errorf("failed to save credentials: %w", err)
+ if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
+ return fmt.Errorf("failed to save credentials: %w", saveErr)
+ }
+ }
+
+ if err == nil {
+ // Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
+ if account.Platform == PlatformAntigravity &&
+ account.Status == StatusError &&
+ strings.Contains(account.ErrorMessage, "missing_project_id:") {
+ if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
+ log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
+ } else {
+ log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
+ }
}
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool {
"invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
+ "missing_project_id", // 缺少 project_id
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 0b9bc20c..b210286d 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -1,6 +1,7 @@
package service
import (
+ "context"
"database/sql"
"time"
@@ -196,6 +197,8 @@ func ProvideOpsScheduledReportService(
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
+ // Start Pub/Sub subscriber for L1 cache invalidation across instances
+ apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background())
return apiKeyService
}
diff --git a/deploy/README.md b/deploy/README.md
index f697247d..c42e7552 100644
--- a/deploy/README.md
+++ b/deploy/README.md
@@ -401,3 +401,58 @@ sudo systemctl status redis
2. **Database connection failed**: Check PostgreSQL is running and credentials are correct
3. **Redis connection failed**: Check Redis is running and password is correct
4. **Permission denied**: Ensure proper file ownership for binary install
+
+---
+
+## TLS Fingerprint Configuration
+
+Sub2API supports TLS fingerprint simulation to make requests appear as if they come from the official Claude CLI (Node.js client).
+
+### Default Behavior
+
+- Built-in `claude_cli_v2` profile simulates Node.js 20.x + OpenSSL 3.x
+- JA3 Hash: `1a28e69016765d92e3b381168d68922c`
+- JA4: `t13d5911h1_a33745022dd6_1f22a2ca17c4`
+- Profile selection: `accountID % profileCount`
+
+### Configuration
+
+```yaml
+gateway:
+ tls_fingerprint:
+ enabled: true # Global switch
+ profiles:
+ # Simple profile (uses default cipher suites)
+ profile_1:
+ name: "Profile 1"
+
+ # Profile with custom cipher suites (use compact array format)
+ profile_2:
+ name: "Profile 2"
+ cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196]
+ curves: [29, 23, 24]
+ point_formats: [0]
+
+ # Another custom profile
+ profile_3:
+ name: "Profile 3"
+ cipher_suites: [4865, 4866, 4867, 49199, 49200]
+ curves: [29, 23, 24, 25]
+```
+
+### Profile Fields
+
+| Field | Type | Description |
+|-------|------|-------------|
+| `name` | string | Display name (required) |
+| `cipher_suites` | []uint16 | Cipher suites in decimal. Empty = default |
+| `curves` | []uint16 | Elliptic curves in decimal. Empty = default |
+| `point_formats` | []uint8 | EC point formats. Empty = default |
+
+### Common Values Reference
+
+**Cipher Suites (TLS 1.3):** `4865` (AES_128_GCM), `4866` (AES_256_GCM), `4867` (CHACHA20)
+
+**Cipher Suites (TLS 1.2):** `49195`, `49196`, `49199`, `49200` (ECDHE variants)
+
+**Curves:** `29` (X25519), `23` (P-256), `24` (P-384), `25` (P-521)
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 1f4aa266..558b8ef0 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -210,6 +210,19 @@ gateway:
outbox_backlog_rebuild_rows: 10000
# 全量重建周期(秒),0 表示禁用
full_rebuild_interval_seconds: 300
+ # TLS fingerprint simulation / TLS 指纹伪装
+ # Default profile "claude_cli_v2" simulates Node.js 20.x
+ # 默认模板 "claude_cli_v2" 模拟 Node.js 20.x 指纹
+ tls_fingerprint:
+ enabled: true
+ # profiles:
+ # profile_1:
+ # name: "Custom Profile 1"
+ # profile_2:
+ # name: "Custom Profile 2"
+ # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196]
+ # curves: [29, 23, 24]
+ # point_formats: [0]
# =============================================================================
# API Key Auth Cache Configuration
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index c81de00e..7906cd6b 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -1191,6 +1191,190 @@
+
+
+
+
{{ t('admin.accounts.quotaControl.title') }}
+
+ {{ t('admin.accounts.quotaControl.hint') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.quotaControl.windowCost.hint') }}
+
+
+
+
+
+
+
+
+
+ $
+
+
+
{{ t('admin.accounts.quotaControl.windowCost.limitHint') }}
+
+
+
+
+ $
+
+
+
{{ t('admin.accounts.quotaControl.windowCost.stickyReserveHint') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.quotaControl.sessionLimit.hint') }}
+
+
+
+
+
+
+
+
+
+
{{ t('admin.accounts.quotaControl.sessionLimit.maxSessionsHint') }}
+
+
+
+
+
+ {{ t('common.minutes') }}
+
+
{{ t('admin.accounts.quotaControl.sessionLimit.idleTimeoutHint') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }}
+
+
+
+
+
+
+
@@ -1214,7 +1398,7 @@
@@ -1763,6 +1947,16 @@ const geminiAIStudioOAuthEnabled = ref(false)
const showAdvancedOAuth = ref(false)
const showGeminiHelpDialog = ref(false)
+// Quota control state (Anthropic OAuth/SetupToken only)
+const windowCostEnabled = ref(false)
+const windowCostLimit = ref(null)
+const windowCostStickyReserve = ref(null)
+const sessionLimitEnabled = ref(false)
+const maxSessions = ref(null)
+const sessionIdleTimeout = ref(null)
+const tlsFingerprintEnabled = ref(false)
+const sessionIdMaskingEnabled = ref(false)
+
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
const geminiTierGcp = ref<'gcp_standard' | 'gcp_enterprise'>('gcp_standard')
@@ -2140,6 +2334,15 @@ const resetForm = () => {
customErrorCodeInput.value = null
interceptWarmupRequests.value = false
autoPauseOnExpired.value = true
+ // Reset quota control state
+ windowCostEnabled.value = false
+ windowCostLimit.value = null
+ windowCostStickyReserve.value = null
+ sessionLimitEnabled.value = false
+ maxSessions.value = null
+ sessionIdleTimeout.value = null
+ tlsFingerprintEnabled.value = false
+ sessionIdMaskingEnabled.value = false
tempUnschedEnabled.value = false
tempUnschedRules.value = []
geminiOAuthType.value = 'code_assist'
@@ -2407,7 +2610,32 @@ const handleAnthropicExchange = async (authCode: string) => {
...proxyConfig
})
- const extra = oauth.buildExtraInfo(tokenInfo)
+ // Build extra with quota control settings
+ const baseExtra = oauth.buildExtraInfo(tokenInfo) || {}
+ const extra: Record = { ...baseExtra }
+
+ // Add window cost limit settings
+ if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) {
+ extra.window_cost_limit = windowCostLimit.value
+ extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10
+ }
+
+ // Add session limit settings
+ if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) {
+ extra.max_sessions = maxSessions.value
+ extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5
+ }
+
+ // Add TLS fingerprint settings
+ if (tlsFingerprintEnabled.value) {
+ extra.enable_tls_fingerprint = true
+ }
+
+ // Add session ID masking settings
+ if (sessionIdMaskingEnabled.value) {
+ extra.session_id_masking_enabled = true
+ }
+
const credentials = {
...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
@@ -2475,7 +2703,32 @@ const handleCookieAuth = async (sessionKey: string) => {
...proxyConfig
})
- const extra = oauth.buildExtraInfo(tokenInfo)
+ // Build extra with quota control settings
+ const baseExtra = oauth.buildExtraInfo(tokenInfo) || {}
+ const extra: Record = { ...baseExtra }
+
+ // Add window cost limit settings
+ if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) {
+ extra.window_cost_limit = windowCostLimit.value
+ extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10
+ }
+
+ // Add session limit settings
+ if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) {
+ extra.max_sessions = maxSessions.value
+ extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5
+ }
+
+ // Add TLS fingerprint settings
+ if (tlsFingerprintEnabled.value) {
+ extra.enable_tls_fingerprint = true
+ }
+
+ // Add session ID masking settings
+ if (sessionIdMaskingEnabled.value) {
+ extra.session_id_masking_enabled = true
+ }
+
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
// Merge interceptWarmupRequests into credentials
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index d27364f1..81d10932 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -566,7 +566,7 @@
@@ -732,6 +732,60 @@
+
+
+
+
+
+
+
+ {{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }}
+
+
+
+
+
@@ -904,6 +958,8 @@ const windowCostStickyReserve = ref(null)
const sessionLimitEnabled = ref(false)
const maxSessions = ref(null)
const sessionIdleTimeout = ref(null)
+const tlsFingerprintEnabled = ref(false)
+const sessionIdMaskingEnabled = ref(false)
// Computed: current preset mappings based on platform
const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic'))
@@ -1237,6 +1293,8 @@ function loadQuotaControlSettings(account: Account) {
sessionLimitEnabled.value = false
maxSessions.value = null
sessionIdleTimeout.value = null
+ tlsFingerprintEnabled.value = false
+ sessionIdMaskingEnabled.value = false
// Only applies to Anthropic OAuth/SetupToken accounts
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
@@ -1255,6 +1313,16 @@ function loadQuotaControlSettings(account: Account) {
maxSessions.value = account.max_sessions
sessionIdleTimeout.value = account.session_idle_timeout_minutes ?? 5
}
+
+ // Load TLS fingerprint setting
+ if (account.enable_tls_fingerprint === true) {
+ tlsFingerprintEnabled.value = true
+ }
+
+ // Load session ID masking setting
+ if (account.session_id_masking_enabled === true) {
+ sessionIdMaskingEnabled.value = true
+ }
}
function formatTempUnschedKeywords(value: unknown) {
@@ -1407,6 +1475,20 @@ const handleSubmit = async () => {
delete newExtra.session_idle_timeout_minutes
}
+ // TLS fingerprint setting
+ if (tlsFingerprintEnabled.value) {
+ newExtra.enable_tls_fingerprint = true
+ } else {
+ delete newExtra.enable_tls_fingerprint
+ }
+
+ // Session ID masking setting
+ if (sessionIdMaskingEnabled.value) {
+ newExtra.session_id_masking_enabled = true
+ } else {
+ delete newExtra.session_id_masking_enabled
+ }
+
updatePayload.extra = newExtra
}
diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue
index fd8742c3..9d2b40fb 100644
--- a/frontend/src/components/layout/AppHeader.vue
+++ b/frontend/src/components/layout/AppHeader.vue
@@ -21,8 +21,20 @@
-
+
+
+
+
+ {{ t('nav.docs') }}
+
+
@@ -211,6 +223,7 @@ const user = computed(() => authStore.user)
const dropdownOpen = ref(false)
const dropdownRef = ref
(null)
const contactInfo = computed(() => appStore.contactInfo)
+const docUrl = computed(() => appStore.docUrl)
// 只在标准模式的管理员下显示新手引导按钮
const showOnboardingButton = computed(() => {
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index a237ad53..362a1349 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -196,7 +196,8 @@ export default {
expand: 'Expand',
logout: 'Logout',
github: 'GitHub',
- mySubscriptions: 'My Subscriptions'
+ mySubscriptions: 'My Subscriptions',
+ docs: 'Docs'
},
// Auth
@@ -1288,6 +1289,14 @@ export default {
idleTimeout: 'Idle Timeout',
idleTimeoutPlaceholder: '5',
idleTimeoutHint: 'Sessions will be released after idle timeout'
+ },
+ tlsFingerprint: {
+ label: 'TLS Fingerprint Simulation',
+ hint: 'Simulate Node.js/Claude Code client TLS fingerprint'
+ },
+ sessionIdMasking: {
+ label: 'Session ID Masking',
+ hint: 'When enabled, fixes the session ID in metadata.user_id for 15 minutes, making upstream think requests come from the same session'
}
},
expired: 'Expired',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index bfe36c1f..1efd3867 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -193,7 +193,8 @@ export default {
expand: '展开',
logout: '退出登录',
github: 'GitHub',
- mySubscriptions: '我的订阅'
+ mySubscriptions: '我的订阅',
+ docs: '文档'
},
// Auth
@@ -1420,6 +1421,14 @@ export default {
idleTimeout: '空闲超时',
idleTimeoutPlaceholder: '5',
idleTimeoutHint: '会话空闲超时后自动释放'
+ },
+ tlsFingerprint: {
+ label: 'TLS 指纹模拟',
+ hint: '模拟 Node.js/Claude Code 客户端的 TLS 指纹'
+ },
+ sessionIdMasking: {
+ label: '会话 ID 伪装',
+ hint: '启用后将在 15 分钟内固定 metadata.user_id 中的 session ID,使上游认为请求来自同一会话'
}
},
expired: '已过期',
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 353ccb83..35e256e6 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -480,6 +480,13 @@ export interface Account {
max_sessions?: number | null
session_idle_timeout_minutes?: number | null
+ // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
+ enable_tls_fingerprint?: boolean | null
+
+ // 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 启用后将在15分钟内固定 metadata.user_id 中的 session ID
+ session_id_masking_enabled?: boolean | null
+
// 运行时状态(仅当启用对应限制时返回)
current_window_cost?: number | null // 当前窗口费用
active_sessions?: number | null // 当前活跃会话数
diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue
index 96457172..47a15084 100644
--- a/frontend/src/views/admin/GroupsView.vue
+++ b/frontend/src/views/admin/GroupsView.vue
@@ -243,7 +243,7 @@
/>
{{ t('admin.groups.platformHint') }}
-
+
-
+
{
}
}
-// 监听 subscription_type 变化,订阅模式时重置 rate_multiplier 为 1,is_exclusive 为 true
+// 监听 subscription_type 变化,订阅模式时 is_exclusive 默认为 true
watch(
() => createForm.subscription_type,
(newVal) => {
if (newVal === 'subscription') {
- createForm.rate_multiplier = 1.0
createForm.is_exclusive = true
}
}
diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json
index a1731cfb..82ae3f9f 100644
--- a/frontend/tsconfig.json
+++ b/frontend/tsconfig.json
@@ -21,5 +21,6 @@
"types": ["vite/client"]
},
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"],
+ "exclude": ["src/**/__tests__/**", "src/**/*.spec.ts", "src/**/*.test.ts"],
"references": [{ "path": "./tsconfig.node.json" }]
}