diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 87221256..944e0f84 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -435,8 +435,7 @@ type DefaultConfig struct { } type RateLimitConfig struct { - OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) - OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401 临时不可调度冷却时间(分钟) + OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } // APIKeyAuthCacheConfig API Key 认证缓存配置 @@ -710,7 +709,6 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) - viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 5) // Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查 viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json") diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 8ef641ba..ca479486 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -73,10 +73,8 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc return false } - isOAuth401 := statusCode == 401 && account.Type == AccountTypeOAuth && - (account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) tempMatched := false - if !isOAuth401 || account.IsTempUnschedulableEnabled() { + if statusCode != 401 { tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody) } upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) @@ -87,18 +85,13 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc switch statusCode { case 401: - if isOAuth401 { - if tempMatched { - if s.tokenCacheInvalidator != nil { - if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { - slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) - } + if account.Type == AccountTypeOAuth && + (account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) { + if s.tokenCacheInvalidator != nil { + if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { + slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) } - shouldDisable = true - } else { - shouldDisable = s.handleOAuth401TempUnschedulable(ctx, account, upstreamMsg) } - break } msg := "Authentication failed (401): invalid or expired credentials" if upstreamMsg != "" { @@ -150,63 +143,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc return shouldDisable } -func (s *RateLimitService) handleOAuth401TempUnschedulable(ctx context.Context, account *Account, upstreamMsg string) bool { - if account == nil { - return false - } - - if s.tokenCacheInvalidator != nil { - if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { - slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) - } - } - - now := time.Now() - until := now.Add(s.oauth401Cooldown()) - msg := "Authentication failed (401): invalid or expired credentials" - if upstreamMsg != "" { - msg = "Authentication failed (401): " + upstreamMsg - } - - state := &TempUnschedState{ - UntilUnix: until.Unix(), - TriggeredAtUnix: now.Unix(), - StatusCode: 401, - MatchedKeyword: "oauth_401", - RuleIndex: -1, // -1 表示非规则触发,而是 OAuth 401 特殊处理 - ErrorMessage: msg, - } - - reason := "" - if raw, err := json.Marshal(state); err == nil { - reason = string(raw) - } - if reason == "" { - reason = msg - } - - if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { - slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) - return false - } - - if s.tempUnschedCache != nil { - if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { - slog.Warn("oauth_401_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err) - } - } - - slog.Info("oauth_401_temp_unschedulable", "account_id", account.ID, "until", until) - return true -} - -func (s *RateLimitService) oauth401Cooldown() time.Duration { - if s != nil && s.cfg != nil && s.cfg.RateLimit.OAuth401CooldownMinutes > 0 { - return time.Duration(s.cfg.RateLimit.OAuth401CooldownMinutes) * time.Minute - } - return 5 * time.Minute -} - // PreCheckUsage proactively checks local quota before dispatching a request. // Returns false when the account should be skipped. func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) { diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 2c43b1cf..36357a4b 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -15,21 +15,19 @@ import ( type rateLimitAccountRepoStub struct { mockAccountRepoForGemini - tempCalls int - tempUntil time.Time - tempReason string setErrorCalls int -} - -func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { - r.tempCalls++ - r.tempUntil = until - r.tempReason = reason - return nil + tempCalls int + lastErrorMsg string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { r.setErrorCalls++ + r.lastErrorMsg = errorMsg + return nil +} + +func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.tempCalls++ return nil } @@ -43,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc return r.err } -func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testing.T) { +func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { tests := []struct { name string platform string @@ -62,17 +60,26 @@ func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testi ID: 100, Platform: tt.platform, Type: AccountTypeOAuth, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": 401, + "keywords": []any{"unauthorized"}, + "duration_minutes": 30, + "description": "custom rule", + }, + }, + }, } - start := time.Now() shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.tempCalls) - require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)") require.Len(t, invalidator.accounts, 1) - require.WithinDuration(t, start.Add(5*time.Minute), repo.tempUntil, 10*time.Second) - require.NotEmpty(t, repo.tempReason) }) } } @@ -91,43 +98,10 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.tempCalls) - require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.setErrorCalls) require.Len(t, invalidator.accounts, 1) } -func TestRateLimitService_HandleUpstreamError_OAuth401CustomRule(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - invalidator := &tokenCacheInvalidatorRecorder{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - service.SetTokenCacheInvalidator(invalidator) - account := &Account{ - ID: 103, - Platform: PlatformGemini, - Type: AccountTypeOAuth, - Credentials: map[string]any{ - "temp_unschedulable_enabled": true, - "temp_unschedulable_rules": []any{ - map[string]any{ - "error_code": 401, - "keywords": []any{"unauthorized"}, - "duration_minutes": 30, - "description": "custom rule", - }, - }, - }, - } - - start := time.Now() - shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) - - require.True(t, shouldDisable) - require.Equal(t, 1, repo.tempCalls) - require.Equal(t, 0, repo.setErrorCalls) - require.Len(t, invalidator.accounts, 1) - require.WithinDuration(t, start.Add(30*time.Minute), repo.tempUntil, 10*time.Second) -} - func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{} @@ -142,212 +116,6 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 0, repo.tempCalls) require.Equal(t, 1, repo.setErrorCalls) require.Empty(t, invalidator.accounts) } - -// TestRateLimitService_HandleOAuth401_NilAccount 测试 account 为 nil 的情况 -func TestRateLimitService_HandleOAuth401_NilAccount(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - - result := service.handleOAuth401TempUnschedulable(context.Background(), nil, "error") - - require.False(t, result) - require.Equal(t, 0, repo.tempCalls) -} - -// TestRateLimitService_HandleOAuth401_NilInvalidator 测试 tokenCacheInvalidator 为 nil 的情况 -func TestRateLimitService_HandleOAuth401_NilInvalidator(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - // 不设置 tokenCacheInvalidator - account := &Account{ - ID: 200, - Platform: PlatformGemini, - Type: AccountTypeOAuth, - } - - result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error") - - require.True(t, result) - require.Equal(t, 1, repo.tempCalls) -} - -// TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed 测试 SetTempUnschedulable 失败的情况 -func TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed(t *testing.T) { - repo := &rateLimitAccountRepoStubWithError{ - setTempErr: errors.New("db error"), - } - invalidator := &tokenCacheInvalidatorRecorder{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - service.SetTokenCacheInvalidator(invalidator) - account := &Account{ - ID: 201, - Platform: PlatformGemini, - Type: AccountTypeOAuth, - } - - result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error") - - require.False(t, result) // 失败应返回 false - require.Len(t, invalidator.accounts, 1) // 但 invalidator 仍然被调用 -} - -// rateLimitAccountRepoStubWithError 支持返回错误的 stub -type rateLimitAccountRepoStubWithError struct { - mockAccountRepoForGemini - setTempErr error - setErrorCalls int -} - -func (r *rateLimitAccountRepoStubWithError) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { - return r.setTempErr -} - -func (r *rateLimitAccountRepoStubWithError) SetError(ctx context.Context, id int64, errorMsg string) error { - r.setErrorCalls++ - return nil -} - -// TestRateLimitService_HandleOAuth401_WithTempUnschedCache 测试 tempUnschedCache 存在的情况 -func TestRateLimitService_HandleOAuth401_WithTempUnschedCache(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - invalidator := &tokenCacheInvalidatorRecorder{} - tempCache := &tempUnschedCacheStub{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache) - service.SetTokenCacheInvalidator(invalidator) - account := &Account{ - ID: 202, - Platform: PlatformGemini, - Type: AccountTypeOAuth, - } - - result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error") - - require.True(t, result) - require.Equal(t, 1, repo.tempCalls) - require.Equal(t, 1, tempCache.setCalls) -} - -// TestRateLimitService_HandleOAuth401_TempUnschedCacheError 测试 tempUnschedCache 设置失败的情况 -func TestRateLimitService_HandleOAuth401_TempUnschedCacheError(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - invalidator := &tokenCacheInvalidatorRecorder{} - tempCache := &tempUnschedCacheStub{setErr: errors.New("cache error")} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache) - service.SetTokenCacheInvalidator(invalidator) - account := &Account{ - ID: 203, - Platform: PlatformGemini, - Type: AccountTypeOAuth, - } - - result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error") - - require.True(t, result) // 缓存错误不影响主流程 - require.Equal(t, 1, repo.tempCalls) -} - -// tempUnschedCacheStub 用于测试的 TempUnschedCache stub -type tempUnschedCacheStub struct { - setCalls int - setErr error -} - -func (c *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) { - return nil, nil -} - -func (c *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { - c.setCalls++ - return c.setErr -} - -func (c *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error { - return nil -} - -// TestRateLimitService_OAuth401Cooldown 测试 oauth401Cooldown 函数 -func TestRateLimitService_OAuth401Cooldown(t *testing.T) { - tests := []struct { - name string - cfg *config.Config - expected time.Duration - }{ - { - name: "default_when_config_zero", - cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 0}}, - expected: 5 * time.Minute, - }, - { - name: "custom_cooldown_10_minutes", - cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 10}}, - expected: 10 * time.Minute, - }, - { - name: "custom_cooldown_1_minute", - cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 1}}, - expected: 1 * time.Minute, - }, - { - name: "negative_value_uses_default", - cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: -5}}, - expected: 5 * time.Minute, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - service := NewRateLimitService(nil, nil, tt.cfg, nil, nil) - result := service.oauth401Cooldown() - require.Equal(t, tt.expected, result) - }) - } -} - -// TestRateLimitService_OAuth401Cooldown_NilConfig 测试 cfg 为 nil 的情况 -func TestRateLimitService_OAuth401Cooldown_NilConfig(t *testing.T) { - service := &RateLimitService{cfg: nil} - result := service.oauth401Cooldown() - require.Equal(t, 5*time.Minute, result) -} - -// TestRateLimitService_HandleOAuth401_WithCustomCooldown 测试自定义 cooldown 配置 -func TestRateLimitService_HandleOAuth401_WithCustomCooldown(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - cfg := &config.Config{ - RateLimit: config.RateLimitConfig{ - OAuth401CooldownMinutes: 15, - }, - } - service := NewRateLimitService(repo, nil, cfg, nil, nil) - account := &Account{ - ID: 204, - Platform: PlatformAntigravity, - Type: AccountTypeOAuth, - } - - start := time.Now() - result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error") - - require.True(t, result) - require.WithinDuration(t, start.Add(15*time.Minute), repo.tempUntil, 10*time.Second) -} - -// TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg 测试 upstreamMsg 为空的情况 -func TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg(t *testing.T) { - repo := &rateLimitAccountRepoStub{} - service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) - account := &Account{ - ID: 205, - Platform: PlatformGemini, - Type: AccountTypeOAuth, - } - - result := service.handleOAuth401TempUnschedulable(context.Background(), account, "") - - require.True(t, result) - require.Contains(t, repo.tempReason, "Authentication failed (401)") -} diff --git a/config.yaml b/config.yaml index bd399874..424ce9eb 100644 --- a/config.yaml +++ b/config.yaml @@ -387,9 +387,6 @@ rate_limit: # Cooldown time (in minutes) when upstream returns 529 (overloaded) # 上游返回 529(过载)时的冷却时间(分钟) overload_cooldown_minutes: 10 - # Cooldown time (in minutes) for OAuth 401 temporary unschedulable - # OAuth 401 临时不可调度冷却时间(分钟) - oauth_401_cooldown_minutes: 5 # ============================================================================= # Pricing Data Source (Optional) diff --git a/deploy/.env.example b/deploy/.env.example index 3af969ef..f21a3c62 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -76,9 +76,6 @@ JWT_EXPIRE_HOUR=24 # Cooldown time (in minutes) when upstream returns 529 (overloaded) # 上游返回 529(过载)时的冷却时间(分钟) RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10 -# Cooldown time (in minutes) for OAuth 401 temporary unschedulable -# OAuth 401 临时不可调度冷却时间(分钟) -RATE_LIMIT_OAUTH_401_COOLDOWN_MINUTES=5 # ----------------------------------------------------------------------------- # Gateway Scheduling (Optional) diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index fa8a30c7..ce2439f4 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -429,9 +429,6 @@ rate_limit: # Cooldown time (in minutes) when upstream returns 529 (overloaded) # 上游返回 529(过载)时的冷却时间(分钟) overload_cooldown_minutes: 10 - # Cooldown time (in minutes) for OAuth 401 temporary unschedulable - # OAuth 401 临时不可调度冷却时间(分钟) - oauth_401_cooldown_minutes: 5 # ============================================================================= # Pricing Data Source (Optional)