perf: 错误处理性能优化
- MatchRule 延迟/限制 body ToLower,先用 statusCode 短路,只在需要关键词匹配时转换且限制 8KB - 预计算规则的小写关键词/平台和 error code set,消除运行时重复 ToLower 和线性扫描 - MODEL_CAPACITY_EXHAUSTED 全局去重,避免并发请求重复重试同一模型 - 503 重试 body 读取限制从 2MB 降至 8KB - time.After 替换为 time.NewTimer,防止 context 取消时 timer 泄漏
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -66,6 +67,9 @@ const (
|
||||
// 单账号 503 退避重试:原地重试的总累计等待时间上限
|
||||
// 超过此上限将不再重试,直接返回 503
|
||||
antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间
|
||||
antigravityModelCapacityCooldown = 10 * time.Second
|
||||
)
|
||||
|
||||
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
||||
@@ -74,6 +78,12 @@ var antigravityPassthroughErrorMessages = []string{
|
||||
"prompt is too long",
|
||||
}
|
||||
|
||||
// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试
|
||||
var (
|
||||
modelCapacityExhaustedMu sync.RWMutex
|
||||
modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||
@@ -211,17 +221,38 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
if isModelCapacityExhausted {
|
||||
maxAttempts = antigravityModelCapacityRetryMaxAttempts
|
||||
waitDuration = antigravityModelCapacityRetryWait
|
||||
|
||||
// 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503
|
||||
if modelName != "" {
|
||||
modelCapacityExhaustedMu.RLock()
|
||||
cooldownUntil, exists := modelCapacityExhaustedUntil[modelName]
|
||||
modelCapacityExhaustedMu.RUnlock()
|
||||
if exists && time.Now().Before(cooldownUntil) {
|
||||
log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)",
|
||||
p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05"))
|
||||
return &smartRetryResult{
|
||||
action: smartRetryActionBreakWithResp,
|
||||
resp: &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID)
|
||||
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
timer.Stop()
|
||||
log.Printf("%s status=context_canceled_during_smart_retry", p.prefix)
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||
case <-time.After(waitDuration):
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
// 智能重试:创建新请求
|
||||
@@ -242,6 +273,12 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable {
|
||||
log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts)
|
||||
// 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown
|
||||
if isModelCapacityExhausted && modelName != "" {
|
||||
modelCapacityExhaustedMu.Lock()
|
||||
delete(modelCapacityExhaustedUntil, modelName)
|
||||
modelCapacityExhaustedMu.Unlock()
|
||||
}
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp}
|
||||
}
|
||||
|
||||
@@ -257,7 +294,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
}
|
||||
lastRetryResp = retryResp
|
||||
if retryResp != nil {
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||
_ = retryResp.Body.Close()
|
||||
}
|
||||
|
||||
@@ -283,6 +320,12 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam
|
||||
// MODEL_CAPACITY_EXHAUSTED:模型容量不足,切换账号无意义
|
||||
// 直接返回上游错误响应,不设置模型限流,不切换账号
|
||||
if isModelCapacityExhausted {
|
||||
// 设置 cooldown,让后续请求快速失败,避免重复重试
|
||||
if modelName != "" {
|
||||
modelCapacityExhaustedMu.Lock()
|
||||
modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown)
|
||||
modelCapacityExhaustedMu.Unlock()
|
||||
}
|
||||
log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)",
|
||||
p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200))
|
||||
return &smartRetryResult{
|
||||
@@ -395,11 +438,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d",
|
||||
p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID)
|
||||
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
timer.Stop()
|
||||
log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix)
|
||||
return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()}
|
||||
case <-time.After(waitDuration):
|
||||
case <-timer.C:
|
||||
}
|
||||
totalWaited += waitDuration
|
||||
|
||||
@@ -433,7 +478,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace(
|
||||
_ = lastRetryResp.Body.Close()
|
||||
}
|
||||
lastRetryResp = retryResp
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||
_ = retryResp.Body.Close()
|
||||
|
||||
// 解析新的重试信息,更新下次等待时间
|
||||
@@ -1404,7 +1449,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
break
|
||||
}
|
||||
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10))
|
||||
_ = retryResp.Body.Close()
|
||||
if retryResp.StatusCode == http.StatusTooManyRequests {
|
||||
retryBaseURL := ""
|
||||
@@ -2211,10 +2256,12 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
sleepFor = 0
|
||||
}
|
||||
|
||||
timer := time.NewTimer(sleepFor)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return false
|
||||
case <-time.After(sleepFor):
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,10 +45,20 @@ type ErrorPassthroughService struct {
|
||||
cache ErrorPassthroughCache
|
||||
|
||||
// 本地内存缓存,用于快速匹配
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localCache []*cachedPassthroughRule
|
||||
localCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower
|
||||
type cachedPassthroughRule struct {
|
||||
*model.ErrorPassthroughRule
|
||||
lowerKeywords []string // 预计算的小写关键词
|
||||
lowerPlatforms []string // 预计算的小写平台
|
||||
errorCodeSet map[int]struct{} // 预计算的 error code set
|
||||
}
|
||||
|
||||
const maxBodyMatchLen = 8 << 10 // 8KB,错误信息不会在 8KB 之后才出现
|
||||
|
||||
// NewErrorPassthroughService 创建错误透传规则服务
|
||||
func NewErrorPassthroughService(
|
||||
repo ErrorPassthroughRepository,
|
||||
@@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
lowerPlatform := strings.ToLower(platform)
|
||||
var bodyLower string // 延迟初始化,只在需要关键词匹配时计算
|
||||
var bodyLowerDone bool
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if !s.platformMatches(rule, platform) {
|
||||
if !s.platformMatchesCached(rule, lowerPlatform) {
|
||||
continue
|
||||
}
|
||||
if s.ruleMatches(rule, statusCode, bodyStr) {
|
||||
return rule
|
||||
if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) {
|
||||
return rule.ErrorPassthroughRule
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod
|
||||
}
|
||||
|
||||
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule {
|
||||
s.localCacheMu.RLock()
|
||||
rules := s.localCache
|
||||
s.localCacheMu.RUnlock()
|
||||
@@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setLocalCache 设置本地缓存
|
||||
// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算
|
||||
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
||||
cached := make([]*cachedPassthroughRule, len(rules))
|
||||
for i, r := range rules {
|
||||
cr := &cachedPassthroughRule{ErrorPassthroughRule: r}
|
||||
if len(r.Keywords) > 0 {
|
||||
cr.lowerKeywords = make([]string, len(r.Keywords))
|
||||
for j, kw := range r.Keywords {
|
||||
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||
}
|
||||
}
|
||||
if len(r.Platforms) > 0 {
|
||||
cr.lowerPlatforms = make([]string, len(r.Platforms))
|
||||
for j, p := range r.Platforms {
|
||||
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||
}
|
||||
}
|
||||
if len(r.ErrorCodes) > 0 {
|
||||
cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes))
|
||||
for _, code := range r.ErrorCodes {
|
||||
cr.errorCodeSet[code] = struct{}{}
|
||||
}
|
||||
}
|
||||
cached[i] = cr
|
||||
}
|
||||
|
||||
// 按优先级排序
|
||||
sorted := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
sort.Slice(cached, func(i, j int) bool {
|
||||
return cached[i].Priority < cached[j].Priority
|
||||
})
|
||||
|
||||
s.localCacheMu.Lock()
|
||||
s.localCache = sorted
|
||||
s.localCache = cached
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// platformMatches 检查平台是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
|
||||
// 如果没有配置平台限制,则匹配所有平台
|
||||
if len(rule.Platforms) == 0 {
|
||||
// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB
|
||||
func ensureBodyLower(body []byte, bodyLower *string, done *bool) string {
|
||||
if *done {
|
||||
return *bodyLower
|
||||
}
|
||||
b := body
|
||||
if len(b) > maxBodyMatchLen {
|
||||
b = b[:maxBodyMatchLen]
|
||||
}
|
||||
*bodyLower = strings.ToLower(string(b))
|
||||
*done = true
|
||||
return *bodyLower
|
||||
}
|
||||
|
||||
// platformMatchesCached 使用预计算的小写平台检查是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool {
|
||||
if len(rule.lowerPlatforms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
platform = strings.ToLower(platform)
|
||||
for _, p := range rule.Platforms {
|
||||
if strings.ToLower(p) == platform {
|
||||
for _, p := range rule.lowerPlatforms {
|
||||
if p == lowerPlatform {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches 检查规则是否匹配
|
||||
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
|
||||
hasErrorCodes := len(rule.ErrorCodes) > 0
|
||||
hasKeywords := len(rule.Keywords) > 0
|
||||
// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换
|
||||
func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool {
|
||||
hasErrorCodes := len(rule.errorCodeSet) > 0
|
||||
hasKeywords := len(rule.lowerKeywords) > 0
|
||||
|
||||
// 如果没有配置任何条件,不匹配
|
||||
if !hasErrorCodes && !hasKeywords {
|
||||
return false
|
||||
}
|
||||
|
||||
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
|
||||
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
|
||||
codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode)
|
||||
|
||||
if rule.MatchMode == model.MatchModeAll {
|
||||
// "all" 模式:所有配置的条件都必须满足
|
||||
return codeMatch && keywordMatch
|
||||
// "all" 模式:所有配置的条件都必须满足,短路
|
||||
if hasErrorCodes && !codeMatch {
|
||||
return false
|
||||
}
|
||||
if hasKeywords {
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch
|
||||
}
|
||||
|
||||
// "any" 模式:任一条件满足即可
|
||||
// "any" 模式:任一条件满足即可,短路
|
||||
if hasErrorCodes && hasKeywords {
|
||||
return codeMatch || keywordMatch
|
||||
if codeMatch {
|
||||
return true
|
||||
}
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch && keywordMatch
|
||||
// 只配置了一种条件
|
||||
if hasKeywords {
|
||||
return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords)
|
||||
}
|
||||
return codeMatch
|
||||
}
|
||||
|
||||
// containsInt 检查切片是否包含指定整数
|
||||
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
|
||||
for _, v := range slice {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
|
||||
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(bodyLower, strings.ToLower(kw)) {
|
||||
// containsIntSet 使用 map 查找替代线性扫描
|
||||
func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool {
|
||||
_, ok := set[val]
|
||||
return ok
|
||||
}
|
||||
|
||||
// containsAnyKeywordCached 使用预计算的小写关键词检查匹配
|
||||
func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool {
|
||||
for _, kw := range lowerKeywords {
|
||||
if strings.Contains(bodyLower, kw) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic
|
||||
return svc
|
||||
}
|
||||
|
||||
// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule(测试用)
|
||||
func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule {
|
||||
cr := &cachedPassthroughRule{ErrorPassthroughRule: rule}
|
||||
if len(rule.Keywords) > 0 {
|
||||
cr.lowerKeywords = make([]string, len(rule.Keywords))
|
||||
for j, kw := range rule.Keywords {
|
||||
cr.lowerKeywords[j] = strings.ToLower(kw)
|
||||
}
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
cr.lowerPlatforms = make([]string, len(rule.Platforms))
|
||||
for j, p := range rule.Platforms {
|
||||
cr.lowerPlatforms[j] = strings.ToLower(p)
|
||||
}
|
||||
}
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes))
|
||||
for _, code := range rule.ErrorCodes {
|
||||
cr.errorCodeSet[code] = struct{}{}
|
||||
}
|
||||
}
|
||||
return cr
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 ruleMatches 核心匹配逻辑
|
||||
// 测试 ruleMatchesOptimized 核心匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestRuleMatches_NoConditions(t *testing.T) {
|
||||
// 没有配置任何条件时,不应该匹配
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone),
|
||||
"没有配置条件时不应该匹配")
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
|
||||
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{"context limit", "model not supported"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
||||
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
||||
{"关键词不匹配", 422, "some other error", false},
|
||||
// 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的
|
||||
// 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches
|
||||
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
|
||||
{"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 MatchRule 的行为:先转换为小写
|
||||
bodyLower := strings.ToLower(tt.body)
|
||||
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
// any 模式:错误码 OR 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
@@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
// all 模式:错误码 AND 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll,
|
||||
}
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
var bodyLower string
|
||||
var bodyLowerDone bool
|
||||
result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 platformMatches 平台匹配逻辑
|
||||
// 测试 platformMatchesCached 平台匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestPlatformMatches(t *testing.T) {
|
||||
@@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
rule := newCachedRuleForTest(&model.ErrorPassthroughRule{
|
||||
Platforms: tt.rulePlatforms,
|
||||
}
|
||||
result := svc.platformMatches(rule, tt.requestPlatform)
|
||||
})
|
||||
result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform))
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user