Merge branch 'mt21625457/main'
This commit is contained in:
@@ -151,6 +151,7 @@ type GatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
@@ -169,6 +170,7 @@ func NewGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
@@ -185,6 +187,7 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
@@ -745,6 +748,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
||||
}
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
}
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
if useMixed {
|
||||
platforms := []string{platform, PlatformAntigravity}
|
||||
@@ -821,6 +827,13 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
a, b := accounts[i], accounts[j]
|
||||
@@ -851,7 +864,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
@@ -864,16 +877,11 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(单平台)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:忽略 groupID,查询所有可用账号
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
@@ -935,7 +943,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
|
||||
// 1. 查询粘性会话
|
||||
@@ -943,7 +950,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
@@ -958,13 +965,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ type GeminiMessagesCompatService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
cache GatewayCache
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
tokenProvider *GeminiTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
@@ -51,6 +52,7 @@ func NewGeminiMessagesCompatService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cache GatewayCache,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
tokenProvider *GeminiTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
@@ -61,6 +63,7 @@ func NewGeminiMessagesCompatService(
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
@@ -105,12 +108,6 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
var queryPlatforms []string
|
||||
if useMixedScheduling {
|
||||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
||||
} else {
|
||||
queryPlatforms = []string{platform}
|
||||
}
|
||||
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
@@ -118,7 +115,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
@@ -149,22 +146,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
|
||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
|
||||
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||
if len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
@@ -245,6 +236,31 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
queryPlatforms := []string{platform}
|
||||
if useMixedScheduling {
|
||||
queryPlatforms = []string{platform, PlatformAntigravity}
|
||||
}
|
||||
|
||||
if groupID != nil {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
@@ -266,13 +282,7 @@ func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (strin
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -288,13 +298,7 @@ func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context
|
||||
// 3) OAuth accounts explicitly marked as ai_studio
|
||||
// 4) Any remaining Gemini accounts (fallback)
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -85,6 +85,7 @@ type OpenAIGatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
@@ -101,6 +102,7 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
@@ -115,6 +117,7 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
@@ -159,7 +162,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
@@ -170,16 +173,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []Account
|
||||
var err error
|
||||
// 简易模式:忽略分组限制,查询所有可用账号
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
@@ -301,7 +295,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
@@ -446,6 +440,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
|
||||
return accounts, err
|
||||
}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
@@ -468,6 +466,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
|
||||
68
backend/internal/service/scheduler_cache.go
Normal file
68
backend/internal/service/scheduler_cache.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
SchedulerModeSingle = "single"
|
||||
SchedulerModeMixed = "mixed"
|
||||
SchedulerModeForced = "forced"
|
||||
)
|
||||
|
||||
type SchedulerBucket struct {
|
||||
GroupID int64
|
||||
Platform string
|
||||
Mode string
|
||||
}
|
||||
|
||||
func (b SchedulerBucket) String() string {
|
||||
return fmt.Sprintf("%d:%s:%s", b.GroupID, b.Platform, b.Mode)
|
||||
}
|
||||
|
||||
func ParseSchedulerBucket(raw string) (SchedulerBucket, bool) {
|
||||
parts := strings.Split(raw, ":")
|
||||
if len(parts) != 3 {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
groupID, err := strconv.ParseInt(parts[0], 10, 64)
|
||||
if err != nil {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
if parts[1] == "" || parts[2] == "" {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
return SchedulerBucket{
|
||||
GroupID: groupID,
|
||||
Platform: parts[1],
|
||||
Mode: parts[2],
|
||||
}, true
|
||||
}
|
||||
|
||||
// SchedulerCache 负责调度快照与账号快照的缓存读写。
|
||||
type SchedulerCache interface {
|
||||
// GetSnapshot 读取快照并返回命中与否(ready + active + 数据完整)。
|
||||
GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error)
|
||||
// SetSnapshot 写入快照并切换激活版本。
|
||||
SetSnapshot(ctx context.Context, bucket SchedulerBucket, accounts []Account) error
|
||||
// GetAccount 获取单账号快照。
|
||||
GetAccount(ctx context.Context, accountID int64) (*Account, error)
|
||||
// SetAccount 写入单账号快照(包含不可调度状态)。
|
||||
SetAccount(ctx context.Context, account *Account) error
|
||||
// DeleteAccount 删除单账号快照。
|
||||
DeleteAccount(ctx context.Context, accountID int64) error
|
||||
// UpdateLastUsed 批量更新账号的最后使用时间。
|
||||
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
// TryLockBucket 尝试获取分桶重建锁。
|
||||
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
|
||||
// ListBuckets 返回已注册的分桶集合。
|
||||
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
|
||||
// GetOutboxWatermark 读取 outbox 水位。
|
||||
GetOutboxWatermark(ctx context.Context) (int64, error)
|
||||
// SetOutboxWatermark 保存 outbox 水位。
|
||||
SetOutboxWatermark(ctx context.Context, id int64) error
|
||||
}
|
||||
10
backend/internal/service/scheduler_events.go
Normal file
10
backend/internal/service/scheduler_events.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package service
|
||||
|
||||
const (
|
||||
SchedulerOutboxEventAccountChanged = "account_changed"
|
||||
SchedulerOutboxEventAccountGroupsChanged = "account_groups_changed"
|
||||
SchedulerOutboxEventAccountBulkChanged = "account_bulk_changed"
|
||||
SchedulerOutboxEventAccountLastUsed = "account_last_used"
|
||||
SchedulerOutboxEventGroupChanged = "group_changed"
|
||||
SchedulerOutboxEventFullRebuild = "full_rebuild"
|
||||
)
|
||||
21
backend/internal/service/scheduler_outbox.go
Normal file
21
backend/internal/service/scheduler_outbox.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SchedulerOutboxEvent struct {
|
||||
ID int64
|
||||
EventType string
|
||||
AccountID *int64
|
||||
GroupID *int64
|
||||
Payload map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// SchedulerOutboxRepository 提供调度 outbox 的读取接口。
|
||||
type SchedulerOutboxRepository interface {
|
||||
ListAfter(ctx context.Context, afterID int64, limit int) ([]SchedulerOutboxEvent, error)
|
||||
MaxID(ctx context.Context) (int64, error)
|
||||
}
|
||||
786
backend/internal/service/scheduler_snapshot_service.go
Normal file
786
backend/internal/service/scheduler_snapshot_service.go
Normal file
@@ -0,0 +1,786 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSchedulerCacheNotReady = errors.New("scheduler cache not ready")
|
||||
ErrSchedulerFallbackLimited = errors.New("scheduler db fallback limited")
|
||||
)
|
||||
|
||||
const outboxEventTimeout = 2 * time.Minute
|
||||
|
||||
type SchedulerSnapshotService struct {
|
||||
cache SchedulerCache
|
||||
outboxRepo SchedulerOutboxRepository
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
cfg *config.Config
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
fallbackLimit *fallbackLimiter
|
||||
lagMu sync.Mutex
|
||||
lagFailures int
|
||||
}
|
||||
|
||||
func NewSchedulerSnapshotService(
|
||||
cache SchedulerCache,
|
||||
outboxRepo SchedulerOutboxRepository,
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cfg *config.Config,
|
||||
) *SchedulerSnapshotService {
|
||||
maxQPS := 0
|
||||
if cfg != nil {
|
||||
maxQPS = cfg.Gateway.Scheduling.DbFallbackMaxQPS
|
||||
}
|
||||
return &SchedulerSnapshotService{
|
||||
cache: cache,
|
||||
outboxRepo: outboxRepo,
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: cfg,
|
||||
stopCh: make(chan struct{}),
|
||||
fallbackLimit: newFallbackLimiter(maxQPS),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) Start() {
|
||||
if s == nil || s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runInitialRebuild()
|
||||
}()
|
||||
|
||||
interval := s.outboxPollInterval()
|
||||
if s.outboxRepo != nil && interval > 0 {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runOutboxWorker(interval)
|
||||
}()
|
||||
}
|
||||
|
||||
fullInterval := s.fullRebuildInterval()
|
||||
if fullInterval > 0 {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runFullRebuildWorker(fullInterval)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
mode := s.resolveMode(platform, hasForcePlatform)
|
||||
bucket := s.bucketFor(groupID, platform, mode)
|
||||
|
||||
if s.cache != nil {
|
||||
cached, hit, err := s.cache.GetSnapshot(ctx, bucket)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err)
|
||||
} else if hit {
|
||||
return derefAccounts(cached), useMixed, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.guardFallback(ctx); err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
|
||||
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := s.loadAccountsFromDB(fallbackCtx, bucket, useMixed)
|
||||
if err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if s.cache != nil {
|
||||
account, err := s.cache.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err)
|
||||
} else if account != nil {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.guardFallback(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.accountRepo.GetByID(fallbackCtx, accountID)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runInitialRebuild() {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil {
|
||||
log.Printf("[Scheduler] rebuild startup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runOutboxWorker(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.pollOutbox()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.pollOutbox()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.triggerFullRebuild("interval"); err != nil {
|
||||
log.Printf("[Scheduler] full rebuild failed: %v", err)
|
||||
}
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) pollOutbox() {
|
||||
if s.outboxRepo == nil || s.cache == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
watermark, err := s.cache.GetOutboxWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
events, err := s.outboxRepo.ListAfter(ctx, watermark, 200)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox poll failed: %v", err)
|
||||
return
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
watermarkForCheck := watermark
|
||||
for _, event := range events {
|
||||
eventCtx, cancel := context.WithTimeout(context.Background(), outboxEventTimeout)
|
||||
err := s.handleOutboxEvent(eventCtx, event)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
lastID := events[len(events)-1].ID
|
||||
if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark write failed: %v", err)
|
||||
} else {
|
||||
watermarkForCheck = lastID
|
||||
}
|
||||
|
||||
s.checkOutboxLag(ctx, events[0], watermarkForCheck)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent) error {
|
||||
switch event.EventType {
|
||||
case SchedulerOutboxEventAccountLastUsed:
|
||||
return s.handleLastUsedEvent(ctx, event.Payload)
|
||||
case SchedulerOutboxEventAccountBulkChanged:
|
||||
return s.handleBulkAccountEvent(ctx, event.Payload)
|
||||
case SchedulerOutboxEventAccountGroupsChanged:
|
||||
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
|
||||
case SchedulerOutboxEventAccountChanged:
|
||||
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
|
||||
case SchedulerOutboxEventGroupChanged:
|
||||
return s.handleGroupEvent(ctx, event.GroupID)
|
||||
case SchedulerOutboxEventFullRebuild:
|
||||
return s.triggerFullRebuild("outbox")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleLastUsedEvent(ctx context.Context, payload map[string]any) error {
|
||||
if s.cache == nil || payload == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := payload["last_used"].(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
updates := make(map[int64]time.Time, len(raw))
|
||||
for key, value := range raw {
|
||||
id, err := strconv.ParseInt(key, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
continue
|
||||
}
|
||||
sec, ok := toInt64(value)
|
||||
if !ok || sec <= 0 {
|
||||
continue
|
||||
}
|
||||
updates[id] = time.Unix(sec, 0)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.UpdateLastUsed(ctx, updates)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any) error {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
ids := parseInt64Slice(payload["account_ids"])
|
||||
for _, id := range ids {
|
||||
if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
|
||||
if accountID == nil || *accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
if s.accountRepo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var groupIDs []int64
|
||||
if payload != nil {
|
||||
groupIDs = parseInt64Slice(payload["group_ids"])
|
||||
}
|
||||
|
||||
account, err := s.accountRepo.GetByID(ctx, *accountID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAccountNotFound) {
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeleteAccount(ctx, *accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
groupIDs = account.GroupIDs
|
||||
}
|
||||
return s.rebuildByAccount(ctx, account, groupIDs, "account_change")
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64) error {
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupIDs := []int64{*groupID}
|
||||
return s.rebuildByGroupIDs(ctx, groupIDs, "group_change")
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string) error {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
groupIDs = s.normalizeGroupIDs(groupIDs)
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string) error {
|
||||
groupIDs = s.normalizeGroupIDs(groupIDs)
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string) error {
|
||||
if platform == "" {
|
||||
return nil
|
||||
}
|
||||
var firstErr error
|
||||
for _, gid := range groupIDs {
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeSingle}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeForced}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeMixed}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBuckets(ctx context.Context, buckets []SchedulerBucket, reason string) error {
|
||||
var firstErr error
|
||||
for _, bucket := range buckets {
|
||||
if err := s.rebuildBucket(ctx, bucket, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket SchedulerBucket, reason string) error {
|
||||
if s.cache == nil {
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
ok, err := s.cache.TryLockBucket(ctx, bucket, 30*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("[Scheduler] rebuild ok: bucket=%s reason=%s size=%d", bucket.String(), reason, len(accounts))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error {
|
||||
if s.cache == nil {
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.rebuildBuckets(ctx, buckets, reason)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest SchedulerOutboxEvent, watermark int64) {
|
||||
if oldest.CreatedAt.IsZero() || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lag := time.Since(oldest.CreatedAt)
|
||||
if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 {
|
||||
log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds)
|
||||
}
|
||||
|
||||
if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds {
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures++
|
||||
failures := s.lagFailures
|
||||
s.lagMu.Unlock()
|
||||
|
||||
if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures {
|
||||
log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures)
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures = 0
|
||||
s.lagMu.Unlock()
|
||||
if err := s.triggerFullRebuild("outbox_lag"); err != nil {
|
||||
log.Printf("[Scheduler] outbox lag rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures = 0
|
||||
s.lagMu.Unlock()
|
||||
}
|
||||
|
||||
threshold := s.cfg.Gateway.Scheduling.OutboxBacklogRebuildRows
|
||||
if threshold <= 0 || s.outboxRepo == nil {
|
||||
return
|
||||
}
|
||||
maxID, err := s.outboxRepo.MaxID(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if maxID-watermark >= int64(threshold) {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark)
|
||||
if err := s.triggerFullRebuild("outbox_backlog"); err != nil {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucket SchedulerBucket, useMixed bool) ([]Account, error) {
|
||||
if s.accountRepo == nil {
|
||||
return nil, ErrSchedulerCacheNotReady
|
||||
}
|
||||
groupID := bucket.GroupID
|
||||
if s.isRunModeSimple() {
|
||||
groupID = 0
|
||||
}
|
||||
|
||||
if useMixed {
|
||||
platforms := []string{bucket.Platform, PlatformAntigravity}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID > 0 {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
if groupID > 0 {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {
|
||||
return SchedulerBucket{
|
||||
GroupID: s.normalizeGroupID(groupID),
|
||||
Platform: platform,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) normalizeGroupID(groupID *int64) int64 {
|
||||
if s.isRunModeSimple() {
|
||||
return 0
|
||||
}
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) normalizeGroupIDs(groupIDs []int64) []int64 {
|
||||
if s.isRunModeSimple() {
|
||||
return []int64{0}
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return []int64{0}
|
||||
}
|
||||
seen := make(map[int64]struct{}, len(groupIDs))
|
||||
out := make([]int64, 0, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []int64{0}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) resolveMode(platform string, hasForcePlatform bool) string {
|
||||
if hasForcePlatform {
|
||||
return SchedulerModeForced
|
||||
}
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
return SchedulerModeMixed
|
||||
}
|
||||
return SchedulerModeSingle
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) guardFallback(ctx context.Context) error {
|
||||
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackEnabled {
|
||||
if s.fallbackLimit == nil || s.fallbackLimit.Allow() {
|
||||
return nil
|
||||
}
|
||||
return ErrSchedulerFallbackLimited
|
||||
}
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) withFallbackTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds <= 0 {
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
timeout := time.Duration(s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds) * time.Second
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
if remaining < timeout {
|
||||
timeout = remaining
|
||||
}
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) isRunModeSimple() bool {
|
||||
return s.cfg != nil && s.cfg.RunMode == config.RunModeSimple
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) outboxPollInterval() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return time.Second
|
||||
}
|
||||
sec := s.cfg.Gateway.Scheduling.OutboxPollIntervalSeconds
|
||||
if sec <= 0 {
|
||||
return time.Second
|
||||
}
|
||||
return time.Duration(sec) * time.Second
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return 0
|
||||
}
|
||||
sec := s.cfg.Gateway.Scheduling.FullRebuildIntervalSeconds
|
||||
if sec <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(sec) * time.Second
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeMixed})
|
||||
}
|
||||
}
|
||||
|
||||
if s.isRunModeSimple() || s.groupRepo == nil {
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
|
||||
groups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
for _, group := range groups {
|
||||
if group.Platform == "" {
|
||||
continue
|
||||
}
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeForced})
|
||||
if group.Platform == PlatformAnthropic || group.Platform == PlatformGemini {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeMixed})
|
||||
}
|
||||
}
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
|
||||
func dedupeBuckets(in []SchedulerBucket) []SchedulerBucket {
|
||||
seen := make(map[string]struct{}, len(in))
|
||||
out := make([]SchedulerBucket, 0, len(in))
|
||||
for _, bucket := range in {
|
||||
key := bucket.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, bucket)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func derefAccounts(accounts []*Account) []Account {
|
||||
if len(accounts) == 0 {
|
||||
return []Account{}
|
||||
}
|
||||
out := make([]Account, 0, len(accounts))
|
||||
for _, account := range accounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, *account)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseInt64Slice(value any) []int64 {
|
||||
raw, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]int64, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
if v, ok := toInt64(item); ok && v > 0 {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toInt64(value any) (int64, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int64(v), true
|
||||
case int64:
|
||||
return v, true
|
||||
case int:
|
||||
return int64(v), true
|
||||
case json.Number:
|
||||
parsed, err := strconv.ParseInt(v.String(), 10, 64)
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
type fallbackLimiter struct {
|
||||
maxQPS int
|
||||
mu sync.Mutex
|
||||
window time.Time
|
||||
count int
|
||||
}
|
||||
|
||||
func newFallbackLimiter(maxQPS int) *fallbackLimiter {
|
||||
if maxQPS <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &fallbackLimiter{
|
||||
maxQPS: maxQPS,
|
||||
window: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *fallbackLimiter) Allow() bool {
|
||||
if l == nil || l.maxQPS <= 0 {
|
||||
return true
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if now.Sub(l.window) >= time.Second {
|
||||
l.window = now
|
||||
l.count = 0
|
||||
}
|
||||
if l.count >= l.maxQPS {
|
||||
return false
|
||||
}
|
||||
l.count++
|
||||
return true
|
||||
}
|
||||
@@ -86,6 +86,19 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSchedulerSnapshotService creates and starts SchedulerSnapshotService.
|
||||
func ProvideSchedulerSnapshotService(
|
||||
cache SchedulerCache,
|
||||
outboxRepo SchedulerOutboxRepository,
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cfg *config.Config,
|
||||
) *SchedulerSnapshotService {
|
||||
svc := NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, groupRepo, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideRateLimitService creates RateLimitService with optional dependencies.
|
||||
func ProvideRateLimitService(
|
||||
accountRepo AccountRepository,
|
||||
@@ -217,6 +230,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
ProvideConcurrencyService,
|
||||
ProvideSchedulerSnapshotService,
|
||||
NewIdentityService,
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
|
||||
Reference in New Issue
Block a user