fix(认证): OAuth 401 直接标记错误状态
- OAuth 401 清理缓存并设置错误状态 - 移除 oauth_401_cooldown_minutes 配置及示例 - 更新 401 相关单测 破坏性变更: OAuth 401 不再临时不可调度,需手动恢复
This commit is contained in:
@@ -435,8 +435,7 @@ type DefaultConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RateLimitConfig struct {
|
type RateLimitConfig struct {
|
||||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||||
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401 临时不可调度冷却时间(分钟)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
||||||
@@ -710,7 +709,6 @@ func setDefaults() {
|
|||||||
|
|
||||||
// RateLimit
|
// RateLimit
|
||||||
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
|
||||||
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 5)
|
|
||||||
|
|
||||||
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
|
// Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查
|
||||||
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
|
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json")
|
||||||
|
|||||||
@@ -73,10 +73,8 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
isOAuth401 := statusCode == 401 && account.Type == AccountTypeOAuth &&
|
|
||||||
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini)
|
|
||||||
tempMatched := false
|
tempMatched := false
|
||||||
if !isOAuth401 || account.IsTempUnschedulableEnabled() {
|
if statusCode != 401 {
|
||||||
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||||
}
|
}
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||||
@@ -87,18 +85,13 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
|
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401:
|
case 401:
|
||||||
if isOAuth401 {
|
if account.Type == AccountTypeOAuth &&
|
||||||
if tempMatched {
|
(account.Platform == PlatformAntigravity || account.Platform == PlatformGemini) {
|
||||||
if s.tokenCacheInvalidator != nil {
|
if s.tokenCacheInvalidator != nil {
|
||||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
shouldDisable = true
|
|
||||||
} else {
|
|
||||||
shouldDisable = s.handleOAuth401TempUnschedulable(ctx, account, upstreamMsg)
|
|
||||||
}
|
}
|
||||||
break
|
|
||||||
}
|
}
|
||||||
msg := "Authentication failed (401): invalid or expired credentials"
|
msg := "Authentication failed (401): invalid or expired credentials"
|
||||||
if upstreamMsg != "" {
|
if upstreamMsg != "" {
|
||||||
@@ -150,63 +143,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
|||||||
return shouldDisable
|
return shouldDisable
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RateLimitService) handleOAuth401TempUnschedulable(ctx context.Context, account *Account, upstreamMsg string) bool {
|
|
||||||
if account == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.tokenCacheInvalidator != nil {
|
|
||||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
|
||||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
until := now.Add(s.oauth401Cooldown())
|
|
||||||
msg := "Authentication failed (401): invalid or expired credentials"
|
|
||||||
if upstreamMsg != "" {
|
|
||||||
msg = "Authentication failed (401): " + upstreamMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
state := &TempUnschedState{
|
|
||||||
UntilUnix: until.Unix(),
|
|
||||||
TriggeredAtUnix: now.Unix(),
|
|
||||||
StatusCode: 401,
|
|
||||||
MatchedKeyword: "oauth_401",
|
|
||||||
RuleIndex: -1, // -1 表示非规则触发,而是 OAuth 401 特殊处理
|
|
||||||
ErrorMessage: msg,
|
|
||||||
}
|
|
||||||
|
|
||||||
reason := ""
|
|
||||||
if raw, err := json.Marshal(state); err == nil {
|
|
||||||
reason = string(raw)
|
|
||||||
}
|
|
||||||
if reason == "" {
|
|
||||||
reason = msg
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
|
||||||
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.tempUnschedCache != nil {
|
|
||||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
|
||||||
slog.Warn("oauth_401_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("oauth_401_temp_unschedulable", "account_id", account.ID, "until", until)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *RateLimitService) oauth401Cooldown() time.Duration {
|
|
||||||
if s != nil && s.cfg != nil && s.cfg.RateLimit.OAuth401CooldownMinutes > 0 {
|
|
||||||
return time.Duration(s.cfg.RateLimit.OAuth401CooldownMinutes) * time.Minute
|
|
||||||
}
|
|
||||||
return 5 * time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreCheckUsage proactively checks local quota before dispatching a request.
|
// PreCheckUsage proactively checks local quota before dispatching a request.
|
||||||
// Returns false when the account should be skipped.
|
// Returns false when the account should be skipped.
|
||||||
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
||||||
|
|||||||
@@ -15,21 +15,19 @@ import (
|
|||||||
|
|
||||||
type rateLimitAccountRepoStub struct {
|
type rateLimitAccountRepoStub struct {
|
||||||
mockAccountRepoForGemini
|
mockAccountRepoForGemini
|
||||||
tempCalls int
|
|
||||||
tempUntil time.Time
|
|
||||||
tempReason string
|
|
||||||
setErrorCalls int
|
setErrorCalls int
|
||||||
}
|
tempCalls int
|
||||||
|
lastErrorMsg string
|
||||||
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
|
||||||
r.tempCalls++
|
|
||||||
r.tempUntil = until
|
|
||||||
r.tempReason = reason
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
r.setErrorCalls++
|
r.setErrorCalls++
|
||||||
|
r.lastErrorMsg = errorMsg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
r.tempCalls++
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
|
|||||||
return r.err
|
return r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testing.T) {
|
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
platform string
|
platform string
|
||||||
@@ -62,17 +60,26 @@ func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testi
|
|||||||
ID: 100,
|
ID: 100,
|
||||||
Platform: tt.platform,
|
Platform: tt.platform,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"temp_unschedulable_enabled": true,
|
||||||
|
"temp_unschedulable_rules": []any{
|
||||||
|
map[string]any{
|
||||||
|
"error_code": 401,
|
||||||
|
"keywords": []any{"unauthorized"},
|
||||||
|
"duration_minutes": 30,
|
||||||
|
"description": "custom rule",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
require.True(t, shouldDisable)
|
require.True(t, shouldDisable)
|
||||||
require.Equal(t, 1, repo.tempCalls)
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
require.Equal(t, 0, repo.setErrorCalls)
|
require.Equal(t, 0, repo.tempCalls)
|
||||||
|
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
|
||||||
require.Len(t, invalidator.accounts, 1)
|
require.Len(t, invalidator.accounts, 1)
|
||||||
require.WithinDuration(t, start.Add(5*time.Minute), repo.tempUntil, 10*time.Second)
|
|
||||||
require.NotEmpty(t, repo.tempReason)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -91,43 +98,10 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
|
|||||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
require.True(t, shouldDisable)
|
require.True(t, shouldDisable)
|
||||||
require.Equal(t, 1, repo.tempCalls)
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
require.Equal(t, 0, repo.setErrorCalls)
|
|
||||||
require.Len(t, invalidator.accounts, 1)
|
require.Len(t, invalidator.accounts, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimitService_HandleUpstreamError_OAuth401CustomRule(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStub{}
|
|
||||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
||||||
service.SetTokenCacheInvalidator(invalidator)
|
|
||||||
account := &Account{
|
|
||||||
ID: 103,
|
|
||||||
Platform: PlatformGemini,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Credentials: map[string]any{
|
|
||||||
"temp_unschedulable_enabled": true,
|
|
||||||
"temp_unschedulable_rules": []any{
|
|
||||||
map[string]any{
|
|
||||||
"error_code": 401,
|
|
||||||
"keywords": []any{"unauthorized"},
|
|
||||||
"duration_minutes": 30,
|
|
||||||
"description": "custom rule",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
|
||||||
|
|
||||||
require.True(t, shouldDisable)
|
|
||||||
require.Equal(t, 1, repo.tempCalls)
|
|
||||||
require.Equal(t, 0, repo.setErrorCalls)
|
|
||||||
require.Len(t, invalidator.accounts, 1)
|
|
||||||
require.WithinDuration(t, start.Add(30*time.Minute), repo.tempUntil, 10*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
||||||
repo := &rateLimitAccountRepoStub{}
|
repo := &rateLimitAccountRepoStub{}
|
||||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||||
@@ -142,212 +116,6 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
|||||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||||
|
|
||||||
require.True(t, shouldDisable)
|
require.True(t, shouldDisable)
|
||||||
require.Equal(t, 0, repo.tempCalls)
|
|
||||||
require.Equal(t, 1, repo.setErrorCalls)
|
require.Equal(t, 1, repo.setErrorCalls)
|
||||||
require.Empty(t, invalidator.accounts)
|
require.Empty(t, invalidator.accounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRateLimitService_HandleOAuth401_NilAccount 测试 account 为 nil 的情况
|
|
||||||
func TestRateLimitService_HandleOAuth401_NilAccount(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStub{}
|
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
||||||
|
|
||||||
result := service.handleOAuth401TempUnschedulable(context.Background(), nil, "error")
|
|
||||||
|
|
||||||
require.False(t, result)
|
|
||||||
require.Equal(t, 0, repo.tempCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_HandleOAuth401_NilInvalidator 测试 tokenCacheInvalidator 为 nil 的情况
|
|
||||||
func TestRateLimitService_HandleOAuth401_NilInvalidator(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStub{}
|
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
||||||
// 不设置 tokenCacheInvalidator
|
|
||||||
account := &Account{
|
|
||||||
ID: 200,
|
|
||||||
Platform: PlatformGemini,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
|
||||||
|
|
||||||
require.True(t, result)
|
|
||||||
require.Equal(t, 1, repo.tempCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed 测试 SetTempUnschedulable 失败的情况
|
|
||||||
func TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStubWithError{
|
|
||||||
setTempErr: errors.New("db error"),
|
|
||||||
}
|
|
||||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
||||||
service.SetTokenCacheInvalidator(invalidator)
|
|
||||||
account := &Account{
|
|
||||||
ID: 201,
|
|
||||||
Platform: PlatformGemini,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
|
||||||
|
|
||||||
require.False(t, result) // 失败应返回 false
|
|
||||||
require.Len(t, invalidator.accounts, 1) // 但 invalidator 仍然被调用
|
|
||||||
}
|
|
||||||
|
|
||||||
// rateLimitAccountRepoStubWithError 支持返回错误的 stub
|
|
||||||
type rateLimitAccountRepoStubWithError struct {
|
|
||||||
mockAccountRepoForGemini
|
|
||||||
setTempErr error
|
|
||||||
setErrorCalls int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rateLimitAccountRepoStubWithError) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
|
||||||
return r.setTempErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rateLimitAccountRepoStubWithError) SetError(ctx context.Context, id int64, errorMsg string) error {
|
|
||||||
r.setErrorCalls++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_HandleOAuth401_WithTempUnschedCache 测试 tempUnschedCache 存在的情况
|
|
||||||
func TestRateLimitService_HandleOAuth401_WithTempUnschedCache(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStub{}
|
|
||||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
||||||
tempCache := &tempUnschedCacheStub{}
|
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
|
|
||||||
service.SetTokenCacheInvalidator(invalidator)
|
|
||||||
account := &Account{
|
|
||||||
ID: 202,
|
|
||||||
Platform: PlatformGemini,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
|
||||||
|
|
||||||
require.True(t, result)
|
|
||||||
require.Equal(t, 1, repo.tempCalls)
|
|
||||||
require.Equal(t, 1, tempCache.setCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_HandleOAuth401_TempUnschedCacheError 测试 tempUnschedCache 设置失败的情况
|
|
||||||
func TestRateLimitService_HandleOAuth401_TempUnschedCacheError(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStub{}
|
|
||||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
|
||||||
tempCache := &tempUnschedCacheStub{setErr: errors.New("cache error")}
|
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
|
|
||||||
service.SetTokenCacheInvalidator(invalidator)
|
|
||||||
account := &Account{
|
|
||||||
ID: 203,
|
|
||||||
Platform: PlatformGemini,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
|
||||||
|
|
||||||
require.True(t, result) // 缓存错误不影响主流程
|
|
||||||
require.Equal(t, 1, repo.tempCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
// tempUnschedCacheStub 用于测试的 TempUnschedCache stub
|
|
||||||
type tempUnschedCacheStub struct {
|
|
||||||
setCalls int
|
|
||||||
setErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
|
|
||||||
c.setCalls++
|
|
||||||
return c.setErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_OAuth401Cooldown 测试 oauth401Cooldown 函数
|
|
||||||
func TestRateLimitService_OAuth401Cooldown(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cfg *config.Config
|
|
||||||
expected time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "default_when_config_zero",
|
|
||||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 0}},
|
|
||||||
expected: 5 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "custom_cooldown_10_minutes",
|
|
||||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 10}},
|
|
||||||
expected: 10 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "custom_cooldown_1_minute",
|
|
||||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 1}},
|
|
||||||
expected: 1 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative_value_uses_default",
|
|
||||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: -5}},
|
|
||||||
expected: 5 * time.Minute,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
service := NewRateLimitService(nil, nil, tt.cfg, nil, nil)
|
|
||||||
result := service.oauth401Cooldown()
|
|
||||||
require.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_OAuth401Cooldown_NilConfig 测试 cfg 为 nil 的情况
|
|
||||||
func TestRateLimitService_OAuth401Cooldown_NilConfig(t *testing.T) {
|
|
||||||
service := &RateLimitService{cfg: nil}
|
|
||||||
result := service.oauth401Cooldown()
|
|
||||||
require.Equal(t, 5*time.Minute, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_HandleOAuth401_WithCustomCooldown 测试自定义 cooldown 配置
|
|
||||||
func TestRateLimitService_HandleOAuth401_WithCustomCooldown(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStub{}
|
|
||||||
cfg := &config.Config{
|
|
||||||
RateLimit: config.RateLimitConfig{
|
|
||||||
OAuth401CooldownMinutes: 15,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
service := NewRateLimitService(repo, nil, cfg, nil, nil)
|
|
||||||
account := &Account{
|
|
||||||
ID: 204,
|
|
||||||
Platform: PlatformAntigravity,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
|
||||||
|
|
||||||
require.True(t, result)
|
|
||||||
require.WithinDuration(t, start.Add(15*time.Minute), repo.tempUntil, 10*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg 测试 upstreamMsg 为空的情况
|
|
||||||
func TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg(t *testing.T) {
|
|
||||||
repo := &rateLimitAccountRepoStub{}
|
|
||||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
||||||
account := &Account{
|
|
||||||
ID: 205,
|
|
||||||
Platform: PlatformGemini,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "")
|
|
||||||
|
|
||||||
require.True(t, result)
|
|
||||||
require.Contains(t, repo.tempReason, "Authentication failed (401)")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -387,9 +387,6 @@ rate_limit:
|
|||||||
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
||||||
# 上游返回 529(过载)时的冷却时间(分钟)
|
# 上游返回 529(过载)时的冷却时间(分钟)
|
||||||
overload_cooldown_minutes: 10
|
overload_cooldown_minutes: 10
|
||||||
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
|
|
||||||
# OAuth 401 临时不可调度冷却时间(分钟)
|
|
||||||
oauth_401_cooldown_minutes: 5
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Pricing Data Source (Optional)
|
# Pricing Data Source (Optional)
|
||||||
|
|||||||
@@ -76,9 +76,6 @@ JWT_EXPIRE_HOUR=24
|
|||||||
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
||||||
# 上游返回 529(过载)时的冷却时间(分钟)
|
# 上游返回 529(过载)时的冷却时间(分钟)
|
||||||
RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
|
RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
|
||||||
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
|
|
||||||
# OAuth 401 临时不可调度冷却时间(分钟)
|
|
||||||
RATE_LIMIT_OAUTH_401_COOLDOWN_MINUTES=5
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Gateway Scheduling (Optional)
|
# Gateway Scheduling (Optional)
|
||||||
|
|||||||
@@ -429,9 +429,6 @@ rate_limit:
|
|||||||
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
# Cooldown time (in minutes) when upstream returns 529 (overloaded)
|
||||||
# 上游返回 529(过载)时的冷却时间(分钟)
|
# 上游返回 529(过载)时的冷却时间(分钟)
|
||||||
overload_cooldown_minutes: 10
|
overload_cooldown_minutes: 10
|
||||||
# Cooldown time (in minutes) for OAuth 401 temporary unschedulable
|
|
||||||
# OAuth 401 临时不可调度冷却时间(分钟)
|
|
||||||
oauth_401_cooldown_minutes: 5
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Pricing Data Source (Optional)
|
# Pricing Data Source (Optional)
|
||||||
|
|||||||
Reference in New Issue
Block a user