antigravity: 区分切换后重试次数

This commit is contained in:
song
2026-01-28 00:01:03 +08:00
parent 877c17251d
commit f761afb1ef
6 changed files with 87 additions and 11 deletions

View File

@@ -276,10 +276,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body)
} else { } else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
@@ -419,10 +423,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body)
} else { } else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()

View File

@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -288,10 +289,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context()
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body) result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body)
} else { } else {
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()

View File

@@ -14,6 +14,9 @@ const (
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count" RetryCount Key = "ctx_retry_count"
// AccountSwitchCount 表示请求过程中发生的账号切换次数
AccountSwitchCount Key = "ctx_account_switch_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端 // IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
IsClaudeCodeClient Key = "ctx_is_claude_code_client" IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置 // Group 认证后的分组信息,由 API Key 认证中间件设置

View File

@@ -33,6 +33,7 @@ const (
const ( const (
antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES" antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES"
antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE" antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT" antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE" antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
@@ -745,6 +746,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel billingModel = mappedModel
} }
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
@@ -793,7 +796,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream: s.httpUpstream, httpUpstream: s.httpUpstream,
settingService: s.settingService, settingService: s.settingService,
handleError: s.handleUpstreamError, handleError: s.handleUpstreamError,
maxRetries: antigravityMaxRetriesForModel(originalModel), maxRetries: maxRetries,
}) })
if err != nil { if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
@@ -870,7 +873,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
httpUpstream: s.httpUpstream, httpUpstream: s.httpUpstream,
settingService: s.settingService, settingService: s.settingService,
handleError: s.handleUpstreamError, handleError: s.handleUpstreamError,
maxRetries: antigravityMaxRetriesForModel(originalModel), maxRetries: maxRetries,
}) })
if retryErr != nil { if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
@@ -1387,6 +1390,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
billingModel = mappedModel billingModel = mappedModel
} }
afterSwitch := antigravityHasAccountSwitch(ctx)
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
// 获取 access_token // 获取 access_token
if s.tokenProvider == nil { if s.tokenProvider == nil {
@@ -1444,7 +1449,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
httpUpstream: s.httpUpstream, httpUpstream: s.httpUpstream,
settingService: s.settingService, settingService: s.settingService,
handleError: s.handleUpstreamError, handleError: s.handleUpstreamError,
maxRetries: antigravityMaxRetriesForModel(originalModel), maxRetries: maxRetries,
}) })
if err != nil { if err != nil {
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
@@ -1641,6 +1646,16 @@ func antigravityUseScopeRateLimit() bool {
return v == "1" || v == "true" || v == "yes" || v == "on" return v == "1" || v == "true" || v == "yes" || v == "on"
} }
func antigravityHasAccountSwitch(ctx context.Context) bool {
if ctx == nil {
return false
}
if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok {
return v > 0
}
return false
}
func antigravityMaxRetries() int { func antigravityMaxRetries() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv)) raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv))
if raw == "" { if raw == "" {
@@ -1653,9 +1668,21 @@ func antigravityMaxRetries() int {
return value return value
} }
func antigravityMaxRetriesAfterSwitch() int {
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv))
if raw == "" {
return antigravityMaxRetries()
}
value, err := strconv.Atoi(raw)
if err != nil || value <= 0 {
return antigravityMaxRetries()
}
return value
}
// antigravityMaxRetriesForModel 根据模型类型获取重试次数 // antigravityMaxRetriesForModel 根据模型类型获取重试次数
// 优先使用模型细分配置,未设置则回退到平台级配置 // 优先使用模型细分配置,未设置则回退到平台级配置
func antigravityMaxRetriesForModel(model string) int { func antigravityMaxRetriesForModel(model string, afterSwitch bool) int {
var envKey string var envKey string
if strings.HasPrefix(model, "claude-") { if strings.HasPrefix(model, "claude-") {
envKey = antigravityMaxRetriesClaudeEnv envKey = antigravityMaxRetriesClaudeEnv
@@ -1672,6 +1699,9 @@ func antigravityMaxRetriesForModel(model string) int {
} }
} }
} }
if afterSwitch {
return antigravityMaxRetriesAfterSwitch()
}
return antigravityMaxRetries() return antigravityMaxRetries()
} }

View File

@@ -161,3 +161,28 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
require.Len(t, events, 1) require.Len(t, events, 1)
require.Equal(t, "prompt_too_long", events[0].Kind) require.Equal(t, "prompt_too_long", events[0].Kind)
} }
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "4")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false)
require.Equal(t, 4, got)
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true)
require.Equal(t, 7, got)
}
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) {
t.Setenv(antigravityMaxRetriesEnv, "5")
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "")
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true)
require.Equal(t, 5, got)
}

View File

@@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/lib/pq" "github.com/lib/pq"
@@ -476,9 +477,13 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
continue continue
} }
attemptCtx := ctx
if switches > 0 {
attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches)
}
exec := func() *opsRetryExecution { exec := func() *opsRetryExecution {
defer selection.ReleaseFunc() defer selection.ReleaseFunc()
return s.executeWithAccount(ctx, reqType, errorLog, body, account) return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account)
}() }()
if exec != nil { if exec != nil {