Files
xinghuoapi/backend/internal/service/token_refresh_service_test.go
zqq61 ec6bcfeb83 fix: OAuth 401 不再永久锁死账号,改用临时不可调度实现自动恢复
OAuth 账号收到 401 时,原逻辑同时设置 expires_at=now() 和 SetError(),
但刷新服务只查询 status=active 的账号,导致 error 状态的账号永远无法
被刷新服务拾取,expires_at=now() 实际上是死代码。

修复:
- OAuth 401 使用 SetTempUnschedulable 替代 SetError,保持 status=active
- 新增 oauth_401_cooldown_minutes 配置项(默认 10 分钟)
- 刷新成功后同步清除 DB 和 Redis 中的临时不可调度状态
- 不可重试错误检查(invalid_grant 等)从 Antigravity 推广到所有平台
- 可重试错误耗尽后不再标记 error,下个刷新周期继续重试

恢复流程:
OAuth 401 → temp_unschedulable + expires_at=now → 刷新服务拾取
  → 成功: 清除 temp_unschedulable → 自动恢复
  → invalid_grant: SetError → 永久禁用
  → 网络错误: 仅记日志 → 下周期重试
2026-03-02 22:54:38 +08:00

456 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
updateCalls int
setErrorCalls int
clearTempCalls int
lastAccount *Account
updateErr error
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.lastAccount = account
return r.updateErr
}
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
return nil
}
func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error {
r.clearTempCalls++
return nil
}
type tokenCacheInvalidatorStub struct {
calls int
err error
}
func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
s.calls++
return s.err
}
type tempUnschedCacheStub struct {
deleteCalls int
}
func (s *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
return nil
}
func (s *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
return nil, nil
}
func (s *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
s.deleteCalls++
return nil
}
type tokenRefresherStub struct {
credentials map[string]any
err error
}
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
return true
}
func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool {
return true
}
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
if r.err != nil {
return nil, r.err
}
return r.credentials, nil
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 5,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 6,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls)
}
func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
account := &Account{
ID: 7,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
}
// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 8,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "ag-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 9,
Platform: PlatformGemini,
Type: AccountTypeAPIKey, // 非 OAuth
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试所有 OAuth 平台都触发缓存失效
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 10,
Platform: PlatformOpenAI, // OpenAI OAuth 账户
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 11,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Contains(t, err.Error(), "failed to save credentials")
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
}
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试可重试错误耗尽不标记 error
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 2,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 12,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("refresh failed"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
require.Equal(t, 0, repo.setErrorCalls) // 可重试错误耗尽不标记 error下个周期继续重试
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 13,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("network error"), // 可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 3,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 14,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 0, repo.updateCalls)
require.Equal(t, 0, invalidator.calls)
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
}
// TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable 测试刷新成功后清除临时不可调度DB + Redis
func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
tempCache := &tempUnschedCacheStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 1,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, tempCache)
until := time.Now().Add(10 * time.Minute)
account := &Account{
ID: 15,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
TempUnschedulableUntil: &until,
}
refresher := &tokenRefresherStub{
credentials: map[string]any{
"access_token": "new-token",
},
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.NoError(t, err)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
}
// TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms 测试所有平台不可重试错误都 SetError
func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "anthropic", platform: PlatformAnthropic},
{name: "openai", platform: PlatformOpenAI},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
cfg := &config.Config{
TokenRefresh: config.TokenRefreshConfig{
MaxRetries: 3,
RetryBackoffSeconds: 0,
},
}
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, nil, cfg, nil)
account := &Account{
ID: 16,
Platform: tt.platform,
Type: AccountTypeOAuth,
}
refresher := &tokenRefresherStub{
err: errors.New("invalid_grant: token revoked"),
}
err := service.refreshWithRetry(context.Background(), account, refresher)
require.Error(t, err)
require.Equal(t, 1, repo.setErrorCalls) // 所有平台不可重试错误都应 SetError
})
}
}
// TestIsNonRetryableRefreshError 测试不可重试错误判断
func TestIsNonRetryableRefreshError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{name: "nil_error", err: nil, expected: false},
{name: "network_error", err: errors.New("network timeout"), expected: false},
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNonRetryableRefreshError(tt.err)
require.Equal(t, tt.expected, result)
})
}
}