feat(scheduler): 引入调度快照缓存与 outbox 回放
- 调度热路径优先读 Redis 快照,保留分组排序语义 - outbox 回放 + 全量重建纠偏,失败重试不推进水位 - 自动 Atlas 基线对齐并同步调度配置示例
This commit is contained in:
@@ -67,6 +67,7 @@ func provideCleanup(
|
||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||
opsCleanup *service.OpsCleanupService,
|
||||
opsScheduledReport *service.OpsScheduledReportService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
pricing *service.PricingService,
|
||||
@@ -116,6 +117,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"SchedulerSnapshotService", func() error {
|
||||
if schedulerSnapshot != nil {
|
||||
schedulerSnapshot.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"TokenRefreshService", func() error {
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
|
||||
@@ -111,6 +111,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
oAuthHandler := admin.NewOAuthHandler(oAuthService)
|
||||
@@ -130,9 +133,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
identityCache := repository.NewIdentityCache(redisClient)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
|
||||
opsHandler := admin.NewOpsHandler(opsService)
|
||||
@@ -164,7 +167,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -194,6 +197,7 @@ func provideCleanup(
|
||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||
opsCleanup *service.OpsCleanupService,
|
||||
opsScheduledReport *service.OpsScheduledReportService,
|
||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
accountExpiry *service.AccountExpiryService,
|
||||
pricing *service.PricingService,
|
||||
@@ -242,6 +246,12 @@ func provideCleanup(
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"SchedulerSnapshotService", func() error {
|
||||
if schedulerSnapshot != nil {
|
||||
schedulerSnapshot.Stop()
|
||||
}
|
||||
return nil
|
||||
}},
|
||||
{"TokenRefreshService", func() error {
|
||||
tokenRefresh.Stop()
|
||||
return nil
|
||||
|
||||
@@ -270,6 +270,29 @@ type GatewaySchedulingConfig struct {
|
||||
|
||||
// 过期槽位清理周期(0 表示禁用)
|
||||
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
|
||||
|
||||
// 受控回源配置
|
||||
DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"`
|
||||
// 受控回源超时(秒),0 表示不额外收紧超时
|
||||
DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"`
|
||||
// 受控回源限流(实例级 QPS),0 表示不限制
|
||||
DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"`
|
||||
|
||||
// Outbox 轮询与滞后阈值配置
|
||||
// Outbox 轮询周期(秒)
|
||||
OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"`
|
||||
// Outbox 滞后告警阈值(秒)
|
||||
OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"`
|
||||
// Outbox 触发强制重建阈值(秒)
|
||||
OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"`
|
||||
// Outbox 连续滞后触发次数
|
||||
OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"`
|
||||
// Outbox 积压触发重建阈值(行数)
|
||||
OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"`
|
||||
|
||||
// 全量重建周期配置
|
||||
// 全量重建周期(秒),0 表示禁用
|
||||
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
|
||||
}
|
||||
|
||||
func (s *ServerConfig) Address() string {
|
||||
@@ -749,6 +772,15 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
|
||||
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
|
||||
viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0)
|
||||
viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1)
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5)
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10)
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
|
||||
// TokenRefresh
|
||||
@@ -1021,6 +1053,35 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.DbFallbackTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.db_fallback_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.DbFallbackMaxQPS < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.db_fallback_max_qps must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.OutboxPollIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.outbox_poll_interval_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.Scheduling.OutboxLagWarnSeconds < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.outbox_lag_warn_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.OutboxLagRebuildSeconds < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.OutboxLagRebuildFailures <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_failures must be positive")
|
||||
}
|
||||
if c.Gateway.Scheduling.OutboxBacklogRebuildRows < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.outbox_backlog_rebuild_rows must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.FullRebuildIntervalSeconds < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.full_rebuild_interval_seconds must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.OutboxLagWarnSeconds > 0 &&
|
||||
c.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 &&
|
||||
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
|
||||
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
|
||||
}
|
||||
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
||||
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -115,6 +116,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
account.ID = created.ID
|
||||
account.CreatedAt = created.CreatedAt
|
||||
account.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -341,10 +345,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
account.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 使用事务保证账号与关联分组的删除原子性
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
@@ -368,7 +379,12 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
return tx.Commit()
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -455,7 +471,18 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetLastUsedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := map[string]any{
|
||||
"last_used": map[string]int64{
|
||||
strconv.FormatInt(id, 10): now.Unix(),
|
||||
},
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
@@ -479,7 +506,18 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
||||
args = append(args, pq.Array(ids))
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastUsedPayload := make(map[string]int64, len(updates))
|
||||
for id, ts := range updates {
|
||||
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
|
||||
}
|
||||
payload := map[string]any{"last_used": lastUsedPayload}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
@@ -488,7 +526,13 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
SetStatus(service.StatusError).
|
||||
SetErrorMessage(errorMsg).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
|
||||
@@ -497,7 +541,14 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
SetGroupID(groupID).
|
||||
SetPriority(priority).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
@@ -507,7 +558,14 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
||||
dbaccountgroup.GroupIDEQ(groupID),
|
||||
).
|
||||
Exec(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
||||
@@ -528,6 +586,10 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
|
||||
}
|
||||
|
||||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 使用事务保证删除旧绑定与创建新绑定的原子性
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
@@ -568,7 +630,13 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
return tx.Commit()
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -672,7 +740,13 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
SetRateLimitedAt(now).
|
||||
SetRateLimitResetAt(resetAt).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
@@ -706,6 +780,9 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -714,7 +791,13 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetOverloadUntil(until).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
@@ -727,7 +810,13 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
AND deleted_at IS NULL
|
||||
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
||||
`, until, reason, id)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
@@ -739,7 +828,13 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
||||
WHERE id = $1
|
||||
AND deleted_at IS NULL
|
||||
`, id)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
@@ -749,7 +844,13 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
ClearRateLimitResetAt().
|
||||
ClearOverloadUntil().
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
@@ -770,6 +871,9 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -792,7 +896,13 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetSchedulable(schedulable).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
@@ -813,6 +923,11 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if rows > 0 {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -844,6 +959,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -928,6 +1046,12 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if rows > 0 {
|
||||
payload := map[string]any{"account_ids": ids}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -1170,6 +1294,54 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []
|
||||
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
|
||||
entries, err := r.client.AccountGroup.
|
||||
Query().
|
||||
Where(dbaccountgroup.AccountIDEQ(accountID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids := make([]int64, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
ids = append(ids, entry.GroupID)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func mergeGroupIDs(a []int64, b []int64) []int64 {
|
||||
seen := make(map[int64]struct{}, len(a)+len(b))
|
||||
out := make([]int64, 0, len(a)+len(b))
|
||||
for _, id := range a {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
for _, id := range b {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{"group_ids": groupIDs}
|
||||
}
|
||||
|
||||
func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
if m == nil {
|
||||
return nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
@@ -55,6 +56,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
groupIn.ID = created.ID
|
||||
groupIn.CreatedAt = created.CreatedAt
|
||||
groupIn.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
@@ -111,12 +115,21 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
}
|
||||
groupIn.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
@@ -246,6 +259,9 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
|
||||
return 0, err
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
@@ -353,6 +369,9 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
@@ -28,6 +28,23 @@ CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
);
|
||||
`
|
||||
|
||||
const atlasSchemaRevisionsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||
version TEXT PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
type INTEGER NOT NULL,
|
||||
applied INTEGER NOT NULL DEFAULT 0,
|
||||
total INTEGER NOT NULL DEFAULT 0,
|
||||
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
execution_time BIGINT NOT NULL DEFAULT 0,
|
||||
error TEXT NULL,
|
||||
error_stmt TEXT NULL,
|
||||
hash TEXT NOT NULL DEFAULT '',
|
||||
partial_hashes TEXT[] NULL,
|
||||
operator_version TEXT NULL
|
||||
);
|
||||
`
|
||||
|
||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||
@@ -94,6 +111,11 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
return fmt.Errorf("create schema_migrations: %w", err)
|
||||
}
|
||||
|
||||
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions)。
|
||||
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
@@ -172,6 +194,80 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check schema_migrations: %w", err)
|
||||
}
|
||||
if !hasLegacy {
|
||||
return nil
|
||||
}
|
||||
|
||||
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if !hasAtlas {
|
||||
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
|
||||
return fmt.Errorf("create atlas_schema_revisions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
|
||||
return fmt.Errorf("count atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("atlas baseline version: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.ExecContext(ctx, `
|
||||
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
|
||||
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
|
||||
`, version, description, 1, hash); err != nil {
|
||||
return fmt.Errorf("insert atlas baseline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
|
||||
var exists bool
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = $1
|
||||
)
|
||||
`, tableName).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
if len(files) == 0 {
|
||||
return "baseline", "baseline", "", nil
|
||||
}
|
||||
sort.Strings(files)
|
||||
name := files[len(files)-1]
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
hash := hex.EncodeToString(sum[:])
|
||||
version := strings.TrimSuffix(name, ".sql")
|
||||
return version, version, hash, nil
|
||||
}
|
||||
|
||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||
|
||||
276
backend/internal/repository/scheduler_cache.go
Normal file
276
backend/internal/repository/scheduler_cache.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerBucketSetKey = "sched:buckets"
|
||||
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
|
||||
schedulerAccountPrefix = "sched:acc:"
|
||||
schedulerActivePrefix = "sched:active:"
|
||||
schedulerReadyPrefix = "sched:ready:"
|
||||
schedulerVersionPrefix = "sched:ver:"
|
||||
schedulerSnapshotPrefix = "sched:"
|
||||
schedulerLockPrefix = "sched:lock:"
|
||||
)
|
||||
|
||||
type schedulerCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
|
||||
return &schedulerCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if readyVal != "1" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
|
||||
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []*service.Account{}, true, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
keys = append(keys, schedulerAccountKey(id))
|
||||
}
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
accounts := make([]*service.Account, 0, len(values))
|
||||
for _, val := range values {
|
||||
if val == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
account, err := decodeCachedAccount(val)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
accounts = append(accounts, account)
|
||||
}
|
||||
|
||||
return accounts, true, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
||||
|
||||
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
versionStr := strconv.FormatInt(version, 10)
|
||||
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, account := range accounts {
|
||||
payload, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
|
||||
}
|
||||
if len(accounts) > 0 {
|
||||
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||
members := make([]redis.Z, 0, len(accounts))
|
||||
for idx, account := range accounts {
|
||||
members = append(members, redis.Z{
|
||||
Score: float64(idx),
|
||||
Member: strconv.FormatInt(account.ID, 10),
|
||||
})
|
||||
}
|
||||
pipe.ZAdd(ctx, snapshotKey, members...)
|
||||
} else {
|
||||
pipe.Del(ctx, snapshotKey)
|
||||
}
|
||||
pipe.Set(ctx, activeKey, versionStr, 0)
|
||||
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
||||
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if oldActive != "" && oldActive != versionStr {
|
||||
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodeCachedAccount(val)
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
if account == nil || account.ID <= 0 {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
|
||||
return c.rdb.Set(ctx, key, payload, 0).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
if accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(updates))
|
||||
ids := make([]int64, 0, len(updates))
|
||||
for id := range updates {
|
||||
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
for i, val := range values {
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
account, err := decodeCachedAccount(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
account.LastUsedAt = ptrTime(updates[ids[i]])
|
||||
updated, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pipe.Set(ctx, keys[i], updated, 0)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]service.SchedulerBucket, 0, len(raw))
|
||||
for _, entry := range raw {
|
||||
bucket, ok := service.ParseSchedulerBucket(entry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, bucket)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
|
||||
}
|
||||
|
||||
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
|
||||
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||
}
|
||||
|
||||
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
|
||||
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
|
||||
}
|
||||
|
||||
func schedulerAccountKey(id string) string {
|
||||
return schedulerAccountPrefix + id
|
||||
}
|
||||
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
func decodeCachedAccount(val any) (*service.Account, error) {
|
||||
var payload []byte
|
||||
switch raw := val.(type) {
|
||||
case string:
|
||||
payload = []byte(raw)
|
||||
case []byte:
|
||||
payload = raw
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected account cache type: %T", val)
|
||||
}
|
||||
var account service.Account
|
||||
if err := json.Unmarshal(payload, &account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type schedulerOutboxRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||
return &schedulerOutboxRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *schedulerOutboxRepository) ListAfter(ctx context.Context, afterID int64, limit int) ([]service.SchedulerOutboxEvent, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, event_type, account_id, group_id, payload, created_at
|
||||
FROM scheduler_outbox
|
||||
WHERE id > $1
|
||||
ORDER BY id ASC
|
||||
LIMIT $2
|
||||
`, afterID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
events := make([]service.SchedulerOutboxEvent, 0, limit)
|
||||
for rows.Next() {
|
||||
var (
|
||||
payloadRaw []byte
|
||||
accountID sql.NullInt64
|
||||
groupID sql.NullInt64
|
||||
event service.SchedulerOutboxEvent
|
||||
)
|
||||
if err := rows.Scan(&event.ID, &event.EventType, &accountID, &groupID, &payloadRaw, &event.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accountID.Valid {
|
||||
v := accountID.Int64
|
||||
event.AccountID = &v
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
event.GroupID = &v
|
||||
}
|
||||
if len(payloadRaw) > 0 {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
event.Payload = payload
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (r *schedulerOutboxRepository) MaxID(ctx context.Context) (int64, error) {
|
||||
var maxID int64
|
||||
if err := r.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox").Scan(&maxID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return maxID, nil
|
||||
}
|
||||
|
||||
func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType string, accountID *int64, groupID *int64, payload any) error {
|
||||
if exec == nil {
|
||||
return nil
|
||||
}
|
||||
var payloadJSON []byte
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payloadJSON = encoded
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, eventType, accountID, groupID, payloadJSON)
|
||||
return err
|
||||
}
|
||||
@@ -66,6 +66,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewRedeemCache,
|
||||
NewUpdateCache,
|
||||
NewGeminiTokenCache,
|
||||
NewSchedulerCache,
|
||||
NewSchedulerOutboxRepository,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
|
||||
@@ -151,6 +151,7 @@ type GatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
@@ -169,6 +170,7 @@ func NewGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
@@ -185,6 +187,7 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
@@ -745,6 +748,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
||||
}
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
}
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
if useMixed {
|
||||
platforms := []string{platform, PlatformAntigravity}
|
||||
@@ -821,6 +827,13 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
a, b := accounts[i], accounts[j]
|
||||
@@ -851,7 +864,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
@@ -864,16 +877,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(单平台)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:忽略 groupID,查询所有可用账号
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
@@ -935,7 +943,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
|
||||
// 1. 查询粘性会话
|
||||
@@ -943,7 +950,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
@@ -958,13 +965,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ type GeminiMessagesCompatService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
cache GatewayCache
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
tokenProvider *GeminiTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
@@ -51,6 +52,7 @@ func NewGeminiMessagesCompatService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cache GatewayCache,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
tokenProvider *GeminiTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
@@ -61,6 +63,7 @@ func NewGeminiMessagesCompatService(
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
@@ -105,12 +108,6 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
var queryPlatforms []string
|
||||
if useMixedScheduling {
|
||||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
||||
} else {
|
||||
queryPlatforms = []string{platform}
|
||||
}
|
||||
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
@@ -118,7 +115,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
@@ -149,22 +146,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
|
||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
|
||||
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||
if len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
@@ -245,6 +236,31 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
queryPlatforms := []string{platform}
|
||||
if useMixedScheduling {
|
||||
queryPlatforms = []string{platform, PlatformAntigravity}
|
||||
}
|
||||
|
||||
if groupID != nil {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
@@ -266,13 +282,7 @@ func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (strin
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -288,13 +298,7 @@ func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context
|
||||
// 3) OAuth accounts explicitly marked as ai_studio
|
||||
// 4) Any remaining Gemini accounts (fallback)
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -85,6 +85,7 @@ type OpenAIGatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
@@ -101,6 +102,7 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
@@ -115,6 +117,7 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
@@ -159,7 +162,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
@@ -170,16 +173,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []Account
|
||||
var err error
|
||||
// 简易模式:忽略分组限制,查询所有可用账号
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
@@ -301,7 +295,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
@@ -446,6 +440,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
|
||||
return accounts, err
|
||||
}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
@@ -468,6 +466,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
|
||||
68
backend/internal/service/scheduler_cache.go
Normal file
68
backend/internal/service/scheduler_cache.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
SchedulerModeSingle = "single"
|
||||
SchedulerModeMixed = "mixed"
|
||||
SchedulerModeForced = "forced"
|
||||
)
|
||||
|
||||
type SchedulerBucket struct {
|
||||
GroupID int64
|
||||
Platform string
|
||||
Mode string
|
||||
}
|
||||
|
||||
func (b SchedulerBucket) String() string {
|
||||
return fmt.Sprintf("%d:%s:%s", b.GroupID, b.Platform, b.Mode)
|
||||
}
|
||||
|
||||
func ParseSchedulerBucket(raw string) (SchedulerBucket, bool) {
|
||||
parts := strings.Split(raw, ":")
|
||||
if len(parts) != 3 {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
groupID, err := strconv.ParseInt(parts[0], 10, 64)
|
||||
if err != nil {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
if parts[1] == "" || parts[2] == "" {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
return SchedulerBucket{
|
||||
GroupID: groupID,
|
||||
Platform: parts[1],
|
||||
Mode: parts[2],
|
||||
}, true
|
||||
}
|
||||
|
||||
// SchedulerCache 负责调度快照与账号快照的缓存读写。
|
||||
type SchedulerCache interface {
|
||||
// GetSnapshot 读取快照并返回命中与否(ready + active + 数据完整)。
|
||||
GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error)
|
||||
// SetSnapshot 写入快照并切换激活版本。
|
||||
SetSnapshot(ctx context.Context, bucket SchedulerBucket, accounts []Account) error
|
||||
// GetAccount 获取单账号快照。
|
||||
GetAccount(ctx context.Context, accountID int64) (*Account, error)
|
||||
// SetAccount 写入单账号快照(包含不可调度状态)。
|
||||
SetAccount(ctx context.Context, account *Account) error
|
||||
// DeleteAccount 删除单账号快照。
|
||||
DeleteAccount(ctx context.Context, accountID int64) error
|
||||
// UpdateLastUsed 批量更新账号的最后使用时间。
|
||||
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
// TryLockBucket 尝试获取分桶重建锁。
|
||||
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
|
||||
// ListBuckets 返回已注册的分桶集合。
|
||||
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
|
||||
// GetOutboxWatermark 读取 outbox 水位。
|
||||
GetOutboxWatermark(ctx context.Context) (int64, error)
|
||||
// SetOutboxWatermark 保存 outbox 水位。
|
||||
SetOutboxWatermark(ctx context.Context, id int64) error
|
||||
}
|
||||
10
backend/internal/service/scheduler_events.go
Normal file
10
backend/internal/service/scheduler_events.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package service
|
||||
|
||||
const (
|
||||
SchedulerOutboxEventAccountChanged = "account_changed"
|
||||
SchedulerOutboxEventAccountGroupsChanged = "account_groups_changed"
|
||||
SchedulerOutboxEventAccountBulkChanged = "account_bulk_changed"
|
||||
SchedulerOutboxEventAccountLastUsed = "account_last_used"
|
||||
SchedulerOutboxEventGroupChanged = "group_changed"
|
||||
SchedulerOutboxEventFullRebuild = "full_rebuild"
|
||||
)
|
||||
21
backend/internal/service/scheduler_outbox.go
Normal file
21
backend/internal/service/scheduler_outbox.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SchedulerOutboxEvent struct {
|
||||
ID int64
|
||||
EventType string
|
||||
AccountID *int64
|
||||
GroupID *int64
|
||||
Payload map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// SchedulerOutboxRepository 提供调度 outbox 的读取接口。
|
||||
type SchedulerOutboxRepository interface {
|
||||
ListAfter(ctx context.Context, afterID int64, limit int) ([]SchedulerOutboxEvent, error)
|
||||
MaxID(ctx context.Context) (int64, error)
|
||||
}
|
||||
786
backend/internal/service/scheduler_snapshot_service.go
Normal file
786
backend/internal/service/scheduler_snapshot_service.go
Normal file
@@ -0,0 +1,786 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSchedulerCacheNotReady = errors.New("scheduler cache not ready")
|
||||
ErrSchedulerFallbackLimited = errors.New("scheduler db fallback limited")
|
||||
)
|
||||
|
||||
const outboxEventTimeout = 2 * time.Minute
|
||||
|
||||
type SchedulerSnapshotService struct {
|
||||
cache SchedulerCache
|
||||
outboxRepo SchedulerOutboxRepository
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
cfg *config.Config
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
fallbackLimit *fallbackLimiter
|
||||
lagMu sync.Mutex
|
||||
lagFailures int
|
||||
}
|
||||
|
||||
func NewSchedulerSnapshotService(
|
||||
cache SchedulerCache,
|
||||
outboxRepo SchedulerOutboxRepository,
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cfg *config.Config,
|
||||
) *SchedulerSnapshotService {
|
||||
maxQPS := 0
|
||||
if cfg != nil {
|
||||
maxQPS = cfg.Gateway.Scheduling.DbFallbackMaxQPS
|
||||
}
|
||||
return &SchedulerSnapshotService{
|
||||
cache: cache,
|
||||
outboxRepo: outboxRepo,
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: cfg,
|
||||
stopCh: make(chan struct{}),
|
||||
fallbackLimit: newFallbackLimiter(maxQPS),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) Start() {
|
||||
if s == nil || s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runInitialRebuild()
|
||||
}()
|
||||
|
||||
interval := s.outboxPollInterval()
|
||||
if s.outboxRepo != nil && interval > 0 {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runOutboxWorker(interval)
|
||||
}()
|
||||
}
|
||||
|
||||
fullInterval := s.fullRebuildInterval()
|
||||
if fullInterval > 0 {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runFullRebuildWorker(fullInterval)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
mode := s.resolveMode(platform, hasForcePlatform)
|
||||
bucket := s.bucketFor(groupID, platform, mode)
|
||||
|
||||
if s.cache != nil {
|
||||
cached, hit, err := s.cache.GetSnapshot(ctx, bucket)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err)
|
||||
} else if hit {
|
||||
return derefAccounts(cached), useMixed, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.guardFallback(ctx); err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
|
||||
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := s.loadAccountsFromDB(fallbackCtx, bucket, useMixed)
|
||||
if err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if s.cache != nil {
|
||||
account, err := s.cache.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err)
|
||||
} else if account != nil {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.guardFallback(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.accountRepo.GetByID(fallbackCtx, accountID)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runInitialRebuild() {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil {
|
||||
log.Printf("[Scheduler] rebuild startup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runOutboxWorker(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.pollOutbox()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.pollOutbox()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.triggerFullRebuild("interval"); err != nil {
|
||||
log.Printf("[Scheduler] full rebuild failed: %v", err)
|
||||
}
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) pollOutbox() {
|
||||
if s.outboxRepo == nil || s.cache == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
watermark, err := s.cache.GetOutboxWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
events, err := s.outboxRepo.ListAfter(ctx, watermark, 200)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox poll failed: %v", err)
|
||||
return
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
watermarkForCheck := watermark
|
||||
for _, event := range events {
|
||||
eventCtx, cancel := context.WithTimeout(context.Background(), outboxEventTimeout)
|
||||
err := s.handleOutboxEvent(eventCtx, event)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
lastID := events[len(events)-1].ID
|
||||
if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark write failed: %v", err)
|
||||
} else {
|
||||
watermarkForCheck = lastID
|
||||
}
|
||||
|
||||
s.checkOutboxLag(ctx, events[0], watermarkForCheck)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent) error {
|
||||
switch event.EventType {
|
||||
case SchedulerOutboxEventAccountLastUsed:
|
||||
return s.handleLastUsedEvent(ctx, event.Payload)
|
||||
case SchedulerOutboxEventAccountBulkChanged:
|
||||
return s.handleBulkAccountEvent(ctx, event.Payload)
|
||||
case SchedulerOutboxEventAccountGroupsChanged:
|
||||
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
|
||||
case SchedulerOutboxEventAccountChanged:
|
||||
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
|
||||
case SchedulerOutboxEventGroupChanged:
|
||||
return s.handleGroupEvent(ctx, event.GroupID)
|
||||
case SchedulerOutboxEventFullRebuild:
|
||||
return s.triggerFullRebuild("outbox")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleLastUsedEvent(ctx context.Context, payload map[string]any) error {
|
||||
if s.cache == nil || payload == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := payload["last_used"].(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
updates := make(map[int64]time.Time, len(raw))
|
||||
for key, value := range raw {
|
||||
id, err := strconv.ParseInt(key, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
continue
|
||||
}
|
||||
sec, ok := toInt64(value)
|
||||
if !ok || sec <= 0 {
|
||||
continue
|
||||
}
|
||||
updates[id] = time.Unix(sec, 0)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.UpdateLastUsed(ctx, updates)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any) error {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
ids := parseInt64Slice(payload["account_ids"])
|
||||
for _, id := range ids {
|
||||
if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
|
||||
if accountID == nil || *accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
if s.accountRepo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var groupIDs []int64
|
||||
if payload != nil {
|
||||
groupIDs = parseInt64Slice(payload["group_ids"])
|
||||
}
|
||||
|
||||
account, err := s.accountRepo.GetByID(ctx, *accountID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAccountNotFound) {
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeleteAccount(ctx, *accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
groupIDs = account.GroupIDs
|
||||
}
|
||||
return s.rebuildByAccount(ctx, account, groupIDs, "account_change")
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64) error {
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupIDs := []int64{*groupID}
|
||||
return s.rebuildByGroupIDs(ctx, groupIDs, "group_change")
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string) error {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
groupIDs = s.normalizeGroupIDs(groupIDs)
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string) error {
|
||||
groupIDs = s.normalizeGroupIDs(groupIDs)
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string) error {
|
||||
if platform == "" {
|
||||
return nil
|
||||
}
|
||||
var firstErr error
|
||||
for _, gid := range groupIDs {
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeSingle}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeForced}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeMixed}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBuckets(ctx context.Context, buckets []SchedulerBucket, reason string) error {
|
||||
var firstErr error
|
||||
for _, bucket := range buckets {
|
||||
if err := s.rebuildBucket(ctx, bucket, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket SchedulerBucket, reason string) error {
|
||||
if s.cache == nil {
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
ok, err := s.cache.TryLockBucket(ctx, bucket, 30*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("[Scheduler] rebuild ok: bucket=%s reason=%s size=%d", bucket.String(), reason, len(accounts))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error {
|
||||
if s.cache == nil {
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.rebuildBuckets(ctx, buckets, reason)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest SchedulerOutboxEvent, watermark int64) {
|
||||
if oldest.CreatedAt.IsZero() || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lag := time.Since(oldest.CreatedAt)
|
||||
if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 {
|
||||
log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds)
|
||||
}
|
||||
|
||||
if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds {
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures++
|
||||
failures := s.lagFailures
|
||||
s.lagMu.Unlock()
|
||||
|
||||
if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures {
|
||||
log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures)
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures = 0
|
||||
s.lagMu.Unlock()
|
||||
if err := s.triggerFullRebuild("outbox_lag"); err != nil {
|
||||
log.Printf("[Scheduler] outbox lag rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures = 0
|
||||
s.lagMu.Unlock()
|
||||
}
|
||||
|
||||
threshold := s.cfg.Gateway.Scheduling.OutboxBacklogRebuildRows
|
||||
if threshold <= 0 || s.outboxRepo == nil {
|
||||
return
|
||||
}
|
||||
maxID, err := s.outboxRepo.MaxID(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if maxID-watermark >= int64(threshold) {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark)
|
||||
if err := s.triggerFullRebuild("outbox_backlog"); err != nil {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucket SchedulerBucket, useMixed bool) ([]Account, error) {
|
||||
if s.accountRepo == nil {
|
||||
return nil, ErrSchedulerCacheNotReady
|
||||
}
|
||||
groupID := bucket.GroupID
|
||||
if s.isRunModeSimple() {
|
||||
groupID = 0
|
||||
}
|
||||
|
||||
if useMixed {
|
||||
platforms := []string{bucket.Platform, PlatformAntigravity}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID > 0 {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
if groupID > 0 {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {
|
||||
return SchedulerBucket{
|
||||
GroupID: s.normalizeGroupID(groupID),
|
||||
Platform: platform,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) normalizeGroupID(groupID *int64) int64 {
|
||||
if s.isRunModeSimple() {
|
||||
return 0
|
||||
}
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) normalizeGroupIDs(groupIDs []int64) []int64 {
|
||||
if s.isRunModeSimple() {
|
||||
return []int64{0}
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return []int64{0}
|
||||
}
|
||||
seen := make(map[int64]struct{}, len(groupIDs))
|
||||
out := make([]int64, 0, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []int64{0}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) resolveMode(platform string, hasForcePlatform bool) string {
|
||||
if hasForcePlatform {
|
||||
return SchedulerModeForced
|
||||
}
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
return SchedulerModeMixed
|
||||
}
|
||||
return SchedulerModeSingle
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) guardFallback(ctx context.Context) error {
|
||||
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackEnabled {
|
||||
if s.fallbackLimit == nil || s.fallbackLimit.Allow() {
|
||||
return nil
|
||||
}
|
||||
return ErrSchedulerFallbackLimited
|
||||
}
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) withFallbackTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds <= 0 {
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
timeout := time.Duration(s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds) * time.Second
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
if remaining < timeout {
|
||||
timeout = remaining
|
||||
}
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) isRunModeSimple() bool {
|
||||
return s.cfg != nil && s.cfg.RunMode == config.RunModeSimple
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) outboxPollInterval() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return time.Second
|
||||
}
|
||||
sec := s.cfg.Gateway.Scheduling.OutboxPollIntervalSeconds
|
||||
if sec <= 0 {
|
||||
return time.Second
|
||||
}
|
||||
return time.Duration(sec) * time.Second
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return 0
|
||||
}
|
||||
sec := s.cfg.Gateway.Scheduling.FullRebuildIntervalSeconds
|
||||
if sec <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(sec) * time.Second
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeMixed})
|
||||
}
|
||||
}
|
||||
|
||||
if s.isRunModeSimple() || s.groupRepo == nil {
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
|
||||
groups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
for _, group := range groups {
|
||||
if group.Platform == "" {
|
||||
continue
|
||||
}
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeForced})
|
||||
if group.Platform == PlatformAnthropic || group.Platform == PlatformGemini {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeMixed})
|
||||
}
|
||||
}
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
|
||||
func dedupeBuckets(in []SchedulerBucket) []SchedulerBucket {
|
||||
seen := make(map[string]struct{}, len(in))
|
||||
out := make([]SchedulerBucket, 0, len(in))
|
||||
for _, bucket := range in {
|
||||
key := bucket.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, bucket)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func derefAccounts(accounts []*Account) []Account {
|
||||
if len(accounts) == 0 {
|
||||
return []Account{}
|
||||
}
|
||||
out := make([]Account, 0, len(accounts))
|
||||
for _, account := range accounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, *account)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseInt64Slice(value any) []int64 {
|
||||
raw, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]int64, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
if v, ok := toInt64(item); ok && v > 0 {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toInt64(value any) (int64, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int64(v), true
|
||||
case int64:
|
||||
return v, true
|
||||
case int:
|
||||
return int64(v), true
|
||||
case json.Number:
|
||||
parsed, err := strconv.ParseInt(v.String(), 10, 64)
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
type fallbackLimiter struct {
|
||||
maxQPS int
|
||||
mu sync.Mutex
|
||||
window time.Time
|
||||
count int
|
||||
}
|
||||
|
||||
func newFallbackLimiter(maxQPS int) *fallbackLimiter {
|
||||
if maxQPS <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &fallbackLimiter{
|
||||
maxQPS: maxQPS,
|
||||
window: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *fallbackLimiter) Allow() bool {
|
||||
if l == nil || l.maxQPS <= 0 {
|
||||
return true
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if now.Sub(l.window) >= time.Second {
|
||||
l.window = now
|
||||
l.count = 0
|
||||
}
|
||||
if l.count >= l.maxQPS {
|
||||
return false
|
||||
}
|
||||
l.count++
|
||||
return true
|
||||
}
|
||||
@@ -86,6 +86,19 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSchedulerSnapshotService creates and starts SchedulerSnapshotService.
|
||||
func ProvideSchedulerSnapshotService(
|
||||
cache SchedulerCache,
|
||||
outboxRepo SchedulerOutboxRepository,
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cfg *config.Config,
|
||||
) *SchedulerSnapshotService {
|
||||
svc := NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, groupRepo, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
|
||||
func ProvideOpsMetricsCollector(
|
||||
opsRepo OpsRepository,
|
||||
@@ -201,6 +214,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
ProvideConcurrencyService,
|
||||
ProvideSchedulerSnapshotService,
|
||||
NewIdentityService,
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
|
||||
10
backend/migrations/036_scheduler_outbox.sql
Normal file
10
backend/migrations/036_scheduler_outbox.sql
Normal file
@@ -0,0 +1,10 @@
|
||||
CREATE TABLE IF NOT EXISTS scheduler_outbox (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
account_id BIGINT NULL,
|
||||
group_id BIGINT NULL,
|
||||
payload JSONB NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_scheduler_outbox_created_at ON scheduler_outbox (created_at);
|
||||
Reference in New Issue
Block a user