Merge branch 'Wei-Shaw:main' into main
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -40,6 +41,12 @@ const (
|
||||
antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待)
|
||||
antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用)
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED 专用重试参数
|
||||
// 模型容量不足时,所有账号共享同一容量池,切换账号无意义
|
||||
// 使用固定 1s 间隔重试,最多重试 60 次
|
||||
antigravityModelCapacityRetryMaxAttempts = 60
|
||||
antigravityModelCapacityRetryWait = 1 * time.Second
|
||||
|
||||
// Google RPC 状态和类型常量
|
||||
googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED"
|
||||
googleRPCStatusUnavailable = "UNAVAILABLE"
|
||||
@@ -60,6 +67,9 @@ const (
|
||||
// 单账号 503 退避重试:原地重试的总累计等待时间上限
|
||||
// 超过此上限将不再重试,直接返回 503
|
||||
antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间
|
||||
antigravityModelCapacityCooldown = 10 * time.Second
|
||||
)
|
||||
|
||||
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
||||
@@ -68,8 +78,15 @@ var antigravityPassthroughErrorMessages = []string{
|
||||
"prompt is too long",
|
||||
}
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试
|
||||
var (
|
||||
modelCapacityExhaustedMu sync.RWMutex
|
||||
modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||
)
|
||||
|
||||
@@ -131,6 +148,20 @@ type antigravityRetryLoopResult struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// resolveAntigravityForwardBaseURL 解析转发用 base URL。
|
||||
// 默认使用 daily(ForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。
|
||||
func resolveAntigravityForwardBaseURL() string {
|
||||
baseURLs := antigravity.ForwardBaseURLs()
|
||||
if len(baseURLs) == 0 {
|
||||
return ""
|
||||
}
|
||||
mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv)))
|
||||
if mode == "prod" && len(baseURLs) > 1 {
|
||||
return baseURLs[1]
|
||||
}
|
||||
return baseURLs[0]
|
||||
}
|
||||
|
||||
// smartRetryAction 智能重试的处理结果
|
||||
type smartRetryAction int
|
||||
|
||||
@@ -158,7 +189,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
|
||||
// 判断是否触发智能重试
|
||||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||||
shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody)
|
||||
|
||||
// 情况1: retryDelay >= 阈值,限流模型并切换账号
|
||||
if shouldRateLimitModel {
|
||||
@@ -195,20 +226,48 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
}
|
||||
|
||||
// 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次)
|
||||
// 情况2: retryDelay < 阈值(或 MODEL_CAPACITY_EXHAUSTED),智能重试
|
||||
if shouldSmartRetry {
|
||||
var lastRetryResp *http.Response
|
||||
var lastRetryBody []byte
|
||||
|
||||
for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ {
|
||||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID)
|
||||
// MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔)
|
||||
maxAttempts := antigravitySmartRetryMaxAttempts
|
||||
if isModelCapacityExhausted {
|
||||
maxAttempts = antigravityModelCapacityRetryMaxAttempts
|
||||
waitDuration = antigravityModelCapacityRetryWait
|
||||
|
||||
// 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503
|
||||
if modelName != "" {
|
||||
modelCapacityExhaustedMu.RLock()
|
||||
cooldownUntil, exists := modelCapacityExhaustedUntil[modelName]
|
||||
modelCapacityExhaustedMu.RUnlock()
|
||||
if exists && time.Now().Before(cooldownUntil) {
|
||||
log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)",
|
||||
p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05"))
|
||||
return &smartRetryResult{
|
||||
action: smartRetryActionBreakWithResp,
|
||||
resp: &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
|
||||
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
timer.Stop()
|
||||
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||
case <-time.After(waitDuration):
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
// 智能重试:创建新请求
|
||||
@@ -228,13 +287,19 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts)
|
||||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
|
||||
// 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown
|
||||
if isModelCapacityExhausted && modelName != "" {
|
||||
modelCapacityExhaustedMu.Lock()
|
||||
delete(modelCapacityExhaustedUntil, modelName)
|
||||
modelCapacityExhaustedMu.Unlock()
|
||||
}
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
||||
}
|
||||
|
||||
// 网络错误时,继续重试
|
||||
if retryErr != nil || retryResp == nil {
|
||||
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr)
|
||||
log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, maxAttempts, retryErr)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -244,13 +309,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
lastRetryResp = retryResp
|
||||
if retryResp != nil {
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||
_ = retryResp.Body.Close()
|
||||
}
|
||||
|
||||
// 解析新的重试信息,用于下次重试的等待时间
|
||||
if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil {
|
||||
newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||||
// 解析新的重试信息,用于下次重试的等待时间(MODEL_CAPACITY_EXHAUSTED 使用固定循环,跳过)
|
||||
if !isModelCapacityExhausted && attempt < maxAttempts && lastRetryBody != nil {
|
||||
newShouldRetry, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||||
if newShouldRetry && newWaitDuration > 0 {
|
||||
waitDuration = newWaitDuration
|
||||
}
|
||||
@@ -267,6 +332,27 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
retryBody = respBody
|
||||
}
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED:模型容量不足,切换账号无意义
|
||||
// 直接返回上游错误响应,不设置模型限流,不切换账号
|
||||
if isModelCapacityExhausted {
|
||||
// 设置 cooldown,让后续请求快速失败,避免重复重试
|
||||
if modelName != "" {
|
||||
modelCapacityExhaustedMu.Lock()
|
||||
modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown)
|
||||
modelCapacityExhaustedMu.Unlock()
|
||||
}
|
||||
log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)",
|
||||
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
|
||||
return &smartRetryResult{
|
||||
action: smartRetryActionBreakWithResp,
|
||||
resp: &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryBody)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号,
|
||||
// 直接返回 503 让 Handler 层的单账号退避循环做最终处理。
|
||||
if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) {
|
||||
@@ -283,7 +369,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
|
||||
log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)",
|
||||
p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
|
||||
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200))
|
||||
|
||||
resetAt := time.Now().Add(rateLimitDuration)
|
||||
if p.accountRepo != nil && modelName != "" {
|
||||
@@ -367,11 +453,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
|
||||
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
timer.Stop()
|
||||
log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix)
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||
case <-time.After(waitDuration):
|
||||
case <-timer.C:
|
||||
}
|
||||
totalWaited += waitDuration
|
||||
|
||||
@@ -405,12 +493,12 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
_ = lastRetryResp.Body.Close()
|
||||
}
|
||||
lastRetryResp = retryResp
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||
_ = retryResp.Body.Close()
|
||||
|
||||
// 解析新的重试信息,更新下次等待时间
|
||||
if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil {
|
||||
_, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||||
_, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody)
|
||||
if newWaitDuration > 0 {
|
||||
waitDuration = newWaitDuration
|
||||
if waitDuration > antigravitySingleAccountSmartRetryMaxWait {
|
||||
@@ -466,10 +554,11 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
|
||||
}
|
||||
}
|
||||
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs
|
||||
baseURL := resolveAntigravityForwardBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("no antigravity forward base url configured")
|
||||
}
|
||||
availableURLs := []string{baseURL}
|
||||
|
||||
var resp *http.Response
|
||||
var usedBaseURL string
|
||||
@@ -907,11 +996,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// URL fallback 循环
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
baseURL := resolveAntigravityForwardBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, errors.New("no antigravity forward base url configured")
|
||||
}
|
||||
availableURLs := []string{baseURL}
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
@@ -1376,7 +1465,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
break
|
||||
}
|
||||
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||
_ = retryResp.Body.Close()
|
||||
if retryResp.StatusCode == http.StatusTooManyRequests {
|
||||
retryBaseURL := ""
|
||||
@@ -1457,6 +1546,27 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession)
|
||||
|
||||
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if isGoogleProjectConfigError(msg) {
|
||||
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||||
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
|
||||
}
|
||||
}
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
@@ -1997,6 +2107,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// Always record upstream context for Ops error logs, even when we will failover.
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
|
||||
// 精确匹配服务端配置类 400 错误,触发同账号重试 + failover
|
||||
if resp.StatusCode == http.StatusBadRequest && isGoogleProjectConfigError(strings.ToLower(upstreamMsg)) {
|
||||
log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps, RetryableOnSameAccount: true}
|
||||
}
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -2092,6 +2218,44 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
|
||||
}
|
||||
}
|
||||
|
||||
// isGoogleProjectConfigError 判断(已提取的小写)错误消息是否属于 Google 服务端配置类问题。
|
||||
// 只精确匹配已知的服务端侧错误,避免对客户端请求错误做无意义重试。
|
||||
// 适用于所有走 Google 后端的平台(Antigravity、Gemini)。
|
||||
func isGoogleProjectConfigError(lowerMsg string) bool {
|
||||
// Google 间歇性 Bug:Project ID 有效但被临时识别失败
|
||||
return strings.Contains(lowerMsg, "invalid project resource name")
|
||||
}
|
||||
|
||||
// googleConfigErrorCooldown 服务端配置类 400 错误的临时封禁时长
|
||||
const googleConfigErrorCooldown = 1 * time.Minute
|
||||
|
||||
// tempUnscheduleGoogleConfigError 对服务端配置类 400 错误触发临时封禁,
|
||||
// 避免短时间内反复调度到同一个有问题的账号。
|
||||
func tempUnscheduleGoogleConfigError(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
|
||||
until := time.Now().Add(googleConfigErrorCooldown)
|
||||
reason := "400: invalid project resource name (auto temp-unschedule 1m)"
|
||||
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
|
||||
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
|
||||
} else {
|
||||
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
|
||||
}
|
||||
}
|
||||
|
||||
// emptyResponseCooldown 空流式响应的临时封禁时长
|
||||
const emptyResponseCooldown = 1 * time.Minute
|
||||
|
||||
// tempUnscheduleEmptyResponse 对空流式响应触发临时封禁,
|
||||
// 避免短时间内反复调度到同一个返回空响应的账号。
|
||||
func tempUnscheduleEmptyResponse(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) {
|
||||
until := time.Now().Add(emptyResponseCooldown)
|
||||
reason := "empty stream response (auto temp-unschedule 1m)"
|
||||
if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil {
|
||||
log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err)
|
||||
} else {
|
||||
log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason)
|
||||
}
|
||||
}
|
||||
|
||||
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
|
||||
// 返回 true 表示正常完成等待,false 表示 context 已取消
|
||||
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
@@ -2108,10 +2272,12 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
sleepFor = 0
|
||||
}
|
||||
|
||||
timer := time.NewTimer(sleepFor)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return false
|
||||
case <-time.After(sleepFor):
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -2156,8 +2322,9 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
|
||||
|
||||
// antigravitySmartRetryInfo 智能重试所需的信息
|
||||
type antigravitySmartRetryInfo struct {
|
||||
RetryDelay time.Duration // 重试延迟时间
|
||||
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
|
||||
RetryDelay time.Duration // 重试延迟时间
|
||||
ModelName string // 限流的模型名称(如 "claude-sonnet-4-5")
|
||||
IsModelCapacityExhausted bool // 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED)
|
||||
}
|
||||
|
||||
// parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息
|
||||
@@ -2272,31 +2439,40 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo {
|
||||
}
|
||||
|
||||
return &antigravitySmartRetryInfo{
|
||||
RetryDelay: retryDelay,
|
||||
ModelName: modelName,
|
||||
RetryDelay: retryDelay,
|
||||
ModelName: modelName,
|
||||
IsModelCapacityExhausted: hasModelCapacityExhausted,
|
||||
}
|
||||
}
|
||||
|
||||
// shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试
|
||||
// 返回:
|
||||
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold)
|
||||
// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold)
|
||||
// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0)
|
||||
// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold,或 MODEL_CAPACITY_EXHAUSTED)
|
||||
// - shouldRateLimitModel: 是否应该限流模型并切换账号(仅 RATE_LIMIT_EXCEEDED 且 retryDelay >= 阈值)
|
||||
// - waitDuration: 等待时间
|
||||
// - modelName: 限流的模型名称
|
||||
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) {
|
||||
// - isModelCapacityExhausted: 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED)
|
||||
func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string, isModelCapacityExhausted bool) {
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false, false, 0, ""
|
||||
return false, false, 0, "", false
|
||||
}
|
||||
|
||||
info := parseAntigravitySmartRetryInfo(respBody)
|
||||
if info == nil {
|
||||
return false, false, 0, ""
|
||||
return false, false, 0, "", false
|
||||
}
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED(模型容量不足):所有账号共享同一模型容量池
|
||||
// 切换账号无意义,使用固定 1s 间隔重试
|
||||
if info.IsModelCapacityExhausted {
|
||||
return true, false, antigravityModelCapacityRetryWait, info.ModelName, true
|
||||
}
|
||||
|
||||
// RATE_LIMIT_EXCEEDED(账号级限流):
|
||||
// retryDelay >= 阈值:直接限流模型,不重试
|
||||
// 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s
|
||||
if info.RetryDelay >= antigravityRateLimitThreshold {
|
||||
return false, true, info.RetryDelay, info.ModelName
|
||||
return false, true, info.RetryDelay, info.ModelName, false
|
||||
}
|
||||
|
||||
// retryDelay < 阈值:智能重试
|
||||
@@ -2305,7 +2481,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou
|
||||
waitDuration = antigravitySmartRetryMinWait
|
||||
}
|
||||
|
||||
return true, false, waitDuration, info.ModelName
|
||||
return true, false, waitDuration, info.ModelName, false
|
||||
}
|
||||
|
||||
// handleModelRateLimitParams 模型级限流处理参数
|
||||
@@ -2331,8 +2507,9 @@ type handleModelRateLimitResult struct {
|
||||
|
||||
// handleModelRateLimit 处理模型级限流(在原有逻辑之前调用)
|
||||
// 仅处理 429/503,解析模型名和 retryDelay
|
||||
// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试
|
||||
// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
|
||||
// - MODEL_CAPACITY_EXHAUSTED: 返回 Handled=true(实际重试由 handleSmartRetry 处理)
|
||||
// - RATE_LIMIT_EXCEEDED + retryDelay < 阈值: 返回 ShouldRetry=true,由调用方等待后重试
|
||||
// - RATE_LIMIT_EXCEEDED + retryDelay >= 阈值: 设置模型限流 + 清除粘性会话 + 返回 SwitchError
|
||||
func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult {
|
||||
if p.statusCode != 429 && p.statusCode != 503 {
|
||||
return &handleModelRateLimitResult{Handled: false}
|
||||
@@ -2343,7 +2520,17 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
|
||||
return &handleModelRateLimitResult{Handled: false}
|
||||
}
|
||||
|
||||
// < antigravityRateLimitThreshold: 等待后重试
|
||||
// MODEL_CAPACITY_EXHAUSTED:模型容量不足,所有账号共享同一容量池
|
||||
// 切换账号无意义,不设置模型限流(实际重试由 handleSmartRetry 处理)
|
||||
if info.IsModelCapacityExhausted {
|
||||
log.Printf("%s status=%d model_capacity_exhausted model=%s (not switching account, retry handled by smart retry)",
|
||||
p.prefix, p.statusCode, info.ModelName)
|
||||
return &handleModelRateLimitResult{
|
||||
Handled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试
|
||||
if info.RetryDelay < antigravityRateLimitThreshold {
|
||||
log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v",
|
||||
p.prefix, p.statusCode, info.ModelName, info.RetryDelay)
|
||||
@@ -2354,7 +2541,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit
|
||||
}
|
||||
}
|
||||
|
||||
// >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
|
||||
// RATE_LIMIT_EXCEEDED: >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号
|
||||
s.setModelRateLimitAndClearSession(p, info)
|
||||
|
||||
return &handleModelRateLimitResult{
|
||||
@@ -2906,9 +3093,14 @@ returnResponse:
|
||||
// 选择最后一个有效响应
|
||||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||||
|
||||
// 处理空响应情况
|
||||
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
|
||||
if last == nil && lastWithParts == nil {
|
||||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||||
log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover")
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
|
||||
RetryableOnSameAccount: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 如果收集到了图片 parts,需要合并到最终响应中
|
||||
@@ -3126,6 +3318,21 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
||||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||||
}
|
||||
|
||||
// 检查错误透传规则
|
||||
if ptStatus, ptErrType, ptErrMsg, matched := applyErrorPassthroughRule(
|
||||
c, account.Platform, upstreamStatus, body,
|
||||
0, "", "",
|
||||
); matched {
|
||||
c.JSON(ptStatus, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": ptErrType, "message": ptErrMsg},
|
||||
})
|
||||
if upstreamMsg == "" {
|
||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||
}
|
||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||
}
|
||||
|
||||
var statusCode int
|
||||
var errType, errMsg string
|
||||
|
||||
@@ -3323,10 +3530,14 @@ returnResponse:
|
||||
// 选择最后一个有效响应
|
||||
finalResponse := pickGeminiCollectResult(last, lastWithParts)
|
||||
|
||||
// 处理空响应情况
|
||||
// 处理空响应情况 — 触发同账号重试 + failover 切换账号
|
||||
if last == nil && lastWithParts == nil {
|
||||
log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received")
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||||
log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover")
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseBody: []byte(`{"error":"empty stream response from upstream"}`),
|
||||
RetryableOnSameAccount: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 将收集的所有 parts 合并到最终响应中
|
||||
|
||||
@@ -273,12 +273,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||
return projectID, nil
|
||||
} else if onboardErr != nil {
|
||||
lastErr = onboardErr
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
@@ -292,6 +301,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
|
||||
tierID := resolveDefaultTierID(loadRaw)
|
||||
if tierID == "" {
|
||||
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
|
||||
}
|
||||
|
||||
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
func resolveDefaultTierID(loadRaw map[string]any) string {
|
||||
if len(loadRaw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
rawTiers, ok := loadRaw["allowedTiers"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
tiers, ok := rawTiers.([]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, rawTier := range tiers {
|
||||
tier, ok := rawTier.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
|
||||
continue
|
||||
}
|
||||
if id, ok := tier["id"].(string); ok {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FillProjectID 仅获取 project_id,不刷新 OAuth token
|
||||
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
|
||||
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
82
backend/internal/service/antigravity_oauth_service_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveDefaultTierID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
loadRaw map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil loadRaw",
|
||||
loadRaw: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "missing allowedTiers",
|
||||
loadRaw: map[string]any{
|
||||
"paidTier": map[string]any{"id": "g1-pro-tier"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty allowedTiers",
|
||||
loadRaw: map[string]any{"allowedTiers": []any{}},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "tier missing id field",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "allowedTiers but no default",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": false},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "default tier found",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": true},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "free-tier",
|
||||
},
|
||||
{
|
||||
name: "default tier id with spaces",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": " standard-tier ", "isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "standard-tier",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := resolveDefaultTierID(tc.loadRaw)
|
||||
if got != tc.want {
|
||||
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) {
|
||||
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
@@ -131,15 +133,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.False(t, handleErrorCalled)
|
||||
require.Len(t, upstream.calls, 2)
|
||||
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
|
||||
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
|
||||
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
||||
require.True(t, handleErrorCalled)
|
||||
require.Len(t, upstream.calls, antigravityMaxRetries)
|
||||
for _, callURL := range upstream.calls {
|
||||
require.True(t, strings.HasPrefix(callURL, base1))
|
||||
}
|
||||
|
||||
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
require.NotEmpty(t, available)
|
||||
require.Equal(t, base2, available[0])
|
||||
require.Equal(t, base1, available[0])
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景
|
||||
@@ -188,13 +191,14 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景
|
||||
func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
||||
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity}
|
||||
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "UNAVAILABLE",
|
||||
@@ -207,13 +211,13 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) {
|
||||
|
||||
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false)
|
||||
|
||||
// 应该触发模型限流
|
||||
// MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流
|
||||
// 实际重试由 handleSmartRetry 处理
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Handled)
|
||||
require.NotNil(t, result.SwitchError)
|
||||
require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel)
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
|
||||
require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path")
|
||||
require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch")
|
||||
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
|
||||
}
|
||||
|
||||
// TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理)
|
||||
@@ -301,11 +305,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) {
|
||||
|
||||
func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expectedDelay time.Duration
|
||||
expectedModel string
|
||||
expectedNil bool
|
||||
name string
|
||||
body string
|
||||
expectedDelay time.Duration
|
||||
expectedModel string
|
||||
expectedNil bool
|
||||
expectedIsModelCapacityExhausted bool
|
||||
}{
|
||||
{
|
||||
name: "valid complete response with RATE_LIMIT_EXCEEDED",
|
||||
@@ -368,8 +373,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
||||
"message": "No capacity available for model gemini-3-pro-high on the server"
|
||||
}
|
||||
}`,
|
||||
expectedDelay: 39 * time.Second,
|
||||
expectedModel: "gemini-3-pro-high",
|
||||
expectedDelay: 39 * time.Second,
|
||||
expectedModel: "gemini-3-pro-high",
|
||||
expectedIsModelCapacityExhausted: true,
|
||||
},
|
||||
{
|
||||
name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil",
|
||||
@@ -480,6 +486,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) {
|
||||
if result.ModelName != tt.expectedModel {
|
||||
t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel)
|
||||
}
|
||||
if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
|
||||
t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -491,13 +500,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
apiKeyAccount := &Account{Type: AccountTypeAPIKey}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
body string
|
||||
expectedShouldRetry bool
|
||||
expectedShouldRateLimit bool
|
||||
minWait time.Duration
|
||||
modelName string
|
||||
name string
|
||||
account *Account
|
||||
body string
|
||||
expectedShouldRetry bool
|
||||
expectedShouldRateLimit bool
|
||||
expectedIsModelCapacityExhausted bool
|
||||
minWait time.Duration
|
||||
modelName string
|
||||
}{
|
||||
{
|
||||
name: "OAuth account with short delay (< 7s) - smart retry",
|
||||
@@ -611,13 +621,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 39 * time.Second,
|
||||
modelName: "gemini-3-pro-high",
|
||||
expectedShouldRetry: true,
|
||||
expectedShouldRateLimit: false,
|
||||
expectedIsModelCapacityExhausted: true,
|
||||
minWait: 1 * time.Second,
|
||||
modelName: "gemini-3-pro-high",
|
||||
},
|
||||
{
|
||||
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit",
|
||||
name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait",
|
||||
account: oauthAccount,
|
||||
body: `{
|
||||
"error": {
|
||||
@@ -629,10 +640,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
"message": "No capacity available for model gemini-2.5-flash on the server"
|
||||
}
|
||||
}`,
|
||||
expectedShouldRetry: false,
|
||||
expectedShouldRateLimit: true,
|
||||
minWait: 30 * time.Second,
|
||||
modelName: "gemini-2.5-flash",
|
||||
expectedShouldRetry: true,
|
||||
expectedShouldRateLimit: false,
|
||||
expectedIsModelCapacityExhausted: true,
|
||||
minWait: 1 * time.Second,
|
||||
modelName: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit",
|
||||
@@ -656,13 +668,16 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
|
||||
shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body))
|
||||
if shouldRetry != tt.expectedShouldRetry {
|
||||
t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry)
|
||||
}
|
||||
if shouldRateLimit != tt.expectedShouldRateLimit {
|
||||
t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit)
|
||||
}
|
||||
if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted {
|
||||
t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted)
|
||||
}
|
||||
if shouldRetry {
|
||||
if wait < tt.minWait {
|
||||
t.Errorf("wait = %v, want >= %v", wait, tt.minWait)
|
||||
@@ -915,6 +930,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) {
|
||||
t.Setenv(antigravityForwardBaseURLEnv, "")
|
||||
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
}()
|
||||
|
||||
prodURL := "https://prod.test"
|
||||
dailyURL := "https://daily.test"
|
||||
antigravity.BaseURLs = []string{dailyURL, prodURL}
|
||||
|
||||
resolved := resolveAntigravityForwardBaseURL()
|
||||
require.Equal(t, dailyURL, resolved)
|
||||
}
|
||||
|
||||
func TestAntigravityAccountSwitchError_Error(t *testing.T) {
|
||||
err := &AntigravityAccountSwitchError{
|
||||
OriginalAccountID: 789,
|
||||
|
||||
@@ -153,13 +153,14 @@ func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *te
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 503 + 39s >= 7s 阈值
|
||||
// 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,
|
||||
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel)
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||
]
|
||||
}
|
||||
@@ -339,13 +340,14 @@ func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testi
|
||||
|
||||
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
|
||||
// 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
|
||||
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,因为后者走独立的 60 次重试路径
|
||||
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
@@ -371,9 +373,9 @@ func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *t
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -294,8 +294,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test
|
||||
require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError
|
||||
func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) {
|
||||
// TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功
|
||||
// MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号
|
||||
func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
@@ -304,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值
|
||||
// 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s)
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
@@ -322,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
// mock: 第 1 次重试返回 200 成功
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{
|
||||
{StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
@@ -330,6 +339,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
httpUpstream: upstream,
|
||||
isStickySession: true,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
@@ -343,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.Nil(t, result.resp)
|
||||
require.NotNil(t, result.resp, "should return successful response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.err)
|
||||
require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted")
|
||||
require.Equal(t, account.ID, result.switchError.OriginalAccountID)
|
||||
require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel)
|
||||
require.True(t, result.switchError.IsStickySession)
|
||||
require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError")
|
||||
|
||||
// 验证模型限流已设置
|
||||
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||
require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey)
|
||||
// 不应设置模型限流
|
||||
require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit")
|
||||
require.Len(t, upstream.calls, 1, "should have made one retry call before success")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消
|
||||
func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Name: "acc-3",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
// 立即取消上下文,验证重试循环能正确退出
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.Error(t, result.err, "should return context error")
|
||||
require.Nil(t, result.switchError, "should not return switchError on context cancel")
|
||||
require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑
|
||||
@@ -1129,20 +1190,20 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
||||
// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
|
||||
// 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定
|
||||
func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"code": 429,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
@@ -1162,16 +1223,16 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"code": 429,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||
antigravityTokenCacheSkew = 5 * time.Minute
|
||||
antigravityBackfillCooldown = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
@@ -23,6 +25,7 @@ type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
@@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
p.mergeCredentials(account, tokenInfo)
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
@@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现
|
||||
// "Invalid project resource name projects/"。
|
||||
// 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||
if p.shouldAttemptBackfill(account.ID) {
|
||||
p.markBackfillAttempted(account.ID)
|
||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||
account.Credentials["project_id"] = projectID
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
@@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段
|
||||
func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) {
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
}
|
||||
|
||||
// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试)
|
||||
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||
if lastAttempt, ok := v.(time.Time); ok {
|
||||
return time.Since(lastAttempt) > antigravityBackfillCooldown
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
|
||||
p.backfillCooldown.Store(accountID, time.Now())
|
||||
}
|
||||
|
||||
func AntigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
|
||||
@@ -61,6 +61,11 @@ func applyErrorPassthroughRule(
|
||||
errMsg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
// 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。
|
||||
if rule.SkipMonitoring {
|
||||
c.Set(OpsSkipPassthroughKey, true)
|
||||
}
|
||||
|
||||
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
|
||||
errType = "upstream_error"
|
||||
return status, errType, errMsg, true
|
||||
|
||||
@@ -194,6 +194,63 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
|
||||
assert.Equal(t, "Gemini上游失败", errField["message"])
|
||||
}
|
||||
|
||||
func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
|
||||
rule.SkipMonitoring = true
|
||||
|
||||
ruleSvc := &ErrorPassthroughService{}
|
||||
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
|
||||
BindErrorPassthroughService(c, ruleSvc)
|
||||
|
||||
_, _, _, matched := applyErrorPassthroughRule(
|
||||
c,
|
||||
PlatformAnthropic,
|
||||
http.StatusBadRequest,
|
||||
[]byte(`{"error":{"message":"prompt is too long"}}`),
|
||||
http.StatusBadGateway,
|
||||
"upstream_error",
|
||||
"Upstream request failed",
|
||||
)
|
||||
|
||||
assert.True(t, matched)
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限")
|
||||
rule.SkipMonitoring = false
|
||||
|
||||
ruleSvc := &ErrorPassthroughService{}
|
||||
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule})
|
||||
BindErrorPassthroughService(c, ruleSvc)
|
||||
|
||||
_, _, _, matched := applyErrorPassthroughRule(
|
||||
c,
|
||||
PlatformAnthropic,
|
||||
http.StatusBadRequest,
|
||||
[]byte(`{"error":{"message":"prompt is too long"}}`),
|
||||
http.StatusBadGateway,
|
||||
"upstream_error",
|
||||
"Upstream request failed",
|
||||
)
|
||||
|
||||
assert.True(t, matched)
|
||||
_, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false")
|
||||
}
|
||||
|
||||
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
|
||||
return &model.ErrorPassthroughRule{
|
||||
ID: 1,
|
||||
|
||||
@@ -45,10 +45,20 @@ type ErrorPassthroughService struct {
|
||||
cache ErrorPassthroughCache
|
||||
|
||||
// 本地内存缓存,用于快速匹配
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localCache []*cachedPassthroughRule
|
||||
localCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower
|
||||
type cachedPassthroughRule struct {
|
||||
*model.ErrorPassthroughRule
|
||||
lowerKeywords []string // 预计算的小写关键词
|
||||
lowerPlatforms []string // 预计算的小写平台
|
||||
errorCodeSet map[int]struct{} // 预计算的 error code set
|
||||
}
|
||||
|
||||
const maxBodyMatchLen = 8 << 10 // 8KB,错误信息不会在 8KB 之后才出现
|
||||
|
||||
// NewErrorPassthroughService 创建错误透传规则服务
|
||||
func NewErrorPassthroughService(
|
||||
repo ErrorPassthroughRepository,
|
||||
@@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
lowerPlatform := strings.ToLower(platform)
|
||||
var bodyLower string // 延迟初始化,只在需要关键词匹配时计算
|
||||
var bodyLowerDone bool
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if !s.platformMatches(rule, platform) {
|
||||
if !s.platformMatchesCached(rule, lowerPlatform) {
|
||||
continue
|
||||
}
|
||||
if s.ruleMatches(rule, statusCode, bodyStr) {
|
||||
return rule
|
||||
if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) {
|
||||
return rule.ErrorPassthroughRule
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
||||
}
|
||||
|
||||
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
|
||||
s.localCacheMu.RLock()
|
||||
rules := s.localCache
|
||||
s.localCacheMu.RUnlock()
|
||||
@@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setLocalCache 设置本地缓存
|
||||
// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算
|
||||
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
||||
cached := make([]*cachedPassthroughRule, len(rules))
|
||||
for i, r := range rules {
|
||||
cr := &cachedPassthroughRule{ErrorPassthroughRule: r}
|
||||
if len(r.Keywords) > 0 {
|
||||
cr.lowerKeywords = make([]string, len(r.Keywords))
|
||||
for j, kw := range r.Keywords {
|
||||
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||
}
|
||||
}
|
||||
if len(r.Platforms) > 0 {
|
||||
cr.lowerPlatforms = make([]string, len(r.Platforms))
|
||||
for j, p := range r.Platforms {
|
||||
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||
}
|
||||
}
|
||||
if len(r.ErrorCodes) > 0 {
|
||||
cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes))
|
||||
for _, code := range r.ErrorCodes {
|
||||
cr.errorCodeSet[code] = struct{}{}
|
||||
}
|
||||
}
|
||||
cached[i] = cr
|
||||
}
|
||||
|
||||
// 按优先级排序
|
||||
sorted := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
sort.Slice(cached, func(i, j int) bool {
|
||||
return cached[i].Priority < cached[j].Priority
|
||||
})
|
||||
|
||||
s.localCacheMu.Lock()
|
||||
s.localCache = sorted
|
||||
s.localCache = cached
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// platformMatches 检查平台是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
|
||||
// 如果没有配置平台限制,则匹配所有平台
|
||||
if len(rule.Platforms) == 0 {
|
||||
// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB
|
||||
func ensureBodyLower(body []byte, bodyLower *string, done *bool) string {
|
||||
if *done {
|
||||
return *bodyLower
|
||||
}
|
||||
b := body
|
||||
if len(b) > maxBodyMatchLen {
|
||||
b = b[:maxBodyMatchLen]
|
||||
}
|
||||
*bodyLower = strings.ToLower(string(b))
|
||||
*done = true
|
||||
return *bodyLower
|
||||
}
|
||||
|
||||
// platformMatchesCached 使用预计算的小写平台检查是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool {
|
||||
if len(rule.lowerPlatforms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
platform = strings.ToLower(platform)
|
||||
for _, p := range rule.Platforms {
|
||||
if strings.ToLower(p) == platform {
|
||||
for _, p := range rule.lowerPlatforms {
|
||||
if p == lowerPlatform {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches 检查规则是否匹配
|
||||
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
|
||||
hasErrorCodes := len(rule.ErrorCodes) > 0
|
||||
hasKeywords := len(rule.Keywords) > 0
|
||||
// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换
|
||||
func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool {
|
||||
hasErrorCodes := len(rule.errorCodeSet) > 0
|
||||
hasKeywords := len(rule.lowerKeywords) > 0
|
||||
|
||||
// 如果没有配置任何条件,不匹配
|
||||
if !hasErrorCodes && !hasKeywords {
|
||||
return false
|
||||
}
|
||||
|
||||
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
|
||||
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
|
||||
codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode)
|
||||
|
||||
if rule.MatchMode == model.MatchModeAll {
|
||||
// "all" 模式:所有配置的条件都必须满足
|
||||
return codeMatch && keywordMatch
|
||||
// "all" 模式:所有配置的条件都必须满足,短路
|
||||
if hasErrorCodes && !codeMatch {
|
||||
return false
|
||||
}
|
||||
if hasKeywords {
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch
|
||||
}
|
||||
|
||||
// "any" 模式:任一条件满足即可
|
||||
// "any" 模式:任一条件满足即可,短路
|
||||
if hasErrorCodes && hasKeywords {
|
||||
return codeMatch || keywordMatch
|
||||
if codeMatch {
|
||||
return true
|
||||
}
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch && keywordMatch
|
||||
// 只配置了一种条件
|
||||
if hasKeywords {
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch
|
||||
}
|
||||
|
||||
// containsInt 检查切片是否包含指定整数
|
||||
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
|
||||
for _, v := range slice {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
|
||||
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(bodyLower, strings.ToLower(kw)) {
|
||||
// containsIntSet 使用 map 查找替代线性扫描
|
||||
func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool {
|
||||
_, ok := set[val]
|
||||
return ok
|
||||
}
|
||||
|
||||
// containsAnyKeywordCached 使用预计算的小写关键词检查匹配
|
||||
func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool {
|
||||
for _, kw := range lowerKeywords {
|
||||
if strings.Contains(bodyLower, kw) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic
|
||||
return svc
|
||||
}
|
||||
|
||||
// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule(测试用)
|
||||
func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule {
|
||||
cr := &cachedPassthroughRule{ErrorPassthroughRule: rule}
|
||||
if len(rule.Keywords) > 0 {
|
||||
cr.lowerKeywords = make([]string, len(rule.Keywords))
|
||||
for j, kw := range rule.Keywords {
|
||||
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||
}
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
cr.lowerPlatforms = make([]string, len(rule.Platforms))
|
||||
for j, p := range rule.Platforms {
|
||||
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||
}
|
||||
}
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes))
|
||||
for _, code := range rule.ErrorCodes {
|
||||
cr.errorCodeSet[code] = struct{}{}
|
||||
}
|
||||
}
|
||||
return cr
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 ruleMatches 核心匹配逻辑
|
||||
// 测试 ruleMatchesOptimized 核心匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestRuleMatches_NoConditions(t *testing.T) {
|
||||
// 没有配置任何条件时,不应该匹配
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone),
|
||||
"没有配置条件时不应该匹配")
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
|
||||
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{"context limit", "model not supported"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
||||
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
||||
{"关键词不匹配", 422, "some other error", false},
|
||||
// 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的
|
||||
// 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches
|
||||
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
|
||||
{"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 MatchRule 的行为:先转换为小写
|
||||
bodyLower := strings.ToLower(tt.body)
|
||||
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
// any 模式:错误码 OR 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
@@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
// all 模式:错误码 AND 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 platformMatches 平台匹配逻辑
|
||||
// 测试 platformMatchesCached 平台匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestPlatformMatches(t *testing.T) {
|
||||
@@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Platforms: tt.rulePlatforms,
|
||||
}
|
||||
result := svc.platformMatches(rule, tt.requestPlatform)
|
||||
})
|
||||
result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform))
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -368,15 +368,31 @@ type ForwardResult struct {
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
|
||||
RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
|
||||
}
|
||||
|
||||
// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。
|
||||
// 由 handler 层在同账号重试全部用尽、切换账号时调用。
|
||||
func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) {
|
||||
if failoverErr == nil || !failoverErr.RetryableOnSameAccount {
|
||||
return
|
||||
}
|
||||
// 根据状态码选择封禁策略
|
||||
switch failoverErr.StatusCode {
|
||||
case http.StatusBadRequest:
|
||||
tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]")
|
||||
case http.StatusBadGateway:
|
||||
tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]")
|
||||
}
|
||||
}
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
|
||||
@@ -880,6 +880,37 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
// ErrorPolicyNone → 原有逻辑
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if isGoogleProjectConfigError(msg400) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true}
|
||||
}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
@@ -1330,6 +1361,34 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
// ErrorPolicyNone → 原有逻辑
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
// 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if isGoogleProjectConfigError(msg400) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
}
|
||||
log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true}
|
||||
}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
|
||||
@@ -20,6 +20,10 @@ const (
|
||||
// retry the specific upstream attempt (not just the client request).
|
||||
// This value is sanitized+trimmed before being persisted.
|
||||
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
||||
|
||||
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
|
||||
// ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。
|
||||
OpsSkipPassthroughKey = "ops_skip_passthrough"
|
||||
)
|
||||
|
||||
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||
@@ -103,6 +107,37 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
||||
evCopy := ev
|
||||
existing = append(existing, &evCopy)
|
||||
c.Set(OpsUpstreamErrorsKey, existing)
|
||||
|
||||
checkSkipMonitoringForUpstreamEvent(c, &evCopy)
|
||||
}
|
||||
|
||||
// checkSkipMonitoringForUpstreamEvent checks whether the upstream error event
|
||||
// matches a passthrough rule with skip_monitoring=true and, if so, sets the
|
||||
// OpsSkipPassthroughKey on the context. This ensures intermediate retry /
|
||||
// failover errors (which never go through the final applyErrorPassthroughRule
|
||||
// path) can still suppress ops_error_logs recording.
|
||||
func checkSkipMonitoringForUpstreamEvent(c *gin.Context, ev *OpsUpstreamErrorEvent) {
|
||||
if ev.UpstreamStatusCode == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
svc := getBoundErrorPassthroughService(c)
|
||||
if svc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use the best available body representation for keyword matching.
|
||||
// Even when body is empty, MatchRule can still match rules that only
|
||||
// specify ErrorCodes (no Keywords), so we always call it.
|
||||
body := ev.Detail
|
||||
if body == "" {
|
||||
body = ev.Message
|
||||
}
|
||||
|
||||
rule := svc.MatchRule(ev.Platform, ev.UpstreamStatusCode, []byte(body))
|
||||
if rule != nil && rule.SkipMonitoring {
|
||||
c.Set(OpsSkipPassthroughKey, true)
|
||||
}
|
||||
}
|
||||
|
||||
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
|
||||
|
||||
Reference in New Issue
Block a user