Merge pull request #723 from zqq-nuli/fix/oauth-401-temp-unschedulable

fix: OAuth 401 不再永久锁死账号,改用临时不可调度实现自动恢复
This commit is contained in:
Wesley Liddick
2026-03-03 16:33:07 +08:00
committed by GitHub
7 changed files with 175 additions and 49 deletions

View File

@@ -222,7 +222,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)

View File

@@ -873,6 +873,7 @@ type DefaultConfig struct {
type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟)
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
@@ -1260,6 +1261,7 @@ func setDefaults() {
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit避免分支漂移
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")

View File

@@ -146,13 +146,29 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
}
// 3. 临时不可调度,替代 SetError保持 status=active 让刷新服务能拾取)
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "OAuth 401: " + upstreamMsg
}
cooldownMinutes := s.cfg.RateLimit.OAuth401CooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
}
shouldDisable = true
} else {
// 非 OAuth 账号APIKey保持原有 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
}
case 402:
// 支付要求:余额不足或计费问题,停止调度
msg := "Payment required (402): insufficient balance or billing issue"

View File

@@ -41,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct {
name string
platform string
@@ -76,9 +76,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
}
@@ -98,7 +97,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
}

View File

@@ -19,6 +19,7 @@ type TokenRefreshService struct {
cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
stopCh chan struct{}
wg sync.WaitGroup
@@ -34,12 +35,14 @@ func NewTokenRefreshService(
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
tempUnschedCache TempUnschedCache,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator,
schedulerCache: schedulerCache,
tempUnschedCache: tempUnschedCache,
stopCh: make(chan struct{}),
}
@@ -231,6 +234,26 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
}
}
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
"account_id", account.ID,
"error", clearErr,
)
} else {
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
}
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
if s.tempUnschedCache != nil {
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
"account_id", account.ID,
"error", clearErr,
)
}
}
}
// 对所有 OAuth 账号调用缓存失效InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
@@ -257,8 +280,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
return nil
}
// Antigravity 账户:不可重试错误直接标记 error 状态并返回
if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) {
// 不可重试错误invalid_grant/invalid_client 等)直接标记 error 状态并返回
if isNonRetryableRefreshError(err) {
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
slog.Error("token_refresh.set_error_status_failed",
@@ -285,23 +308,13 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
}
// Antigravity 账户:其他错误仅记录日志,不标记 error可能是临时网络问题
// 其他平台账户:重试失败后标记 error
if account.Platform == PlatformAntigravity {
slog.Warn("token_refresh.retry_exhausted_antigravity",
// 可重试错误耗尽:仅记录日志,不标记 error可能是临时网络问题,下个周期继续重试
slog.Warn("token_refresh.retry_exhausted",
"account_id", account.ID,
"platform", account.Platform,
"max_retries", s.cfg.MaxRetries,
"error", lastErr,
)
} else {
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Error("token_refresh.set_error_status_failed",
"account_id", account.ID,
"error", err,
)
}
}
return lastErr
}

View File

@@ -16,6 +16,7 @@ type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
updateCalls int
setErrorCalls int
clearTempCalls int
lastAccount *Account
updateErr error
}
@@ -31,6 +32,11 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM
return nil
}
func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error {
r.clearTempCalls++
return nil
}
type tokenCacheInvalidatorStub struct {
calls int
err error
@@ -41,6 +47,23 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account
return s.err
}
type tempUnschedCacheStub struct {
deleteCalls int
}
func (s *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
return nil
}
func (s *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
return nil, nil
}
func (s *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
s.deleteCalls++
return nil
}
type tokenRefresherStub struct {
credentials map[string]any
err error
@@ -70,7 +93,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 5,
Platform: PlatformGemini,
@@ -98,7 +121,7 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 6,
Platform: PlatformGemini,
@@ -124,7 +147,7 @@ func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 7,
Platform: PlatformGemini,
@@ -151,7 +174,7 @@ func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
@@ -179,7 +202,7 @@ func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 9,
Platform: PlatformGemini,
@@ -207,7 +230,7 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // OpenAI OAuth 账户
@@ -235,7 +258,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 11,
Platform: PlatformGemini,
@@ -254,7 +277,7 @@ func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试可重试错误耗尽不标记 error
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
@@ -264,7 +287,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 12,
Platform: PlatformGemini,
@@ -278,7 +301,7 @@ func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态
require.Equal(t, 0, repo.setErrorCalls) // 可重试错误耗尽不标记 error下个周期继续重试
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
@@ -291,7 +314,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testin
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
@@ -318,7 +341,7 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg)
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
@@ -335,6 +358,77 @@ func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *te
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable 测试刷新成功后清除临时不可调度DB + Redis
func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
tempCache := &tempUnschedCacheStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
until := time.Now().Add(10 * time.Minute)
account := &Account{
ID: 15,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
TempUnschedulableUntil: &until,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
}
// TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms 测试所有平台不可重试错误都 SetError
func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "anthropic", platform: PlatformAnthropic},
{name: "openai", platform: PlatformOpenAI},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 3,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 16,
Platform: tt.platform,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
})
}
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct {

View File

@@ -48,8 +48,9 @@ func ProvideTokenRefreshService(
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
tempUnschedCache TempUnschedCache,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg)
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
svc.SetSoraAccountRepo(soraAccountRepo)
svc.Start()