diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 509cf13a..03e613d5 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -84,7 +84,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) - accountRepository := repository.NewAccountRepository(client, db) + schedulerCache := repository.NewSchedulerCache(redisClient) + accountRepository := repository.NewAccountRepository(client, db, schedulerCache) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) @@ -127,7 +128,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminRedeemHandler := admin.NewRedeemHandler(adminService) promoHandler := admin.NewPromoHandler(promoService) opsRepository := repository.NewOpsRepository(db) - schedulerCache := repository.NewSchedulerCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index f7725820..84bd7b9e 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -39,9 +39,15 @@ import ( // 设计说明: // - client: Ent 客户端,用于类型安全的 ORM 操作 // - sql: 原生 SQL 执行器,用于复杂查询和批量操作 +// - schedulerCache: 调度器缓存,用于在账号状态变更时同步快照 type accountRepository struct { client *dbent.Client // Ent ORM 客户端 sql sqlExecutor // 原生 SQL 执行接口 + // schedulerCache 用于在账号状态变更时主动同步快照到缓存, + // 确保粘性会话能及时感知账号不可用状态。 + // Used to proactively sync account snapshot to cache when status changes, + // ensuring sticky sessions can promptly detect unavailable accounts. + schedulerCache service.SchedulerCache } type tempUnschedSnapshot struct { @@ -51,14 +57,14 @@ type tempUnschedSnapshot struct { // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 -func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository { - return newAccountRepositoryWithSQL(client, sqlDB) +func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { + return newAccountRepositoryWithSQL(client, sqlDB, schedulerCache) } // newAccountRepositoryWithSQL 是内部构造函数,支持依赖注入 SQL 执行器。 // 这种设计便于单元测试时注入 mock 对象。 -func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accountRepository { - return &accountRepository{client: client, sql: sqlq} +func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor, schedulerCache service.SchedulerCache) *accountRepository { + return &accountRepository{client: client, sql: sqlq, schedulerCache: schedulerCache} } func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { @@ -356,6 +362,9 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) } + if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { + r.syncSchedulerAccountSnapshot(ctx, account.ID) + } return nil } @@ -540,9 +549,32 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } +// syncSchedulerAccountSnapshot 在账号状态变更时主动同步快照到调度器缓存。 +// 当账号被设置为错误、禁用、不可调度或临时不可调度时调用, +// 确保调度器和粘性会话逻辑能及时感知账号的最新状态,避免继续使用不可用账号。 +// +// syncSchedulerAccountSnapshot proactively syncs account snapshot to scheduler cache +// when account status changes. Called when account is set to error, disabled, +// unschedulable, or temporarily unschedulable, ensuring scheduler and sticky session +// logic can promptly detect the latest account state and avoid using unavailable accounts. +func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, accountID int64) { + if r == nil || r.schedulerCache == nil || accountID <= 0 { + return + } + account, err := r.GetByID(ctx, accountID) + if err != nil { + log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) + return + } + if err := r.schedulerCache.SetAccount(ctx, account); err != nil { + log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) + } +} + func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { _, err := r.client.AccountGroup.Create(). SetAccountID(accountID). @@ -864,6 +896,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) } + r.syncSchedulerAccountSnapshot(ctx, id) return nil } @@ -974,6 +1007,9 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) } + if !schedulable { + r.syncSchedulerAccountSnapshot(ctx, id) + } return nil } @@ -1128,6 +1164,18 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil { log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err) } + shouldSync := false + if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) { + shouldSync = true + } + if updates.Schedulable != nil && !*updates.Schedulable { + shouldSync = true + } + if shouldSync { + for _, id := range ids { + r.syncSchedulerAccountSnapshot(ctx, id) + } + } } return rows, nil } diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 250b141d..a054b6d6 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -21,11 +21,56 @@ type AccountRepoSuite struct { repo *accountRepository } +type schedulerCacheRecorder struct { + setAccounts []*service.Account +} + +func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { + return nil, false, nil +} + +func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error { + return nil +} + +func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) { + return nil, nil +} + +func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error { + s.setAccounts = append(s.setAccounts, account) + return nil +} + +func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error { + return nil +} + +func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return nil +} + +func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) { + return true, nil +} + +func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { + return nil, nil +} + +func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) { + return 0, nil +} + +func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error { + return nil +} + func (s *AccountRepoSuite) SetupTest() { s.ctx = context.Background() tx := testEntTx(s.T()) s.client = tx.Client() - s.repo = newAccountRepositoryWithSQL(s.client, tx) + s.repo = newAccountRepositoryWithSQL(s.client, tx, nil) } func TestAccountRepoSuite(t *testing.T) { @@ -73,6 +118,20 @@ func (s *AccountRepoSuite) TestUpdate() { s.Require().Equal("updated", got.Name) } +func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() { + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true}) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + + account.Status = service.StatusDisabled + err := s.repo.Update(s.ctx, account) + s.Require().NoError(err, "Update") + + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status) +} + func (s *AccountRepoSuite) TestDelete() { account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"}) @@ -174,7 +233,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { // 每个 case 重新获取隔离资源 tx := testEntTx(s.T()) client := tx.Client() - repo := newAccountRepositoryWithSQL(client, tx) + repo := newAccountRepositoryWithSQL(client, tx, nil) ctx := context.Background() tt.setup(client) @@ -365,12 +424,38 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { func (s *AccountRepoSuite) TestSetSchedulable() { account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true}) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false)) got, err := s.repo.GetByID(s.ctx, account.ID) s.Require().NoError(err) s.Require().False(got.Schedulable) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) +} + +func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() { + account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true}) + account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true}) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + + disabled := service.StatusDisabled + rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{ + Status: &disabled, + }) + s.Require().NoError(err) + s.Require().Equal(int64(2), rows) + + s.Require().Len(cacheRecorder.setAccounts, 2) + ids := map[int64]struct{}{} + for _, acc := range cacheRecorder.setAccounts { + ids[acc.ID] = struct{}{} + } + s.Require().Contains(ids, account1.ID) + s.Require().Contains(ids, account2.ID) } // --- SetOverloaded / SetRateLimited / ClearRateLimit --- diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 40a9ad05..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -39,3 +39,15 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses key := buildSessionKey(groupID, sessionHash) return c.rdb.Expire(ctx, key, ttl).Err() } + +// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 +// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用, +// 以便下次请求能够重新选择可用账号。 +// +// DeleteSessionAccountID removes the sticky session binding for the given session. +// Called when the bound account becomes unavailable (e.g., error status, disabled, +// or unschedulable), allowing subsequent requests to select a new available account. +func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + key := buildSessionKey(groupID, sessionHash) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index d8885bca..0eebc33f 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -78,6 +78,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() { require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error") } +func (s *GatewayCacheSuite) TestDeleteSessionAccountID() { + sessionID := "openai:s4" + accountID := int64(102) + groupID := int64(1) + sessionTTL := 1 * time.Minute + + require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID") + require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID") + + _, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID) + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete") +} + func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { sessionID := "corrupted" groupID := int64(1) diff --git a/backend/internal/repository/gateway_routing_integration_test.go b/backend/internal/repository/gateway_routing_integration_test.go index 5566d2e9..77591fe3 100644 --- a/backend/internal/repository/gateway_routing_integration_test.go +++ b/backend/internal/repository/gateway_routing_integration_test.go @@ -24,7 +24,7 @@ func (s *GatewayRoutingSuite) SetupTest() { s.ctx = context.Background() tx := testEntTx(s.T()) s.client = tx.Client() - s.accountRepo = newAccountRepositoryWithSQL(s.client, tx) + s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil) } func TestGatewayRoutingSuite(t *testing.T) { diff --git a/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go b/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go index e442a125..a88b74ef 100644 --- a/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go +++ b/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go @@ -19,7 +19,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) { _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox") - accountRepo := newAccountRepositoryWithSQL(client, integrationDB) + accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil) outboxRepo := NewSchedulerOutboxRepository(integrationDB) cache := NewSchedulerCache(rdb) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 76d73286..e6c1c75d 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -179,6 +179,7 @@ var _ AccountRepository = (*mockAccountRepoForPlatform)(nil) // mockGatewayCacheForPlatform 单平台测试用的 cache mock type mockGatewayCacheForPlatform struct { sessionBindings map[string]int64 + deletedSessions map[string]int } func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { @@ -200,6 +201,18 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro return nil } +func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if m.sessionBindings == nil { + return nil + } + if m.deletedSessions == nil { + m.deletedSessions = make(map[string]int) + } + m.deletedSessions[sessionHash]++ + delete(m.sessionBindings, sessionHash) + return nil +} + type mockGroupRepoForGateway struct { groups map[int64]*Group getByIDCalls int @@ -623,6 +636,363 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi }) } +func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + require.Equal(t, PlatformAntigravity, acc.Platform) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionClears(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + requestedModel := "claude-3-5-sonnet-20241022" + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusDisabled, Schedulable: true}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-123": 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-group", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {1, 2}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-123", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + require.Equal(t, 1, cache.deletedSessions["session-123"]) + require.Equal(t, int64(2), cache.sessionBindings["session-123"]) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionHit(t *testing.T) { + ctx := context.Background() + groupID := int64(11) + requestedModel := "claude-3-5-sonnet-20241022" + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-456": 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-group-hit", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {1, 2}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-456", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_RoutedFallbackToNormal(t *testing.T) { + ctx := context.Background() + groupID := int64(12) + requestedModel := "claude-3-5-sonnet-20241022" + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-fallback", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {99}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_NoModelSupport(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}}, + }, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "supporting model") +} + +func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) { + ctx := context.Background() + groupID := int64(50) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-group": 1}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-group", "", nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_StickyModelMismatchFallback(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}}, + }, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-miss": 1}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-miss", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_PreferNeverUsed(t *testing.T) { + ctx := context.Background() + lastUsed := time.Now().Add(-1 * time.Hour) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &lastUsed}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing.T) { + ctx := context.Background() + repo := &mockAccountRepoForPlatform{ + accounts: []Account{}, + accountsByID: map[int64]*Account{}, + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "no available accounts") +} + func TestGatewayService_isModelSupportedByAccount(t *testing.T) { svc := &GatewayService{} @@ -740,6 +1110,301 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)") }) + t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) { + groupID := int64(30) + requestedModel := "claude-3-5-sonnet-20241022" + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-mixed-select", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {2}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + }) + + t.Run("混合调度-路由粘性命中", func(t *testing.T) { + groupID := int64(31) + requestedModel := "claude-3-5-sonnet-20241022" + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}, AccountGroups: []AccountGroup{{GroupID: groupID}}}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-777": 2}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-mixed-sticky", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {2}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-777", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + }) + + t.Run("混合调度-路由账号缺失回退", func(t *testing.T) { + groupID := int64(32) + requestedModel := "claude-3-5-sonnet-20241022" + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-mixed-miss", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {99}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) + }) + + t.Run("混合调度-路由账号未启用mixed_scheduling回退", func(t *testing.T) { + groupID := int64(33) + requestedModel := "claude-3-5-sonnet-20241022" + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-mixed-disabled", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {2}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) + }) + + t.Run("混合调度-路由过滤覆盖", func(t *testing.T) { + groupID := int64(35) + requestedModel := "claude-3-5-sonnet-20241022" + resetAt := time.Now().Add(10 * time.Minute) + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false}, + {ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + { + ID: 4, + Platform: PlatformAnthropic, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + "claude_sonnet": map[string]any{ + "rate_limit_reset_at": resetAt.Format(time.RFC3339), + }, + }, + }, + }, + { + ID: 5, + Platform: PlatformAnthropic, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}}, + }, + {ID: 6, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 7, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-mixed-filter", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {1, 2, 3, 4, 5, 6, 7}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + excluded := map[int64]struct{}{1: {}} + acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, excluded, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(7), acc.ID) + }) + + t.Run("混合调度-粘性命中分组账号", func(t *testing.T) { + groupID := int64(34) + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-group": 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-group", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) + }) + t.Run("混合调度-过滤未启用mixed_scheduling的antigravity账户", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ @@ -823,6 +1488,85 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { require.Equal(t, int64(1), acc.ID, "粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户") }) + t.Run("混合调度-粘性会话不可调度-清理并回退", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusDisabled, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-123": 1}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + require.Equal(t, 1, cache.deletedSessions["session-123"]) + require.Equal(t, int64(2), cache.sessionBindings["session-123"]) + }) + + t.Run("混合调度-路由粘性不可调度-清理并回退", func(t *testing.T) { + groupID := int64(12) + requestedModel := "claude-3-5-sonnet-20241022" + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusDisabled, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"session-123": 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Name: "route-mixed", + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + requestedModel: {1, 2}, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + groupRepo: groupRepo, + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-123", requestedModel, nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + require.Equal(t, 1, cache.deletedSessions["session-123"]) + require.Equal(t, int64(2), cache.sessionBindings["session-123"]) + }) + t.Run("混合调度-仅有启用mixed_scheduling的antigravity账户", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ @@ -873,6 +1617,65 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { require.Nil(t, acc) require.Contains(t, err.Error(), "no available accounts") }) + + t.Run("混合调度-不支持模型返回错误", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}}, + }, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "supporting model") + }) + + t.Run("混合调度-优先未使用账号", func(t *testing.T) { + lastUsed := time.Now().Add(-2 * time.Hour) + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &lastUsed}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + }) } // TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查 @@ -959,10 +1762,20 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc type mockConcurrencyCache struct { acquireAccountCalls int loadBatchCalls int + acquireResults map[int64]bool + loadBatchErr error + loadMap map[int64]*AccountLoadInfo + waitCounts map[int64]int + skipDefaultLoad bool } func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { m.acquireAccountCalls++ + if m.acquireResults != nil { + if result, ok := m.acquireResults[accountID]; ok { + return result, nil + } + } return true, nil } @@ -983,6 +1796,11 @@ func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, ac } func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if m.waitCounts != nil { + if count, ok := m.waitCounts[accountID]; ok { + return count, nil + } + } return 0, nil } @@ -1008,8 +1826,25 @@ func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID in func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { m.loadBatchCalls++ + if m.loadBatchErr != nil { + return nil, m.loadBatchErr + } result := make(map[int64]*AccountLoadInfo, len(accounts)) + if m.skipDefaultLoad && m.loadMap != nil { + for _, acc := range accounts { + if load, ok := m.loadMap[acc.ID]; ok { + result[acc.ID] = load + } + } + return result, nil + } for _, acc := range accounts { + if m.loadMap != nil { + if load, ok := m.loadMap[acc.ID]; ok { + result[acc.ID] = load + continue + } + } result[acc.ID] = &AccountLoadInfo{ AccountID: acc.ID, CurrentConcurrency: 0, @@ -1248,6 +2083,48 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询") }) + t.Run("粘性账号禁用-清理会话并回退选择", func(t *testing.T) { + testCtx := context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAnthropic) + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + repo.listPlatformFunc = func(ctx context.Context, platform string) ([]Account, error) { + return repo.accounts, nil + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"sticky": 1}, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID, "粘性账号禁用时应回退到可用账号") + updatedID, ok := cache.sessionBindings["sticky"] + require.True(t, ok, "粘性会话应更新绑定") + require.Equal(t, int64(2), updatedID, "粘性会话应绑定到新账号") + }) + t.Run("无可用账号-返回错误", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{}, @@ -1337,6 +2214,751 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { require.NotNil(t, result.Account) require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号") }) + + t.Run("粘性账号槽位满-返回粘性等待计划", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{"sticky": 1}, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 1 + + concurrencyCache := &mockConcurrencyCache{ + acquireResults: map[int64]bool{1: false}, + waitCounts: map[int64]int{1: 0}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.WaitPlan) + require.Equal(t, int64(1), result.Account.ID) + require.Equal(t, 0, concurrencyCache.loadBatchCalls) + }) + + t.Run("负载批量查询失败-降级旧顺序选择", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + loadBatchErr: errors.New("load batch failed"), + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID) + require.Equal(t, int64(2), cache.sessionBindings["legacy"]) + }) + + t.Run("模型路由-粘性账号等待计划", func(t *testing.T) { + groupID := int64(20) + sessionHash := "route-sticky" + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{sessionHash: 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-5-sonnet-20241022": {1, 2}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 1 + + concurrencyCache := &mockConcurrencyCache{ + acquireResults: map[int64]bool{1: false}, + waitCounts: map[int64]int{1: 0}, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.WaitPlan) + require.Equal(t, int64(1), result.Account.ID) + }) + + t.Run("模型路由-粘性账号命中", func(t *testing.T) { + groupID := int64(20) + sessionHash := "route-hit" + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{sessionHash: 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-5-sonnet-20241022": {1, 2}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{} + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(1), result.Account.ID) + require.Equal(t, 0, concurrencyCache.loadBatchCalls) + }) + + t.Run("模型路由-粘性账号缺失-清理并回退", func(t *testing.T) { + groupID := int64(22) + sessionHash := "route-missing" + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{sessionHash: 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-5-sonnet-20241022": {1, 2}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{} + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID) + require.Equal(t, 1, cache.deletedSessions[sessionHash]) + require.Equal(t, int64(2), cache.sessionBindings[sessionHash]) + }) + + t.Run("模型路由-按负载选择账号", func(t *testing.T) { + groupID := int64(21) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-5-sonnet-20241022": {1, 2}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 80}, + 2: {AccountID: 2, LoadRate: 20}, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID) + require.Equal(t, int64(2), cache.sessionBindings["route"]) + }) + + t.Run("模型路由-路由账号全满返回等待计划", func(t *testing.T) { + groupID := int64(23) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-5-sonnet-20241022": {1, 2}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + acquireResults: map[int64]bool{1: false, 2: false}, + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 10}, + 2: {AccountID: 2, LoadRate: 20}, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.WaitPlan) + require.Equal(t, int64(1), result.Account.ID) + }) + + t.Run("模型路由-路由账号全满-回退普通选择", func(t *testing.T) { + groupID := int64(22) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 3, Platform: PlatformAnthropic, Priority: 0, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-5-sonnet-20241022": {1, 2}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 100}, + 2: {AccountID: 2, LoadRate: 100}, + 3: {AccountID: 3, LoadRate: 0}, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(3), result.Account.ID) + require.Equal(t, int64(3), cache.sessionBindings["fallback"]) + }) + + t.Run("负载批量失败且无法获取-兜底等待", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + loadBatchErr: errors.New("load batch failed"), + acquireResults: map[int64]bool{1: false, 2: false}, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.WaitPlan) + require.Equal(t, int64(1), result.Account.ID) + }) + + t.Run("Gemini负载排序-优先OAuth", func(t *testing.T) { + groupID := int64(24) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, Type: AccountTypeAPIKey}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, Type: AccountTypeOAuth}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformGemini, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 10}, + 2: {AccountID: 2, LoadRate: 10}, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID) + }) + + t.Run("模型路由-过滤路径覆盖", func(t *testing.T) { + groupID := int64(70) + now := time.Now().Add(10 * time.Minute) + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 3, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false, Concurrency: 5}, + {ID: 4, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + { + ID: 5, + Platform: PlatformAnthropic, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + "claude_sonnet": map[string]any{ + "rate_limit_reset_at": now.Format(time.RFC3339), + }, + }, + }, + }, + { + ID: 6, + Platform: PlatformAnthropic, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}}, + }, + {ID: 7, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-5-sonnet-20241022": {1, 2, 3, 4, 5, 6}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{} + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + excluded := map[int64]struct{}{1: {}} + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(7), result.Account.ID) + }) + + t.Run("ClaudeCode限制-回退分组", func(t *testing.T) { + groupID := int64(60) + fallbackID := int64(61) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ClaudeCodeOnly: true, + FallbackGroupID: func() *int64 { + v := fallbackID + return &v + }(), + }, + fallbackID: { + ID: fallbackID, + Platform: PlatformGemini, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: &mockGatewayCacheForPlatform{}, + cfg: cfg, + concurrencyService: nil, + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(1), result.Account.ID) + }) + + t.Run("ClaudeCode限制-无降级返回错误", func(t *testing.T) { + groupID := int64(62) + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ClaudeCodeOnly: true, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + + svc := &GatewayService{ + accountRepo: &mockAccountRepoForPlatform{}, + groupRepo: groupRepo, + cache: &mockGatewayCacheForPlatform{}, + cfg: cfg, + concurrencyService: nil, + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) + require.Error(t, err) + require.Nil(t, result) + require.ErrorIs(t, err, ErrClaudeCodeOnly) + }) + + t.Run("负载可用但无法获取槽位-兜底等待", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + acquireResults: map[int64]bool{1: false, 2: false}, + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 10}, + 2: {AccountID: 2, LoadRate: 20}, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.WaitPlan) + require.Equal(t, int64(1), result.Account.ID) + }) + + t.Run("负载信息缺失-使用默认负载", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + concurrencyCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 50}, + }, + skipDefaultLoad: true, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID) + }) } func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 1e3221d3..df010b6f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -97,11 +97,24 @@ var allowedHeaders = map[string]bool{ "content-type": true, } -// GatewayCache defines cache operations for gateway service +// GatewayCache 定义网关服务的缓存操作接口。 +// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 +// +// GatewayCache defines cache operations for gateway service. +// Provides sticky session storage, retrieval, refresh and deletion capabilities. type GatewayCache interface { + // GetSessionAccountID 获取粘性会话绑定的账号 ID + // Get the account ID bound to a sticky session GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) + // SetSessionAccountID 设置粘性会话与账号的绑定关系 + // Set the binding between sticky session and account SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error + // RefreshSessionTTL 刷新粘性会话的过期时间 + // Refresh the expiration time of a sticky session RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error + // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 + // Delete sticky session binding, used to proactively clean up when account becomes unavailable + DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -112,6 +125,28 @@ func derefGroupID(groupID *int64) int64 { return *groupID } +// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 +// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。 +// 这确保后续请求不会继续使用不可用的账号。 +// +// shouldClearStickySession checks if an account is in an unschedulable state +// and the sticky session binding should be cleared. +// Returns true when account status is error/disabled, schedulable is false, +// or within temporary unschedulable period. +// This ensures subsequent requests won't continue using unavailable accounts. +func shouldClearStickySession(account *Account) bool { + if account == nil { + return false + } + if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable { + return true + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return true + } + return false +} + type AccountWaitPlan struct { AccountID int64 MaxConcurrency int @@ -601,6 +636,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 } + } else { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } } } @@ -696,31 +733,37 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] - if ok && s.isAccountInGroup(account, groupID) && - s.isAccountAllowedForPlatform(account, platform, useMixed) && - account.IsSchedulableForModel(requestedModel) && - (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + if ok { + clearSticky := shouldClearStickySession(account) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } + if !clearSticky && s.isAccountInGroup(account, groupID) && + s.isAccountAllowedForPlatform(account, platform, useMixed) && + account.IsSchedulableForModel(requestedModel) && + (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } } } @@ -1133,14 +1176,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) - if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + if err == nil { + clearSticky := shouldClearStickySession(account) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil } - return account, nil } } } @@ -1230,11 +1279,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) - if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + if err == nil { + clearSticky := shouldClearStickySession(account) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + return account, nil } - return account, nil } } } @@ -1334,15 +1389,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + if err == nil { + clearSticky := shouldClearStickySession(account) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil } - if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) - } - return account, nil } } } @@ -1433,12 +1494,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + if err == nil { + clearSticky := shouldClearStickySession(account) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + return account, nil } - return account, nil } } } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 75de90f2..7234540f 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, } func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 优先检查 context 中的强制平台(/antigravity 路由) - var platform string - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform != "" { - platform = forcePlatform - } else if groupID != nil { - // 根据分组 platform 决定查询哪种账号 - var group *Group - if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID { - group = ctxGroup - } else { - var err error - group, err = s.groupRepo.GetByIDLite(ctx, *groupID) - if err != nil { - return nil, fmt.Errorf("get group failed: %w", err) - } - } - platform = group.Platform - } else { - // 无分组时只使用原生 gemini 平台 - platform = PlatformGemini + // 1. 确定目标平台和调度模式 + // Determine target platform and scheduling mode + platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID) + if err != nil { + return nil, err } - // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) - // 注意:强制平台模式不走混合调度 - useMixedScheduling := platform == PlatformGemini && !hasForcePlatform - cacheKey := "gemini:" + sessionHash - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) - if err == nil && accountID > 0 { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.getSchedulableAccount(ctx, accountID) - // 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - valid := false - if account.Platform == platform { - valid = true - } else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() { - valid = true - } - if valid { - usable := true - if s.rateLimitService != nil && requestedModel != "" { - ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) - if err != nil { - log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) - } - if !ok { - usable = false - } - } - if usable { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL) - return account, nil - } - } - } - } - } + // 2. 尝试粘性会话命中 + // Try sticky session hit + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil { + return account, nil } - // 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部) + // 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部) + // Query schedulable accounts (force platform mode: try group first, fallback to all) accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } } - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - // 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling - // 非混合调度模式(antigravity 分组):不需要过滤 - if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - if !acc.IsSchedulableForModel(requestedModel) { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { - continue - } - if s.rateLimitService != nil && requestedModel != "" { - ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel) - if err != nil { - log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err) - } - if !ok { - continue - } - } - if selected == nil { - selected = acc - continue - } - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - // Prefer OAuth accounts when both are unused (more compatible for Code Assist flows). - if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth { - selected = acc - } - default: - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } + // 4. 按优先级 + LRU 选择最佳账号 + // Select best account by priority + LRU + selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling) if selected == nil { if requestedModel != "" { @@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co return nil, errors.New("no available Gemini accounts") } + // 5. 设置粘性会话绑定 + // Set sticky session binding if sessionHash != "" { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) } @@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co return selected, nil } +// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。 +// 返回:平台名称、是否使用混合调度、是否强制平台、错误。 +// +// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode. +// Returns: platform name, whether to use mixed scheduling, whether force platform, error. +func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) { + // 优先检查 context 中的强制平台(/antigravity 路由) + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + return forcePlatform, false, true, nil + } + + if groupID != nil { + // 根据分组 platform 决定查询哪种账号 + var group *Group + if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID { + group = ctxGroup + } else { + group, err = s.groupRepo.GetByIDLite(ctx, *groupID) + if err != nil { + return "", false, false, fmt.Errorf("get group failed: %w", err) + } + } + // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) + return group.Platform, group.Platform == PlatformGemini, false, nil + } + + // 无分组时只使用原生 gemini 平台 + return PlatformGemini, true, false, nil +} + +// tryStickySessionHit 尝试从粘性会话获取账号。 +// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。 +// +// tryStickySessionHit attempts to get account from sticky session. +// Returns account if hit and usable; clears session and returns nil if account unavailable. +func (s *GeminiMessagesCompatService) tryStickySessionHit( + ctx context.Context, + groupID *int64, + sessionHash, cacheKey, requestedModel string, + excludedIDs map[int64]struct{}, + platform string, + useMixedScheduling bool, +) *Account { + if sessionHash == "" { + return nil + } + + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + if err != nil || accountID <= 0 { + return nil + } + + if _, excluded := excludedIDs[accountID]; excluded { + return nil + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil { + return nil + } + + // 检查账号是否需要清理粘性会话 + // Check if sticky session should be cleared + if shouldClearStickySession(account) { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + return nil + } + + // 验证账号是否可用于当前请求 + // Verify account is usable for current request + if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) { + return nil + } + + // 刷新会话 TTL 并返回账号 + // Refresh session TTL and return account + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL) + return account +} + +// isAccountUsableForRequest 检查账号是否可用于当前请求。 +// 验证:模型调度、模型支持、平台匹配、速率限制预检。 +// +// isAccountUsableForRequest checks if account is usable for current request. +// Validates: model scheduling, model support, platform matching, rate limit precheck. +func (s *GeminiMessagesCompatService) isAccountUsableForRequest( + ctx context.Context, + account *Account, + requestedModel, platform string, + useMixedScheduling bool, +) bool { + // 检查模型调度能力 + // Check model scheduling capability + if !account.IsSchedulableForModel(requestedModel) { + return false + } + + // 检查模型支持 + // Check model support + if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) { + return false + } + + // 检查平台匹配 + // Check platform matching + if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) { + return false + } + + // 速率限制预检 + // Rate limit precheck + if !s.passesRateLimitPreCheck(ctx, account, requestedModel) { + return false + } + + return true +} + +// isAccountValidForPlatform 检查账号是否匹配目标平台。 +// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。 +// +// isAccountValidForPlatform checks if account matches target platform. +// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling. +func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool { + if account.Platform == platform { + return true + } + if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() { + return true + } + return false +} + +// passesRateLimitPreCheck 执行速率限制预检。 +// 返回 true 表示通过预检或无需预检。 +// +// passesRateLimitPreCheck performs rate limit precheck. +// Returns true if passed or precheck not required. +func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool { + if s.rateLimitService == nil || requestedModel == "" { + return true + } + ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) + if err != nil { + log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) + } + return ok +} + +// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。 +// 返回 nil 表示无可用账号。 +// +// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred). +// Returns nil if no available account. +func (s *GeminiMessagesCompatService) selectBestGeminiAccount( + ctx context.Context, + accounts []Account, + requestedModel string, + excludedIDs map[int64]struct{}, + platform string, + useMixedScheduling bool, +) *Account { + var selected *Account + + for i := range accounts { + acc := &accounts[i] + + // 跳过被排除的账号 + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + + // 检查账号是否可用于当前请求 + if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) { + continue + } + + // 选择最佳账号 + if selected == nil { + selected = acc + continue + } + + if s.isBetterGeminiAccount(acc, selected) { + selected = acc + } + } + + return selected +} + +// isBetterGeminiAccount 判断 candidate 是否比 current 更优。 +// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。 +// +// isBetterGeminiAccount checks if candidate is better than current. +// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used. +func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool { + // 优先级更高(数值更小) + if candidate.Priority < current.Priority { + return true + } + if candidate.Priority > current.Priority { + return false + } + + // 同优先级,比较最后使用时间 + switch { + case candidate.LastUsedAt == nil && current.LastUsedAt != nil: + // candidate 从未使用,优先 + return true + case candidate.LastUsedAt != nil && current.LastUsedAt == nil: + // current 从未使用,保持 + return false + case candidate.LastUsedAt == nil && current.LastUsedAt == nil: + // 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程) + return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth + default: + // 都使用过,选择最久未使用的 + return candidate.LastUsedAt.Before(*current.LastUsedAt) + } +} + // isModelSupportedByAccount 根据账户平台检查模型支持 func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 03f5d757..35f1222d 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -15,8 +15,10 @@ import ( // mockAccountRepoForGemini Gemini 测试用的 mock type mockAccountRepoForGemini struct { - accounts []Account - accountsByID map[int64]*Account + accounts []Account + accountsByID map[int64]*Account + listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) + listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error) } func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) { @@ -104,6 +106,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, return nil, nil } func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + if m.listByPlatformFunc != nil { + return m.listByPlatformFunc(ctx, platforms) + } var result []Account platformSet := make(map[string]bool) for _, p := range platforms { @@ -117,6 +122,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex return result, nil } func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + if m.listByGroupFunc != nil { + return m.listByGroupFunc(ctx, groupID, platforms) + } return m.ListSchedulableByPlatforms(ctx, platforms) } func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { @@ -212,6 +220,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil) // mockGatewayCacheForGemini Gemini 测试用的 cache mock type mockGatewayCacheForGemini struct { sessionBindings map[string]int64 + deletedSessions map[string]int } func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { @@ -233,6 +242,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group return nil } +func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if m.sessionBindings == nil { + return nil + } + if m.deletedSessions == nil { + m.deletedSessions = make(map[string]int) + } + m.deletedSessions[sessionHash]++ + delete(m.sessionBindings, sessionHash) + return nil +} + // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { ctx := context.Background() @@ -523,6 +544,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS // 粘性会话未命中,按优先级选择 require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择") }) + + t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) { + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{ + sessionBindings: map[string]int64{"gemini:session-123": 1}, + } + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) + require.Equal(t, 1, cache.deletedSessions["gemini:session-123"]) + require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"]) + }) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) { + ctx := context.Background() + groupID := int64(9) + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity) + + repo := &mockAccountRepoForGemini{ + listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + return nil, nil + }, + listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) { + return []Account{ + {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, nil + }, + accountsByID: map[int64]*Account{ + 1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformGemini, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}}, + }, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "supporting model") +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) { + ctx := context.Background() + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}}, + {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{ + sessionBindings: map[string]int64{"gemini:session-999": 1}, + } + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) { + ctx := context.Background() + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) { + ctx := context.Background() + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + excluded := map[int64]struct{}{1: {}} + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) { + ctx := context.Background() + repo := &mockAccountRepoForGemini{ + listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) { + return nil, errors.New("query failed") + }, + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "query accounts failed") +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) { + ctx := context.Background() + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) { + ctx := context.Background() + oldTime := time.Now().Add(-2 * time.Hour) + newTime := time.Now().Add(-1 * time.Hour) + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) } // TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑 diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index c7d94882..81f2c12d 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -162,67 +162,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI } // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. +// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 1. Check sticky session - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) - if err == nil && accountID > 0 { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.getSchedulableAccount(ctx, accountID) - if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - // Refresh sticky session TTL - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) - return account, nil - } - } - } + cacheKey := "openai:" + sessionHash + + // 1. 尝试粘性会话命中 + // Try sticky session hit + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { + return account, nil } - // 2. Get schedulable OpenAI accounts + // 2. 获取可调度的 OpenAI 账号 + // Get schedulable OpenAI accounts accounts, err := s.listSchedulableAccounts(ctx, groupID) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } - // 3. Select by priority + LRU - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - // Scheduler snapshots can be temporarily stale; re-check schedulability here to - // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { - continue - } - // Check model support - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { - continue - } - if selected == nil { - selected = acc - continue - } - // Lower priority value means higher priority - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - // keep selected (both never used) - default: - // Same priority, select least recently used - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } + // 3. 按优先级 + LRU 选择最佳账号 + // Select by priority + LRU + selected := s.selectBestAccount(accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -231,14 +190,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C return nil, errors.New("no available OpenAI accounts") } - // 4. Set sticky session + // 4. 设置粘性会话绑定 + // Set sticky session binding if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL) } return selected, nil } +// tryStickySessionHit 尝试从粘性会话获取账号。 +// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。 +// +// tryStickySessionHit attempts to get account from sticky session. +// Returns account if hit and usable; clears session and returns nil if account is unavailable. +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { + if sessionHash == "" { + return nil + } + + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + if err != nil || accountID <= 0 { + return nil + } + + if _, excluded := excludedIDs[accountID]; excluded { + return nil + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil { + return nil + } + + // 检查账号是否需要清理粘性会话 + // Check if sticky session should be cleared + if shouldClearStickySession(account) { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + return nil + } + + // 验证账号是否可用于当前请求 + // Verify account is usable for current request + if !account.IsSchedulable() || !account.IsOpenAI() { + return nil + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return nil + } + + // 刷新会话 TTL 并返回账号 + // Refresh session TTL and return account + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL) + return account +} + +// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。 +// 返回 nil 表示无可用账号。 +// +// selectBestAccount selects the best account from candidates (priority + LRU). +// Returns nil if no available account. +func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { + var selected *Account + + for i := range accounts { + acc := &accounts[i] + + // 跳过被排除的账号 + // Skip excluded accounts + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + + // 调度器快照可能暂时过时,这里重新检查可调度性和平台 + // Scheduler snapshots can be temporarily stale; re-check schedulability and platform + if !acc.IsSchedulable() || !acc.IsOpenAI() { + continue + } + + // 检查模型支持 + // Check model support + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + continue + } + + // 选择优先级最高且最久未使用的账号 + // Select highest priority and least recently used + if selected == nil { + selected = acc + continue + } + + if s.isBetterAccount(acc, selected) { + selected = acc + } + } + + return selected +} + +// isBetterAccount 判断 candidate 是否比 current 更优。 +// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。 +// +// isBetterAccount checks if candidate is better than current. +// Rules: higher priority (lower value) wins; same priority: never used > least recently used. +func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool { + // 优先级更高(数值更小) + // Higher priority (lower value) + if candidate.Priority < current.Priority { + return true + } + if candidate.Priority > current.Priority { + return false + } + + // 同优先级,比较最后使用时间 + // Same priority, compare last used time + switch { + case candidate.LastUsedAt == nil && current.LastUsedAt != nil: + // candidate 从未使用,优先 + return true + case candidate.LastUsedAt != nil && current.LastUsedAt == nil: + // current 从未使用,保持 + return false + case candidate.LastUsedAt == nil && current.LastUsedAt == nil: + // 都未使用,保持 + return false + default: + // 都使用过,选择最久未使用的 + return candidate.LastUsedAt.Before(*current.LastUsedAt) + } +} + // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { cfg := s.schedulingConfig() @@ -307,29 +390,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) - if err == nil && account.IsSchedulable() && account.IsOpenAI() && - (requestedModel == "" || account.IsModelSupported(requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + if err == nil { + clearSticky := shouldClearStickySession(account) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) } + if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && + (requestedModel == "" || account.IsModelSupported(requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } } } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 42b88b7d..14394bde 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -21,19 +21,50 @@ type stubOpenAIAccountRepo struct { accounts []Account } +func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, errors.New("account not found") +} + func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { - return append([]Account(nil), r.accounts...), nil + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil } func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { - return append([]Account(nil), r.accounts...), nil + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil } type stubConcurrencyCache struct { ConcurrencyCache + loadBatchErr error + loadMap map[int64]*AccountLoadInfo + acquireResults map[int64]bool + waitCounts map[int64]int + skipDefaultLoad bool } func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if c.acquireResults != nil { + if result, ok := c.acquireResults[accountID]; ok { + return result, nil + } + } return true, nil } @@ -42,13 +73,75 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID } func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if c.loadBatchErr != nil { + return nil, c.loadBatchErr + } out := make(map[int64]*AccountLoadInfo, len(accounts)) + if c.skipDefaultLoad && c.loadMap != nil { + for _, acc := range accounts { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + } + } + return out, nil + } for _, acc := range accounts { + if c.loadMap != nil { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + continue + } + } out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} } return out, nil } +func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if c.waitCounts != nil { + if count, ok := c.waitCounts[accountID]; ok { + return count, nil + } + } + return 0, nil +} + +type stubGatewayCache struct { + sessionBindings map[string]int64 + deletedSessions map[string]int +} + +func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + if id, ok := c.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + if c.sessionBindings == nil { + c.sessionBindings = make(map[string]int64) + } + c.sessionBindings[sessionHash] = accountID + return nil +} + +func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if c.sessionBindings == nil { + return nil + } + if c.deletedSessions == nil { + c.deletedSessions = make(map[string]int) + } + c.deletedSessions[sessionHash]++ + delete(c.sessionBindings, sessionHash) + return nil +} + func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) @@ -139,6 +232,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre } } +func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) { + sessionHash := "session-1" + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1}, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) + } + if acc == nil || acc.ID != 2 { + t.Fatalf("expected account 2, got %+v", acc) + } + if cache.deletedSessions["openai:"+sessionHash] != 1 { + t.Fatalf("expected sticky session to be deleted") + } + if cache.sessionBindings["openai:"+sessionHash] != 2 { + t.Fatalf("expected sticky session to bind to account 2") + } +} + +func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) { + sessionHash := "session-2" + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1}, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.Account == nil || selection.Account.ID != 2 { + t.Fatalf("expected account 2, got %+v", selection) + } + if cache.deletedSessions["openai:"+sessionHash] != 1 { + t.Fatalf("expected sticky session to be deleted") + } + if cache.sessionBindings["openai:"+sessionHash] != 2 { + t.Fatalf("expected sticky session to bind to account 2") + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) { + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}}, + }, + }, + } + cache := &stubGatewayCache{} + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil) + if err == nil { + t.Fatalf("expected error for unsupported model") + } + if acc != nil { + t.Fatalf("expected nil account for unsupported model") + } + if !strings.Contains(err.Error(), "supporting model") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) { + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{ + loadBatchErr: errors.New("load batch failed"), + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.Account == nil { + t.Fatalf("expected selection") + } + if selection.Account.ID != 2 { + t.Fatalf("expected account 2, got %d", selection.Account.ID) + } + if cache.sessionBindings["openai:fallback"] != 2 { + t.Fatalf("expected sticky session updated") + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) { + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{1: false}, + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 10}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.WaitPlan == nil { + t.Fatalf("expected wait plan fallback") + } + if selection.Account == nil || selection.Account.ID != 1 { + t.Fatalf("expected account 1") + } +} + +func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) { + sessionHash := "bind" + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) + } + if acc == nil || acc.ID != 1 { + t.Fatalf("expected account 1") + } + if cache.sessionBindings["openai:"+sessionHash] != 1 { + t.Fatalf("expected sticky session binding") + } +} + +func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) { + sessionHash := "sticky-wait" + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, + } + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{1: false}, + waitCounts: map[int64]int{1: 0}, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.WaitPlan == nil { + t.Fatalf("expected sticky wait plan") + } + if selection.Account == nil || selection.Account.ID != 1 { + t.Fatalf("expected account 1") + } +} + +func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) { + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 80}, + 2: {AccountID: 2, LoadRate: 10}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.Account == nil || selection.Account.ID != 2 { + t.Fatalf("expected account 2") + } + if cache.sessionBindings["openai:load"] != 2 { + t.Fatalf("expected sticky session updated") + } +} + +func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) { + sessionHash := "excluded" + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2}, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + } + + excluded := map[int64]struct{}{1: {}} + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded) + if err != nil { + t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) + } + if acc == nil || acc.ID != 2 { + t.Fatalf("expected account 2") + } +} + +func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) { + sessionHash := "non-openai" + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2}, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) + } + if acc == nil || acc.ID != 2 { + t.Fatalf("expected account 2") + } +} + +func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) { + repo := stubOpenAIAccountRepo{accounts: []Account{}} + cache := &stubGatewayCache{} + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil) + if err == nil { + t.Fatalf("expected error for no accounts") + } + if acc != nil { + t.Fatalf("expected nil account") + } + if !strings.Contains(err.Error(), "no available OpenAI accounts") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) { + groupID := int64(1) + resetAt := time.Now().Add(1 * time.Hour) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{} + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err == nil { + t.Fatalf("expected error for no candidates") + } + if selection != nil { + t.Fatalf("expected nil selection") + } +} + +func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) { + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 100}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.WaitPlan == nil { + t.Fatalf("expected wait plan") + } +} + +func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) { + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{ + loadBatchErr: errors.New("load batch failed"), + acquireResults: map[int64]bool{1: false}, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.WaitPlan == nil { + t.Fatalf("expected wait plan") + } +} + +func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) { + groupID := int64(1) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 50}, + }, + skipDefaultLoad: true, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.Account == nil || selection.Account.ID != 2 { + t.Fatalf("expected account 2") + } +} + +func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) { + oldTime := time.Now().Add(-2 * time.Hour) + newTime := time.Now().Add(-1 * time.Hour) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime}, + }, + } + cache := &stubGatewayCache{} + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) + } + if acc == nil || acc.ID != 2 { + t.Fatalf("expected account 2") + } +} + +func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) { + groupID := int64(1) + lastUsed := time.Now().Add(-1 * time.Hour) + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 10}, + 2: {AccountID: 2, LoadRate: 10}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.Account == nil || selection.Account.ID != 2 { + t.Fatalf("expected account 2") + } +} + func TestOpenAIStreamingTimeout(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go new file mode 100644 index 00000000..4bd06b7b --- /dev/null +++ b/backend/internal/service/sticky_session_test.go @@ -0,0 +1,54 @@ +//go:build unit + +// Package service 提供 API 网关核心服务。 +// 本文件包含 shouldClearStickySession 函数的单元测试, +// 验证粘性会话清理逻辑在各种账号状态下的正确行为。 +// +// This file contains unit tests for the shouldClearStickySession function, +// verifying correct sticky session clearing behavior under various account states. +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// TestShouldClearStickySession 测试粘性会话清理判断逻辑。 +// 验证在以下情况下是否正确判断需要清理粘性会话: +// - nil 账号:不清理(返回 false) +// - 状态为错误或禁用:清理 +// - 不可调度:清理 +// - 临时不可调度且未过期:清理 +// - 临时不可调度已过期:不清理 +// - 正常可调度状态:不清理 +// +// TestShouldClearStickySession tests the sticky session clearing logic. +// Verifies correct behavior for various account states including: +// nil account, error/disabled status, unschedulable, temporary unschedulable. +func TestShouldClearStickySession(t *testing.T) { + now := time.Now() + future := now.Add(1 * time.Hour) + past := now.Add(-1 * time.Hour) + + tests := []struct { + name string + account *Account + want bool + }{ + {name: "nil account", account: nil, want: false}, + {name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true}, + {name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true}, + {name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true}, + {name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true}, + {name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false}, + {name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, shouldClearStickySession(tt.account)) + }) + } +}