diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index fdc5c6ac..aff9a0ff 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) - schedulerCache := repository.NewSchedulerCache(redisClient) + schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 9b430377..ad023dc1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -620,6 +620,10 @@ type GatewaySchedulingConfig struct { // 负载计算 LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + // 快照桶读取时的 MGET 分块大小 + SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"` + // 快照重建时的缓存写入分块大小 + SnapshotWriteChunkSize int `mapstructure:"snapshot_write_chunk_size"` // 过期槽位清理周期(0 表示禁用) SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` @@ -1340,6 +1344,8 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128) + viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0) @@ -2001,6 +2007,12 @@ func (c *Config) Validate() error { if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") } + if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 { + return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive") + } + if c.Gateway.Scheduling.SnapshotWriteChunkSize <= 0 { + return fmt.Errorf("gateway.scheduling.snapshot_write_chunk_size must be positive") + } if c.Gateway.Scheduling.SlotCleanupInterval < 0 { return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 4caef955..acea3780 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -34,7 +34,12 @@ func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerB func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error { return nil } -func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) { +func (f *fakeSchedulerCache) GetAccount(_ context.Context, id int64) (*service.Account, error) { + for _, account := range f.accounts { + if account != nil && account.ID == id { + return account, nil + } + } return nil, nil } func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil } diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go index fb9c26c4..5857fbcb 100644 --- a/backend/internal/repository/integration_harness_test.go +++ b/backend/internal/repository/integration_harness_test.go @@ -332,6 +332,10 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) { "hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists", "zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore": prefixOne(1) + case "mget": + for i := 1; i < len(args); i++ { + prefixOne(i) + } case "del", "unlink": for i := 1; i < len(args); i++ { prefixOne(i) diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index 4f447e4f..35345a8b 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -15,19 +15,39 @@ const ( schedulerBucketSetKey = "sched:buckets" schedulerOutboxWatermarkKey = "sched:outbox:watermark" schedulerAccountPrefix = "sched:acc:" + schedulerAccountMetaPrefix = "sched:meta:" schedulerActivePrefix = "sched:active:" schedulerReadyPrefix = "sched:ready:" schedulerVersionPrefix = "sched:ver:" schedulerSnapshotPrefix = "sched:" schedulerLockPrefix = "sched:lock:" + + defaultSchedulerSnapshotMGetChunkSize = 128 + defaultSchedulerSnapshotWriteChunkSize = 256 ) type schedulerCache struct { - rdb *redis.Client + rdb *redis.Client + mgetChunkSize int + writeChunkSize int } func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache { - return &schedulerCache{rdb: rdb} + return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize) +} + +func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache { + if mgetChunkSize <= 0 { + mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize + } + if writeChunkSize <= 0 { + writeChunkSize = defaultSchedulerSnapshotWriteChunkSize + } + return &schedulerCache{ + rdb: rdb, + mgetChunkSize: mgetChunkSize, + writeChunkSize: writeChunkSize, + } } func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { @@ -65,9 +85,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul keys := make([]string, 0, len(ids)) for _, id := range ids { - keys = append(keys, schedulerAccountKey(id)) + keys = append(keys, schedulerAccountMetaKey(id)) } - values, err := c.rdb.MGet(ctx, keys...).Result() + values, err := c.mgetChunked(ctx, keys) if err != nil { return nil, false, err } @@ -100,14 +120,11 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul versionStr := strconv.FormatInt(version, 10) snapshotKey := schedulerSnapshotKey(bucket, versionStr) - pipe := c.rdb.Pipeline() - for _, account := range accounts { - payload, err := json.Marshal(account) - if err != nil { - return err - } - pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0) + if err := c.writeAccounts(ctx, accounts); err != nil { + return err } + + pipe := c.rdb.Pipeline() if len(accounts) > 0 { // 使用序号作为 score,保持数据库返回的排序语义。 members := make([]redis.Z, 0, len(accounts)) @@ -117,7 +134,13 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul Member: strconv.FormatInt(account.ID, 10), }) } - pipe.ZAdd(ctx, snapshotKey, members...) + for start := 0; start < len(members); start += c.writeChunkSize { + end := start + c.writeChunkSize + if end > len(members) { + end = len(members) + } + pipe.ZAdd(ctx, snapshotKey, members[start:end]...) + } } else { pipe.Del(ctx, snapshotKey) } @@ -151,20 +174,15 @@ func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Accoun if account == nil || account.ID <= 0 { return nil } - payload, err := json.Marshal(account) - if err != nil { - return err - } - key := schedulerAccountKey(strconv.FormatInt(account.ID, 10)) - return c.rdb.Set(ctx, key, payload, 0).Err() + return c.writeAccounts(ctx, []service.Account{*account}) } func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error { if accountID <= 0 { return nil } - key := schedulerAccountKey(strconv.FormatInt(accountID, 10)) - return c.rdb.Del(ctx, key).Err() + id := strconv.FormatInt(accountID, 10) + return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err() } func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { @@ -179,7 +197,7 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t ids = append(ids, id) } - values, err := c.rdb.MGet(ctx, keys...).Result() + values, err := c.mgetChunked(ctx, keys) if err != nil { return err } @@ -198,7 +216,12 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t if err != nil { return err } + metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account)) + if err != nil { + return err + } pipe.Set(ctx, keys[i], updated, 0) + pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0) } _, err = pipe.Exec(ctx) return err @@ -256,6 +279,10 @@ func schedulerAccountKey(id string) string { return schedulerAccountPrefix + id } +func schedulerAccountMetaKey(id string) string { + return schedulerAccountMetaPrefix + id +} + func ptrTime(t time.Time) *time.Time { return &t } @@ -276,3 +303,137 @@ func decodeCachedAccount(val any) (*service.Account, error) { } return &account, nil } + +func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error { + if len(accounts) == 0 { + return nil + } + + pipe := c.rdb.Pipeline() + pending := 0 + flush := func() error { + if pending == 0 { + return nil + } + if _, err := pipe.Exec(ctx); err != nil { + return err + } + pipe = c.rdb.Pipeline() + pending = 0 + return nil + } + + for _, account := range accounts { + fullPayload, err := json.Marshal(account) + if err != nil { + return err + } + metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account)) + if err != nil { + return err + } + + id := strconv.FormatInt(account.ID, 10) + pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0) + pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0) + pending++ + if pending >= c.writeChunkSize { + if err := flush(); err != nil { + return err + } + } + } + + return flush() +} + +func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) { + if len(keys) == 0 { + return []any{}, nil + } + + out := make([]any, 0, len(keys)) + chunkSize := c.mgetChunkSize + if chunkSize <= 0 { + chunkSize = defaultSchedulerSnapshotMGetChunkSize + } + for start := 0; start < len(keys); start += chunkSize { + end := start + chunkSize + if end > len(keys) { + end = len(keys) + } + part, err := c.rdb.MGet(ctx, keys[start:end]...).Result() + if err != nil { + return nil, err + } + out = append(out, part...) + } + return out, nil +} + +func buildSchedulerMetadataAccount(account service.Account) service.Account { + return service.Account{ + ID: account.ID, + Name: account.Name, + Platform: account.Platform, + Type: account.Type, + Concurrency: account.Concurrency, + Priority: account.Priority, + RateMultiplier: account.RateMultiplier, + Status: account.Status, + LastUsedAt: account.LastUsedAt, + ExpiresAt: account.ExpiresAt, + AutoPauseOnExpired: account.AutoPauseOnExpired, + Schedulable: account.Schedulable, + RateLimitedAt: account.RateLimitedAt, + RateLimitResetAt: account.RateLimitResetAt, + OverloadUntil: account.OverloadUntil, + TempUnschedulableUntil: account.TempUnschedulableUntil, + TempUnschedulableReason: account.TempUnschedulableReason, + SessionWindowStart: account.SessionWindowStart, + SessionWindowEnd: account.SessionWindowEnd, + SessionWindowStatus: account.SessionWindowStatus, + Credentials: filterSchedulerCredentials(account.Credentials), + Extra: filterSchedulerExtra(account.Extra), + } +} + +func filterSchedulerCredentials(credentials map[string]any) map[string]any { + if len(credentials) == 0 { + return nil + } + keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"} + filtered := make(map[string]any) + for _, key := range keys { + if value, ok := credentials[key]; ok && value != nil { + filtered[key] = value + } + } + if len(filtered) == 0 { + return nil + } + return filtered +} + +func filterSchedulerExtra(extra map[string]any) map[string]any { + if len(extra) == 0 { + return nil + } + keys := []string{ + "mixed_scheduling", + "window_cost_limit", + "window_cost_sticky_reserve", + "max_sessions", + "session_idle_timeout_minutes", + } + filtered := make(map[string]any) + for _, key := range keys { + if value, ok := extra[key]; ok && value != nil { + filtered[key] = value + } + } + if len(filtered) == 0 { + return nil + } + return filtered +} diff --git a/backend/internal/repository/scheduler_cache_integration_test.go b/backend/internal/repository/scheduler_cache_integration_test.go new file mode 100644 index 00000000..134a6a07 --- /dev/null +++ b/backend/internal/repository/scheduler_cache_integration_test.go @@ -0,0 +1,88 @@ +//go:build integration + +package repository + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) { + ctx := context.Background() + rdb := testRedis(t) + cache := NewSchedulerCache(rdb) + + bucket := service.SchedulerBucket{GroupID: 2, Platform: service.PlatformGemini, Mode: service.SchedulerModeSingle} + now := time.Now().UTC().Truncate(time.Second) + limitReset := now.Add(10 * time.Minute) + overloadUntil := now.Add(2 * time.Minute) + tempUnschedUntil := now.Add(3 * time.Minute) + windowEnd := now.Add(5 * time.Hour) + + account := service.Account{ + ID: 101, + Name: "gemini-heavy", + Platform: service.PlatformGemini, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 7, + LastUsedAt: &now, + Credentials: map[string]any{ + "api_key": "gemini-api-key", + "access_token": "secret-access-token", + "project_id": "proj-1", + "oauth_type": "ai_studio", + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + "huge_blob": strings.Repeat("x", 4096), + }, + Extra: map[string]any{ + "mixed_scheduling": true, + "window_cost_limit": 12.5, + "window_cost_sticky_reserve": 8.0, + "max_sessions": 4, + "session_idle_timeout_minutes": 11, + "unused_large_field": strings.Repeat("y", 4096), + }, + RateLimitResetAt: &limitReset, + OverloadUntil: &overloadUntil, + TempUnschedulableUntil: &tempUnschedUntil, + SessionWindowStart: &now, + SessionWindowEnd: &windowEnd, + SessionWindowStatus: "active", + } + + require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account})) + + snapshot, hit, err := cache.GetSnapshot(ctx, bucket) + require.NoError(t, err) + require.True(t, hit) + require.Len(t, snapshot, 1) + + got := snapshot[0] + require.NotNil(t, got) + require.Equal(t, "gemini-api-key", got.GetCredential("api_key")) + require.Equal(t, "proj-1", got.GetCredential("project_id")) + require.Equal(t, "ai_studio", got.GetCredential("oauth_type")) + require.NotEmpty(t, got.GetModelMapping()) + require.Empty(t, got.GetCredential("access_token")) + require.Empty(t, got.GetCredential("huge_blob")) + require.Equal(t, true, got.Extra["mixed_scheduling"]) + require.Equal(t, 12.5, got.GetWindowCostLimit()) + require.Equal(t, 8.0, got.GetWindowCostStickyReserve()) + require.Equal(t, 4, got.GetMaxSessions()) + require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes()) + require.Nil(t, got.Extra["unused_large_field"]) + + full, err := cache.GetAccount(ctx, account.ID) + require.NoError(t, err) + require.NotNil(t, full) + require.Equal(t, "secret-access-token", full.GetCredential("access_token")) + require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob")) +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 657e3ed6..d3adb4a0 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -47,6 +47,21 @@ func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.Ses return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes) } +// ProvideSchedulerCache 创建调度快照缓存,并注入快照分块参数。 +func ProvideSchedulerCache(rdb *redis.Client, cfg *config.Config) service.SchedulerCache { + mgetChunkSize := defaultSchedulerSnapshotMGetChunkSize + writeChunkSize := defaultSchedulerSnapshotWriteChunkSize + if cfg != nil { + if cfg.Gateway.Scheduling.SnapshotMGetChunkSize > 0 { + mgetChunkSize = cfg.Gateway.Scheduling.SnapshotMGetChunkSize + } + if cfg.Gateway.Scheduling.SnapshotWriteChunkSize > 0 { + writeChunkSize = cfg.Gateway.Scheduling.SnapshotWriteChunkSize + } + } + return newSchedulerCacheWithChunkSizes(rdb, mgetChunkSize, writeChunkSize) +} + // ProviderSet is the Wire provider set for all repositories var ProviderSet = wire.NewSet( NewUserRepository, @@ -92,7 +107,7 @@ var ProviderSet = wire.NewSet( NewRedeemCache, NewUpdateCache, NewGeminiTokenCache, - NewSchedulerCache, + ProvideSchedulerCache, NewSchedulerOutboxRepository, NewProxyLatencyCache, NewTotpCache, diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a4733649..8b0bdc2a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // antigravity 分组、强制平台模式或无分组使用单平台选择 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 - return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. @@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro localExcluded[account.ID] = struct{}{} // 排除此账号 continue // 重新选择 } - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } // 对于等待计划的情况,也需要先检查会话限制 @@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } } @@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } - return &AccountSelectionResult{ - Account: stickyAccount, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) } } @@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) } } @@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - WaitPlan: &AccountWaitPlan{ - AccountID: item.account.ID, - MaxConcurrency: item.account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } @@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + if s.cache != nil { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + } + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } } @@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 会话限制已满,继续到 Layer 2 // Session limit full, continue to Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } } @@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil { + return nil, legacyErr + } else if ok { return result, nil } } else { @@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) } } @@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, acc, sessionHash) { continue // 会话限制已满,尝试下一个账号 } - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } return nil, ErrNoAvailableAccounts } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, true + selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) + if err != nil { + return nil, false, err + } + return selection, true, nil } } - return nil, false + return nil, false, nil } func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } +func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) + } + return hydrated, nil +} + +func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err + } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} + // filterByMinPriority 过滤出优先级最小的账号集合 func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 32bf21c0..5a9490f3 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) } - return selected, nil + return s.hydrateSelectedAccount(ctx, selected) } // resolvePlatformAndSchedulingMode 解析目标平台和调度模式。 @@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, return s.accountRepo.GetByID(ctx, accountID) } +func (s *GeminiMessagesCompatService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected gemini account %d not found during hydration", account.ID) + } + return hydrated, nil +} + func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) { if s.schedulerSnapshot != nil { accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) @@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont if selected == nil { return nil, errors.New("no available Gemini accounts") } - return selected, nil + return s.hydrateSelectedAccount(ctx, selected) } func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 2623d773..dbc53869 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } - return selected, nil + return s.hydrateSelectedAccount(ctx, selected) } // tryStickySessionHit 尝试从粘性会话获取账号。 @@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } accounts, err := s.listSchedulableAccounts(ctx, groupID) @@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } } @@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } - return &AccountSelectionResult{ - Account: fresh, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) } } } else { @@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } - return &AccountSelectionResult{ - Account: fresh, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) } } } @@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { continue } - return &AccountSelectionResult{ - Account: fresh, - WaitPlan: &AccountWaitPlan{ - AccountID: fresh.ID, - MaxConcurrency: fresh.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } return nil, ErrNoAvailableAccounts @@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun return account, nil } +func (s *OpenAIGatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected openai account %d not found during hydration", account.ID) + } + return hydrated, nil +} + +func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err + } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} + func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { if s.cfg != nil { return s.cfg.Gateway.Scheduling diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go new file mode 100644 index 00000000..5c0b289b --- /dev/null +++ b/backend/internal/service/scheduler_snapshot_hydration_test.go @@ -0,0 +1,159 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" +) + +type snapshotHydrationCache struct { + snapshot []*Account + accounts map[int64]*Account +} + +func (c *snapshotHydrationCache) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { + return c.snapshot, true, nil +} + +func (c *snapshotHydrationCache) SetSnapshot(ctx context.Context, bucket SchedulerBucket, accounts []Account) error { + return nil +} + +func (c *snapshotHydrationCache) GetAccount(ctx context.Context, accountID int64) (*Account, error) { + if c.accounts == nil { + return nil, nil + } + return c.accounts[accountID], nil +} + +func (c *snapshotHydrationCache) SetAccount(ctx context.Context, account *Account) error { + return nil +} + +func (c *snapshotHydrationCache) DeleteAccount(ctx context.Context, accountID int64) error { + return nil +} + +func (c *snapshotHydrationCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return nil +} + +func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) { + return true, nil +} + +func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) { + return nil, nil +} + +func (c *snapshotHydrationCache) GetOutboxWatermark(ctx context.Context) (int64, error) { + return 0, nil +} + +func (c *snapshotHydrationCache) SetOutboxWatermark(ctx context.Context, id int64) error { + return nil +} + +func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) { + cache := &snapshotHydrationCache{ + snapshot: []*Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-4": "gpt-4", + }, + }, + }, + }, + accounts: map[int64]*Account{ + 1: { + ID: 1, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "api_key": "sk-live", + "model_mapping": map[string]any{"gpt-4": "gpt-4"}, + }, + }, + }, + } + + schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, nil, nil, nil) + groupID := int64(2) + svc := &OpenAIGatewayService{ + schedulerSnapshot: schedulerSnapshot, + cache: &stubGatewayCache{}, + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.Account == nil { + t.Fatalf("expected selected account") + } + if got := selection.Account.GetOpenAIApiKey(); got != "sk-live" { + t.Fatalf("expected hydrated api key, got %q", got) + } +} + +func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) { + cache := &snapshotHydrationCache{ + snapshot: []*Account{ + { + ID: 9, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + }, + }, + accounts: map[int64]*Account{ + 9: { + ID: 9, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "api_key": "anthropic-live-key", + }, + }, + }, + } + + schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, nil, nil, nil) + svc := &GatewayService{ + schedulerSnapshot: schedulerSnapshot, + cache: &mockGatewayCacheForPlatform{}, + cfg: testConfig(), + } + + result, err := svc.SelectAccountWithLoadAwareness(context.Background(), nil, "", "claude-3-5-sonnet-20241022", nil, "", 0) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if result == nil || result.Account == nil { + t.Fatalf("expected selected account") + } + if got := result.Account.GetCredential("api_key"); got != "anthropic-live-key" { + t.Fatalf("expected hydrated api key, got %q", got) + } +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 8f60acd5..45440761 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -347,6 +347,12 @@ gateway: # Enable batch load calculation for scheduling # 启用调度批量负载计算 load_batch_enabled: true + # Snapshot bucket MGET chunk size + # 调度快照分桶读取时的 MGET 分块大小 + snapshot_mget_chunk_size: 128 + # Snapshot bucket write chunk size + # 调度快照重建写入时的分块大小 + snapshot_write_chunk_size: 256 # Slot cleanup interval (duration) # 并发槽位清理周期(时间段) slot_cleanup_interval: 30s