fix(认证): OAuth 401 直接标记错误状态

- OAuth 401 清理缓存并设置错误状态

- 移除 oauth_401_cooldown_minutes 配置及示例

- 更新 401 相关单测

破坏性变更: OAuth 401 不再临时不可调度,需手动恢复
This commit is contained in:
yangjianbo
2026-01-15 15:06:34 +08:00
parent 52ad7c6e9c
commit a458e684bc
6 changed files with 31 additions and 338 deletions

View File

@@ -15,21 +15,19 @@ import (
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
tempCalls int
tempUntil time.Time
tempReason string
setErrorCalls int
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
r.tempUntil = until
r.tempReason = reason
return nil
tempCalls int
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.lastErrorMsg = errorMsg
return nil
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
@@ -43,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testing.T) {
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
tests := []struct {
name string
platform string
@@ -62,17 +60,26 @@ func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testi
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
start := time.Now()
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Len(t, invalidator.accounts, 1)
require.WithinDuration(t, start.Add(5*time.Minute), repo.tempUntil, 10*time.Second)
require.NotEmpty(t, repo.tempReason)
})
}
}
@@ -91,43 +98,10 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_OAuth401CustomRule(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 103,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
start := time.Now()
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
require.WithinDuration(t, start.Add(30*time.Minute), repo.tempUntil, 10*time.Second)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
@@ -142,212 +116,6 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.tempCalls)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}
// TestRateLimitService_HandleOAuth401_NilAccount 测试 account 为 nil 的情况
func TestRateLimitService_HandleOAuth401_NilAccount(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
result := service.handleOAuth401TempUnschedulable(context.Background(), nil, "error")
require.False(t, result)
require.Equal(t, 0, repo.tempCalls)
}
// TestRateLimitService_HandleOAuth401_NilInvalidator 测试 tokenCacheInvalidator 为 nil 的情况
func TestRateLimitService_HandleOAuth401_NilInvalidator(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
// 不设置 tokenCacheInvalidator
account := &Account{
ID: 200,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result)
require.Equal(t, 1, repo.tempCalls)
}
// TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed 测试 SetTempUnschedulable 失败的情况
func TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed(t *testing.T) {
repo := &rateLimitAccountRepoStubWithError{
setTempErr: errors.New("db error"),
}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 201,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.False(t, result) // 失败应返回 false
require.Len(t, invalidator.accounts, 1) // 但 invalidator 仍然被调用
}
// rateLimitAccountRepoStubWithError 支持返回错误的 stub
type rateLimitAccountRepoStubWithError struct {
mockAccountRepoForGemini
setTempErr error
setErrorCalls int
}
func (r *rateLimitAccountRepoStubWithError) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return r.setTempErr
}
func (r *rateLimitAccountRepoStubWithError) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
}
// TestRateLimitService_HandleOAuth401_WithTempUnschedCache 测试 tempUnschedCache 存在的情况
func TestRateLimitService_HandleOAuth401_WithTempUnschedCache(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
tempCache := &tempUnschedCacheStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 202,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, tempCache.setCalls)
}
// TestRateLimitService_HandleOAuth401_TempUnschedCacheError 测试 tempUnschedCache 设置失败的情况
func TestRateLimitService_HandleOAuth401_TempUnschedCacheError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
tempCache := &tempUnschedCacheStub{setErr: errors.New("cache error")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 203,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result) // 缓存错误不影响主流程
require.Equal(t, 1, repo.tempCalls)
}
// tempUnschedCacheStub 用于测试的 TempUnschedCache stub
type tempUnschedCacheStub struct {
setCalls int
setErr error
}
func (c *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
return nil, nil
}
func (c *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
c.setCalls++
return c.setErr
}
func (c *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
return nil
}
// TestRateLimitService_OAuth401Cooldown 测试 oauth401Cooldown 函数
func TestRateLimitService_OAuth401Cooldown(t *testing.T) {
tests := []struct {
name string
cfg *config.Config
expected time.Duration
}{
{
name: "default_when_config_zero",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 0}},
expected: 5 * time.Minute,
},
{
name: "custom_cooldown_10_minutes",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 10}},
expected: 10 * time.Minute,
},
{
name: "custom_cooldown_1_minute",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 1}},
expected: 1 * time.Minute,
},
{
name: "negative_value_uses_default",
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: -5}},
expected: 5 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewRateLimitService(nil, nil, tt.cfg, nil, nil)
result := service.oauth401Cooldown()
require.Equal(t, tt.expected, result)
})
}
}
// TestRateLimitService_OAuth401Cooldown_NilConfig 测试 cfg 为 nil 的情况
func TestRateLimitService_OAuth401Cooldown_NilConfig(t *testing.T) {
service := &RateLimitService{cfg: nil}
result := service.oauth401Cooldown()
require.Equal(t, 5*time.Minute, result)
}
// TestRateLimitService_HandleOAuth401_WithCustomCooldown 测试自定义 cooldown 配置
func TestRateLimitService_HandleOAuth401_WithCustomCooldown(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
cfg := &config.Config{
RateLimit: config.RateLimitConfig{
OAuth401CooldownMinutes: 15,
},
}
service := NewRateLimitService(repo, nil, cfg, nil, nil)
account := &Account{
ID: 204,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
start := time.Now()
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
require.True(t, result)
require.WithinDuration(t, start.Add(15*time.Minute), repo.tempUntil, 10*time.Second)
}
// TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg 测试 upstreamMsg 为空的情况
func TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 205,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "")
require.True(t, result)
require.Contains(t, repo.tempReason, "Authentication failed (401)")
}