fix(account): preserve runtime state during credentials-only updates
This commit is contained in:
@@ -14,19 +14,40 @@ import (
|
||||
|
||||
type tokenRefreshAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updateCalls int
|
||||
setErrorCalls int
|
||||
clearTempCalls int
|
||||
lastAccount *Account
|
||||
updateErr error
|
||||
updateCalls int
|
||||
fullUpdateCalls int
|
||||
updateCredentialsCalls int
|
||||
setErrorCalls int
|
||||
clearTempCalls int
|
||||
lastAccount *Account
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
|
||||
r.updateCalls++
|
||||
r.fullUpdateCalls++
|
||||
r.lastAccount = account
|
||||
return r.updateErr
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
|
||||
r.updateCalls++
|
||||
r.updateCredentialsCalls++
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
cloned := cloneCredentials(credentials)
|
||||
if r.accountsByID != nil {
|
||||
if acc, ok := r.accountsByID[id]; ok && acc != nil {
|
||||
acc.Credentials = cloned
|
||||
r.lastAccount = acc
|
||||
return nil
|
||||
}
|
||||
}
|
||||
r.lastAccount = &Account{ID: id, Credentials: cloned}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
return nil
|
||||
@@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Equal(t, 0, repo.fullUpdateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
require.Equal(t, "new-token", account.GetCredential("access_token"))
|
||||
}
|
||||
@@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil)
|
||||
resetAt := time.Now().Add(30 * time.Minute)
|
||||
account := &Account{
|
||||
ID: 17,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
RateLimitResetAt: &resetAt,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
},
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "new-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCredentialsCalls)
|
||||
require.Equal(t, 0, repo.fullUpdateCalls)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
|
||||
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
|
||||
@@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
||||
require.Equal(t, 1, repo.clearTempCalls) // DB 清除
|
||||
require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user