From ad7c10727a7721e07a3fc57e6fafee783293c724 Mon Sep 17 00:00:00 2001 From: Wang Lvyuan <74089601+LvyuanW@users.noreply.github.com> Date: Mon, 23 Mar 2026 03:49:28 +0800 Subject: [PATCH 1/2] fix(account): preserve runtime state during credentials-only updates --- backend/internal/repository/account_repo.go | 11 +++ .../account_credentials_persistence.go | 30 ++++++++ .../service/antigravity_token_provider.go | 2 +- backend/internal/service/crs_sync_service.go | 18 ++--- .../internal/service/gemini_token_provider.go | 2 +- backend/internal/service/oauth_refresh_api.go | 3 +- .../service/oauth_refresh_api_test.go | 56 ++++++++++++--- backend/internal/service/ratelimit_service.go | 2 +- .../service/ratelimit_service_401_test.go | 34 ++++++++- backend/internal/service/sora_sdk_client.go | 2 +- .../internal/service/token_refresh_service.go | 3 +- .../service/token_refresh_service_test.go | 69 +++++++++++++++++-- 12 files changed, 195 insertions(+), 37 deletions(-) create mode 100644 backend/internal/service/account_credentials_persistence.go diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 35b908de..d2bd8650 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account return nil } +func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + _, err := r.client.Account.UpdateOneID(id). + SetCredentials(normalizeJSONMap(credentials)). + Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrAccountNotFound, nil) + } + r.syncSchedulerAccountSnapshot(ctx, id) + return nil +} + func (r *accountRepository) Delete(ctx context.Context, id int64) error { groupIDs, err := r.loadAccountGroupIDs(ctx, id) if err != nil { diff --git a/backend/internal/service/account_credentials_persistence.go b/backend/internal/service/account_credentials_persistence.go new file mode 100644 index 00000000..916df536 --- /dev/null +++ b/backend/internal/service/account_credentials_persistence.go @@ -0,0 +1,30 @@ +package service + +import "context" + +type accountCredentialsUpdater interface { + UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error +} + +func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error { + if repo == nil || account == nil { + return nil + } + + account.Credentials = cloneCredentials(credentials) + if updater, ok := any(repo).(accountCredentialsUpdater); ok { + return updater.UpdateCredentials(ctx, account.ID, account.Credentials) + } + return repo.Update(ctx, account) +} + +func cloneCredentials(in map[string]any) map[string]any { + if in == nil { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 5e53f434..1b360d93 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * p.markBackfillAttempted(account.ID) if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { account.Credentials["project_id"] = projectID - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil { slog.Warn("antigravity_project_id_backfill_persist_failed", "account_id", account.ID, "error", updateErr, diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 6a916740..b69b0639 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after creation if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } } item.Action = "created" @@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after update if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } } @@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } // 🔄 Refresh OAuth token after creation if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } item.Action = "created" result.Created++ @@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after update if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } item.Action = "updated" @@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput continue } if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } item.Action = "created" result.Created++ @@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } item.Action = "updated" diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 1dab67c4..7add3460 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou if tierID != "" { account.Credentials["tier_id"] = tierID } - _ = p.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials) } } diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 17b9128c..5dbba638 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( // 5. 设置版本号 + 更新 DB if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() - freshAccount.Credentials = newCredentials - if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil { + if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil { slog.Error("oauth_refresh_update_failed", "account_id", freshAccount.ID, "error", updateErr, diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go index 6cf9371f..c3b38ddf 100644 --- a/backend/internal/service/oauth_refresh_api_test.go +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -16,10 +16,11 @@ import ( // refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. type refreshAPIAccountRepo struct { mockAccountRepoForGemini - account *Account // returned by GetByID - getByIDErr error - updateErr error - updateCalls int + account *Account // returned by GetByID + getByIDErr error + updateErr error + updateCalls int + updateCredentialsCalls int } func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { @@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error { return r.updateErr } +func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error { + r.updateCalls++ + r.updateCredentialsCalls++ + if r.updateErr != nil { + return r.updateErr + } + if r.account == nil || r.account.ID != id { + r.account = &Account{ID: id} + } + r.account.Credentials = cloneCredentials(credentials) + return nil +} + // refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. type refreshAPIExecutorStub struct { needsRefresh bool @@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) { require.Equal(t, "new-token", result.NewCredentials["access_token"]) require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set require.Equal(t, 1, repo.updateCalls) // DB updated - require.Equal(t, 1, cache.releaseCalls) // lock released + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 1, cache.releaseCalls) // lock released require.Equal(t, 1, executor.refreshCalls) } +func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) { + resetAt := time.Now().Add(45 * time.Minute) + account := &Account{ + ID: 11, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + RateLimitResetAt: &resetAt, + } + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "safe-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.NotNil(t, repo.account.RateLimitResetAt) + require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second) +} + func TestRefreshIfNeeded_LockHeld(t *testing.T) { account := &Account{ID: 2, Platform: PlatformAnthropic} repo := &refreshAPIAccountRepo{account: account} @@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) { require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "invalid_grant") - require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error + require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error require.Equal(t, 1, cache.releaseCalls) // lock still released via defer } @@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) { result := MergeCredentials(old, new) - require.Equal(t, "new-token", result["access_token"]) // overridden - require.Equal(t, "old-refresh", result["refresh_token"]) // preserved + require.Equal(t, "new-token", result["access_token"]) // overridden + require.Equal(t, "old-refresh", result["refresh_token"]) // preserved } // ========== BuildClaudeAccountCredentials tests ========== diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 5c6c26e1..afe5816d 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc account.Credentials = make(map[string]any) } account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) - if err := s.accountRepo.Update(ctx, account); err != nil { + if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil { slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) } else { slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 4a6e5d6c..67b22e52 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -15,9 +15,11 @@ import ( type rateLimitAccountRepoStub struct { mockAccountRepoForGemini - setErrorCalls int - tempCalls int - lastErrorMsg string + setErrorCalls int + tempCalls int + updateCredentialsCalls int + lastCredentials map[string]any + lastErrorMsg string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { @@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id return nil } +func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + r.updateCredentialsCalls++ + r.lastCredentials = cloneCredentials(credentials) + return nil +} + type tokenCacheInvalidatorRecorder struct { accounts []*Account err error @@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin require.True(t, shouldDisable) require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 1, repo.tempCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) require.Len(t, invalidator.accounts, 1) } @@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { require.Equal(t, 1, repo.setErrorCalls) require.Empty(t, invalidator.accounts) } + +func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 103, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.NotEmpty(t, repo.lastCredentials["expires_at"]) +} diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go index f9221c5b..6243f867 100644 --- a/backend/internal/service/sora_sdk_client.go +++ b/backend/internal/service/sora_sdk_client.go @@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun } if c.accountRepo != nil { - if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() { + if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() { c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) } } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 582afcd3..24b7424f 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc newCredentials, err = refresher.Refresh(ctx, account) if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() - account.Credentials = newCredentials - if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { + if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil { return fmt.Errorf("failed to save credentials: %w", saveErr) } } diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index f48de65e..60ba4a96 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -14,19 +14,40 @@ import ( type tokenRefreshAccountRepo struct { mockAccountRepoForGemini - updateCalls int - setErrorCalls int - clearTempCalls int - lastAccount *Account - updateErr error + updateCalls int + fullUpdateCalls int + updateCredentialsCalls int + setErrorCalls int + clearTempCalls int + lastAccount *Account + updateErr error } func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { r.updateCalls++ + r.fullUpdateCalls++ r.lastAccount = account return r.updateErr } +func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + r.updateCalls++ + r.updateCredentialsCalls++ + if r.updateErr != nil { + return r.updateErr + } + cloned := cloneCredentials(credentials) + if r.accountsByID != nil { + if acc, ok := r.accountsByID[id]; ok && acc != nil { + acc.Credentials = cloned + r.lastAccount = acc + return nil + } + } + r.lastAccount = &Account{ID: id, Credentials: cloned} + return nil +} + func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { r.setErrorCalls++ return nil @@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 0, repo.fullUpdateCalls) require.Equal(t, 1, invalidator.calls) require.Equal(t, "new-token", account.GetCredential("access_token")) } @@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 } +func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + resetAt := time.Now().Add(30 * time.Minute) + account := &Account{ + ID: 17, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + RateLimitResetAt: &resetAt, + Credentials: map[string]any{ + "access_token": "old-token", + }, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "new-token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 0, repo.fullUpdateCalls) + require.NotNil(t, account.RateLimitResetAt) + require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second) +} + // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")} @@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) - require.Equal(t, 1, repo.clearTempCalls) // DB 清除 + require.Equal(t, 1, repo.clearTempCalls) // DB 清除 require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 } From fef9259aaa9e2767d06f284d5eb91470901b3745 Mon Sep 17 00:00:00 2001 From: Wang Lvyuan <74089601+LvyuanW@users.noreply.github.com> Date: Mon, 23 Mar 2026 03:50:03 +0800 Subject: [PATCH 2/2] fix(openai): recheck runtime state from db before final account selection --- .../service/openai_account_scheduler.go | 9 +++ .../service/openai_account_scheduler_test.go | 55 ++++++++++++++ .../service/openai_gateway_service.go | 76 ++++++++++++++----- .../service/openai_ws_account_sticky_test.go | 52 +++++++++++++ .../internal/service/openai_ws_forwarder.go | 5 ++ 5 files changed, 177 insertions(+), 20 deletions(-) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 789888cb..37e7ed2c 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } + account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel) + if account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { @@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { continue } + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if acquireErr != nil { return nil, len(candidates), topK, loadSkew, acquireErr diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 977c4ee8..088815ed 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa require.Equal(t, int64(32002), account.ID) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(10103) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{staleSticky, staleBackup}, + accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, + } + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + cache: cache, + cfg: &config.Config{}, + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(33002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10104) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{stalePrimary, staleSecondary}, + accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary}, + } + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + cfg: &config.Config{}, + schedulerSnapshot: snapshotService, + } + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(34002), account.ID) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { ctx := context.Background() groupID := int64(9) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 4e96cf05..9aed7551 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account @@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ if fresh == nil { continue } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel) + if fresh == nil { + continue + } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used @@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } } } @@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. return fresh } +func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account { + if account == nil { + return nil + } + if s.schedulerSnapshot == nil || s.accountRepo == nil { + return account + } + + latest, err := s.accountRepo.GetByID(ctx, account.ID) + if err != nil || latest == nil { + return nil + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now()) + if !latest.IsSchedulable() || !latest.IsOpenAI() { + return nil + } + if requestedModel != "" && !latest.IsModelSupported(requestedModel) { + return nil + } + return latest +} + func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { var ( account *Account diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 9a8803d3..a5b97ca9 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss( require.Zero(t, boundAccountID) } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) { + ctx := context.Background() + groupID := int64(24) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleAccount := &Account{ + ID: 13, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + dbAccount := Account{ + ID: 13, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + RateLimitResetAt: &rateLimitedUntil, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + snapshotCache := &openAISnapshotCacheStub{ + accountsByID: map[int64]*Account{dbAccount.ID: staleAccount}, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连") + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl") + require.NoError(t, getErr) + require.Zero(t, boundAccountID) +} + func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { ctx := context.Background() groupID := int64(23) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 814ec0bd..4f1837c4 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil, nil } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired {