From a54b81cf74ffaadec468e6f611d0c512647b5e7d Mon Sep 17 00:00:00 2001 From: Edric Li Date: Tue, 10 Feb 2026 21:40:31 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - MatchRule 延迟/限制 body ToLower,先用 statusCode 短路,只在需要关键词匹配时转换且限制 8KB - 预计算规则的小写关键词/平台和 error code set,消除运行时重复 ToLower 和线性扫描 - MODEL_CAPACITY_EXHAUSTED 全局去重,避免并发请求重复重试同一模型 - 503 重试 body 读取限制从 2MB 降至 8KB - time.After 替换为 time.NewTimer,防止 context 取消时 timer 泄漏 --- .../service/antigravity_gateway_service.go | 59 ++++++- .../service/error_passthrough_service.go | 145 ++++++++++++------ .../service/error_passthrough_service_test.go | 80 +++++++--- 3 files changed, 206 insertions(+), 78 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index a110f4e0..7d3e5f19 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -16,6 +16,7 @@ import ( "os" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -66,6 +67,9 @@ const ( // 单账号 503 退避重试:原地重试的总累计等待时间上限 // 超过此上限将不再重试,直接返回 503 antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second + + // MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间 + antigravityModelCapacityCooldown = 10 * time.Second ) // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) @@ -74,6 +78,12 @@ var antigravityPassthroughErrorMessages = []string{ "prompt is too long", } +// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试 +var ( + modelCapacityExhaustedMu sync.RWMutex + modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until +) + const ( antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" @@ -211,17 +221,38 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam if isModelCapacityExhausted { maxAttempts = antigravityModelCapacityRetryMaxAttempts waitDuration = antigravityModelCapacityRetryWait + + // 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503 + if modelName != "" { + modelCapacityExhaustedMu.RLock() + cooldownUntil, exists := modelCapacityExhaustedUntil[modelName] + modelCapacityExhaustedMu.RUnlock() + if exists && time.Now().Before(cooldownUntil) { + log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)", + p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05")) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + }, + } + } + } } for attempt := 1; attempt <= maxAttempts; attempt++ { log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID) + timer := time.NewTimer(waitDuration) select { case <-p.ctx.Done(): + timer.Stop() log.Printf("%s status=context_canceled_during_smart_retry", p.prefix) return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} - case <-time.After(waitDuration): + case <-timer.C: } // 智能重试:创建新请求 @@ -242,6 +273,12 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts) + // 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown + if isModelCapacityExhausted && modelName != "" { + modelCapacityExhaustedMu.Lock() + delete(modelCapacityExhaustedUntil, modelName) + modelCapacityExhaustedMu.Unlock() + } return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} } @@ -257,7 +294,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } lastRetryResp = retryResp if retryResp != nil { - lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) _ = retryResp.Body.Close() } @@ -283,6 +320,12 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // MODEL_CAPACITY_EXHAUSTED:模型容量不足,切换账号无意义 // 直接返回上游错误响应,不设置模型限流,不切换账号 if isModelCapacityExhausted { + // 设置 cooldown,让后续请求快速失败,避免重复重试 + if modelName != "" { + modelCapacityExhaustedMu.Lock() + modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown) + modelCapacityExhaustedMu.Unlock() + } log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)", p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200)) return &smartRetryResult{ @@ -395,11 +438,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID) + timer := time.NewTimer(waitDuration) select { case <-p.ctx.Done(): + timer.Stop() log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix) return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} - case <-time.After(waitDuration): + case <-timer.C: } totalWaited += waitDuration @@ -433,7 +478,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( _ = lastRetryResp.Body.Close() } lastRetryResp = retryResp - lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) _ = retryResp.Body.Close() // 解析新的重试信息,更新下次等待时间 @@ -1404,7 +1449,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, break } - retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) _ = retryResp.Body.Close() if retryResp.StatusCode == http.StatusTooManyRequests { retryBaseURL := "" @@ -2211,10 +2256,12 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { sleepFor = 0 } + timer := time.NewTimer(sleepFor) select { case <-ctx.Done(): + timer.Stop() return false - case <-time.After(sleepFor): + case <-timer.C: return true } } diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index c3e0f630..caf12676 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -45,10 +45,20 @@ type ErrorPassthroughService struct { cache ErrorPassthroughCache // 本地内存缓存,用于快速匹配 - localCache []*model.ErrorPassthroughRule + localCache []*cachedPassthroughRule localCacheMu sync.RWMutex } +// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower +type cachedPassthroughRule struct { + *model.ErrorPassthroughRule + lowerKeywords []string // 预计算的小写关键词 + lowerPlatforms []string // 预计算的小写平台 + errorCodeSet map[int]struct{} // 预计算的 error code set +} + +const maxBodyMatchLen = 8 << 10 // 8KB,错误信息不会在 8KB 之后才出现 + // NewErrorPassthroughService 创建错误透传规则服务 func NewErrorPassthroughService( repo ErrorPassthroughRepository, @@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod return nil } - bodyStr := strings.ToLower(string(body)) + lowerPlatform := strings.ToLower(platform) + var bodyLower string // 延迟初始化,只在需要关键词匹配时计算 + var bodyLowerDone bool for _, rule := range rules { if !rule.Enabled { continue } - if !s.platformMatches(rule, platform) { + if !s.platformMatchesCached(rule, lowerPlatform) { continue } - if s.ruleMatches(rule, statusCode, bodyStr) { - return rule + if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) { + return rule.ErrorPassthroughRule } } @@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod } // getCachedRules 获取缓存的规则列表(按优先级排序) -func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule { +func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule { s.localCacheMu.RLock() rules := s.localCache s.localCacheMu.RUnlock() @@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { return nil } -// setLocalCache 设置本地缓存 +// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算 func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { + cached := make([]*cachedPassthroughRule, len(rules)) + for i, r := range rules { + cr := &cachedPassthroughRule{ErrorPassthroughRule: r} + if len(r.Keywords) > 0 { + cr.lowerKeywords = make([]string, len(r.Keywords)) + for j, kw := range r.Keywords { + cr.lowerKeywords[j] = strings.ToLower(kw) + } + } + if len(r.Platforms) > 0 { + cr.lowerPlatforms = make([]string, len(r.Platforms)) + for j, p := range r.Platforms { + cr.lowerPlatforms[j] = strings.ToLower(p) + } + } + if len(r.ErrorCodes) > 0 { + cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes)) + for _, code := range r.ErrorCodes { + cr.errorCodeSet[code] = struct{}{} + } + } + cached[i] = cr + } + // 按优先级排序 - sorted := make([]*model.ErrorPassthroughRule, len(rules)) - copy(sorted, rules) - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Priority < sorted[j].Priority + sort.Slice(cached, func(i, j int) bool { + return cached[i].Priority < cached[j].Priority }) s.localCacheMu.Lock() - s.localCache = sorted + s.localCache = cached s.localCacheMu.Unlock() } @@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { } } -// platformMatches 检查平台是否匹配 -func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool { - // 如果没有配置平台限制,则匹配所有平台 - if len(rule.Platforms) == 0 { +// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB +func ensureBodyLower(body []byte, bodyLower *string, done *bool) string { + if *done { + return *bodyLower + } + b := body + if len(b) > maxBodyMatchLen { + b = b[:maxBodyMatchLen] + } + *bodyLower = strings.ToLower(string(b)) + *done = true + return *bodyLower +} + +// platformMatchesCached 使用预计算的小写平台检查是否匹配 +func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool { + if len(rule.lowerPlatforms) == 0 { return true } - - platform = strings.ToLower(platform) - for _, p := range rule.Platforms { - if strings.ToLower(p) == platform { + for _, p := range rule.lowerPlatforms { + if p == lowerPlatform { return true } } - return false } -// ruleMatches 检查规则是否匹配 -func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool { - hasErrorCodes := len(rule.ErrorCodes) > 0 - hasKeywords := len(rule.Keywords) > 0 +// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换 +func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool { + hasErrorCodes := len(rule.errorCodeSet) > 0 + hasKeywords := len(rule.lowerKeywords) > 0 - // 如果没有配置任何条件,不匹配 if !hasErrorCodes && !hasKeywords { return false } - codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode) - keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords) + codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode) if rule.MatchMode == model.MatchModeAll { - // "all" 模式:所有配置的条件都必须满足 - return codeMatch && keywordMatch + // "all" 模式:所有配置的条件都必须满足,短路 + if hasErrorCodes && !codeMatch { + return false + } + if hasKeywords { + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) + } + return codeMatch } - // "any" 模式:任一条件满足即可 + // "any" 模式:任一条件满足即可,短路 if hasErrorCodes && hasKeywords { - return codeMatch || keywordMatch + if codeMatch { + return true + } + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) } - return codeMatch && keywordMatch + // 只配置了一种条件 + if hasKeywords { + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) + } + return codeMatch } -// containsInt 检查切片是否包含指定整数 -func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool { - for _, v := range slice { - if v == val { - return true - } - } - return false -} - -// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写) -func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool { - for _, kw := range keywords { - if strings.Contains(bodyLower, strings.ToLower(kw)) { +// containsIntSet 使用 map 查找替代线性扫描 +func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool { + _, ok := set[val] + return ok +} + +// containsAnyKeywordCached 使用预计算的小写关键词检查匹配 +func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool { + for _, kw := range lowerKeywords { + if strings.Contains(bodyLower, kw) { return true } } diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go index 74c98d86..96ddd637 100644 --- a/backend/internal/service/error_passthrough_service_test.go +++ b/backend/internal/service/error_passthrough_service_test.go @@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic return svc } +// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule(测试用) +func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule { + cr := &cachedPassthroughRule{ErrorPassthroughRule: rule} + if len(rule.Keywords) > 0 { + cr.lowerKeywords = make([]string, len(rule.Keywords)) + for j, kw := range rule.Keywords { + cr.lowerKeywords[j] = strings.ToLower(kw) + } + } + if len(rule.Platforms) > 0 { + cr.lowerPlatforms = make([]string, len(rule.Platforms)) + for j, p := range rule.Platforms { + cr.lowerPlatforms[j] = strings.ToLower(p) + } + } + if len(rule.ErrorCodes) > 0 { + cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes)) + for _, code := range rule.ErrorCodes { + cr.errorCodeSet[code] = struct{}{} + } + } + return cr +} + // ============================================================================= -// 测试 ruleMatches 核心匹配逻辑 +// 测试 ruleMatchesOptimized 核心匹配逻辑 // ============================================================================= func TestRuleMatches_NoConditions(t *testing.T) { // 没有配置任何条件时,不应该匹配 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{}, MatchMode: model.MatchModeAny, - } + }) - assert.False(t, svc.ruleMatches(rule, 422, "some error message"), + var bodyLower string + var bodyLowerDone bool + assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone), "没有配置条件时不应该匹配") } func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result) }) } @@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{"context limit", "model not supported"}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { {"关键词匹配 context limit", 500, "error: context limit reached", true}, {"关键词匹配 model not supported", 400, "the model not supported here", true}, {"关键词不匹配", 422, "some other error", false}, - // 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的 - // 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches - {"关键词大小写 - 输入已小写", 500, "context limit exceeded", true}, + {"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 模拟 MatchRule 的行为:先转换为小写 - bodyLower := strings.ToLower(tt.body) - result := svc.ruleMatches(rule, tt.statusCode, bodyLower) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result) }) } @@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { // any 模式:错误码 OR 关键词 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result, tt.reason) }) } @@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { func TestRuleMatches_BothConditions_AllMode(t *testing.T) { // all 模式:错误码 AND 关键词 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAll, - } + }) tests := []struct { name string @@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result, tt.reason) }) } } // ============================================================================= -// 测试 platformMatches 平台匹配逻辑 +// 测试 platformMatchesCached 平台匹配逻辑 // ============================================================================= func TestPlatformMatches(t *testing.T) { @@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Platforms: tt.rulePlatforms, - } - result := svc.platformMatches(rule, tt.requestPlatform) + }) + result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform)) assert.Equal(t, tt.expected, result) }) }