diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index cd622a3b..fdb6411c 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -276,10 +276,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body) } else { - result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + result, err = h.geminiCompatService.Forward(requestCtx, c, account, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -419,10 +423,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) + result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body) } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) + result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) } if accountReleaseFunc != nil { accountReleaseFunc() diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index c7646b38..1946aeb2 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -10,6 +10,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" @@ -288,10 +289,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body) } else { - result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) } if accountReleaseFunc != nil { accountReleaseFunc() diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 27bb5ac5..fd7512f7 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -14,6 +14,9 @@ const ( // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 RetryCount Key = "ctx_retry_count" + // AccountSwitchCount 表示请求过程中发生的账号切换次数 + AccountSwitchCount Key = "ctx_account_switch_count" + // IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端 IsClaudeCodeClient Key = "ctx_is_claude_code_client" // Group 认证后的分组信息,由 API Key 认证中间件设置 diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index dbdfd374..db988565 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -33,6 +33,7 @@ const ( const ( antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES" + antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES" antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE" antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT" antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE" @@ -745,6 +746,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { billingModel = mappedModel } + afterSwitch := antigravityHasAccountSwitch(ctx) + maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch) // 获取 access_token if s.tokenProvider == nil { @@ -793,7 +796,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, - maxRetries: antigravityMaxRetriesForModel(originalModel), + maxRetries: maxRetries, }) if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") @@ -870,7 +873,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, - maxRetries: antigravityMaxRetriesForModel(originalModel), + maxRetries: maxRetries, }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -1387,6 +1390,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { billingModel = mappedModel } + afterSwitch := antigravityHasAccountSwitch(ctx) + maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch) // 获取 access_token if s.tokenProvider == nil { @@ -1444,7 +1449,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, - maxRetries: antigravityMaxRetriesForModel(originalModel), + maxRetries: maxRetries, }) if err != nil { return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") @@ -1641,6 +1646,16 @@ func antigravityUseScopeRateLimit() bool { return v == "1" || v == "true" || v == "yes" || v == "on" } +func antigravityHasAccountSwitch(ctx context.Context) bool { + if ctx == nil { + return false + } + if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok { + return v > 0 + } + return false +} + func antigravityMaxRetries() int { raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv)) if raw == "" { @@ -1653,9 +1668,21 @@ func antigravityMaxRetries() int { return value } +func antigravityMaxRetriesAfterSwitch() int { + raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv)) + if raw == "" { + return antigravityMaxRetries() + } + value, err := strconv.Atoi(raw) + if err != nil || value <= 0 { + return antigravityMaxRetries() + } + return value +} + // antigravityMaxRetriesForModel 根据模型类型获取重试次数 // 优先使用模型细分配置,未设置则回退到平台级配置 -func antigravityMaxRetriesForModel(model string) int { +func antigravityMaxRetriesForModel(model string, afterSwitch bool) int { var envKey string if strings.HasPrefix(model, "claude-") { envKey = antigravityMaxRetriesClaudeEnv @@ -1672,6 +1699,9 @@ func antigravityMaxRetriesForModel(model string) int { } } } + if afterSwitch { + return antigravityMaxRetriesAfterSwitch() + } return antigravityMaxRetries() } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 9c1fb415..ffdcdc73 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -161,3 +161,28 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { require.Len(t, events, 1) require.Equal(t, "prompt_too_long", events[0].Kind) } + +func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) { + t.Setenv(antigravityMaxRetriesEnv, "4") + t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7") + t.Setenv(antigravityMaxRetriesClaudeEnv, "") + t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") + t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") + + got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false) + require.Equal(t, 4, got) + + got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true) + require.Equal(t, 7, got) +} + +func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) { + t.Setenv(antigravityMaxRetriesEnv, "5") + t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "") + t.Setenv(antigravityMaxRetriesClaudeEnv, "") + t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") + t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") + + got := antigravityMaxRetriesForModel("gemini-2.5-flash", true) + require.Equal(t, 5, got) +} diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 8d98e43f..ffe4c934 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -476,9 +477,13 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq continue } + attemptCtx := ctx + if switches > 0 { + attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() - return s.executeWithAccount(ctx, reqType, errorLog, body, account) + return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account) }() if exec != nil {