diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index c4859383..e3498680 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.NewConcurrencyService(concurrencyCache) + concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) oAuthHandler := admin.NewOAuthHandler(oAuthService) @@ -128,10 +128,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index aeeddcb4..8c154a9d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "strings" + "time" "github.com/spf13/viper" ) @@ -119,6 +120,26 @@ type GatewayConfig struct { // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + + // Scheduling: 账号调度相关配置 + Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` +} + +// GatewaySchedulingConfig accounts scheduling configuration. +type GatewaySchedulingConfig struct { + // 粘性会话排队配置 + StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"` + StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"` + + // 兜底排队配置 + FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` + FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` + + // 负载计算 + LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + + // 过期槽位清理周期(0 表示禁用) + SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` } func (s *ServerConfig) Address() string { @@ -323,6 +344,12 @@ func setDefaults() { viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) + viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) + viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) + viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) + viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) + viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -411,6 +438,21 @@ func (c *Config) Validate() error { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") } + if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { + return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") + } + if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 { + return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive") + } + if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 { + return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive") + } + if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { + return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") + } + if c.Gateway.Scheduling.SlotCleanupInterval < 0 { + return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") + } return nil } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 1f1becb8..6e722a54 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,11 @@ package config -import "testing" +import ( + "testing" + "time" + + "github.com/spf13/viper" +) func TestNormalizeRunMode(t *testing.T) { tests := []struct { @@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) { } } } + +func TestLoadDefaultSchedulingConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 { + t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting) + } + if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second { + t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout) + } + if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second { + t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout) + } + if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 { + t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting) + } + if !cfg.Gateway.Scheduling.LoadBatchEnabled { + t.Fatalf("LoadBatchEnabled = false, want true") + } + if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second { + t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval) + } +} + +func TestLoadSchedulingConfigFromEnv(t *testing.T) { + viper.Reset() + t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 { + t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting) + } +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a2f833ff..769e6700 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -141,6 +141,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } else if apiKey.Group != nil { platform = apiKey.Group.Platform } + sessionKey := sessionHash + if platform == service.PlatformGemini && sessionHash != "" { + sessionKey = "gemini:" + sessionHash + } if platform == service.PlatformGemini { const maxAccountSwitches = 3 @@ -149,7 +153,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { lastFailoverStatus := 0 for { - account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -158,9 +162,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } + account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -170,11 +178,44 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // 转发请求 - 根据账号平台分流 @@ -187,6 +228,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -231,7 +275,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -240,9 +284,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } + account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -252,11 +300,44 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // 转发请求 - 根据账号平台分流 @@ -269,6 +350,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 4c7bd0f0..4e049dbb 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 h.concurrencyService.DecrementWaitCount(ctx, userID) } +// IncrementAccountWaitCount increments the wait count for an account +func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait) +} + +// DecrementAccountWaitCount decrements the wait count for an account +func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) { + h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) +} + // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait) + return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) +} + +// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. +func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() // Determine if ping is needed (streaming + ping format defined) @@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, } } +// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). +func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { + return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) +} + // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 4e99e00d..1959c0f3 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -197,13 +197,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) parsedReq, _ := service.ParseGatewayRequest(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + sessionKey := sessionHash + if sessionHash != "" { + sessionKey = "gemini:" + sessionHash + } const maxAccountSwitches = 3 switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 for { - account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) @@ -212,12 +216,46 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { handleGeminiFailoverExhausted(c, lastFailoverStatus) return } + account := selection.Account // 4) account concurrency slot - accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) - if err != nil { - googleError(c, http.StatusTooManyRequests, err.Error()) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts") + return + } + canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") + return + } + accountWaitRelease = func() { + geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + stream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + googleError(c, http.StatusTooManyRequests, err.Error()) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // 5) forward (根据平台分流) @@ -230,6 +268,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 7c9934c6..c6b969bc 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) if len(failedAccountIDs) == 0 { @@ -156,14 +156,48 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } + account := selection.Account log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) // 3. Acquire account concurrency slot - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // Forward request @@ -171,6 +205,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 9205230b..d8d6989b 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -2,7 +2,9 @@ package repository import ( "context" + "errors" "fmt" + "strconv" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" @@ -27,6 +29,8 @@ const ( userSlotKeyPrefix = "concurrency:user:" // 等待队列计数器格式: concurrency:wait:{userID} waitQueueKeyPrefix = "concurrency:wait:" + // 账号级等待队列计数器格式: wait:account:{accountID} + accountWaitKeyPrefix = "wait:account:" // 默认槽位过期时间(分钟),可通过配置覆盖 defaultSlotTTLMinutes = 15 @@ -112,33 +116,112 @@ var ( redis.call('EXPIRE', KEYS[1], ARGV[2]) end - return 1 - `) + return 1 + `) + + // incrementAccountWaitScript - account-level wait queue count + incrementAccountWaitScript = redis.NewScript(` + local current = redis.call('GET', KEYS[1]) + if current == false then + current = 0 + else + current = tonumber(current) + end + + if current >= tonumber(ARGV[1]) then + return 0 + end + + local newVal = redis.call('INCR', KEYS[1]) + + -- Only set TTL on first creation to avoid refreshing zombie data + if newVal == 1 then + redis.call('EXPIRE', KEYS[1], ARGV[2]) + end + + return 1 + `) // decrementWaitScript - same as before decrementWaitScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current ~= false and tonumber(current) > 0 then - redis.call('DECR', KEYS[1]) - end - return 1 - `) + local current = redis.call('GET', KEYS[1]) + if current ~= false and tonumber(current) > 0 then + redis.call('DECR', KEYS[1]) + end + return 1 + `) + + // getAccountsLoadBatchScript - batch load query (read-only) + // ARGV[1] = slot TTL (seconds, retained for compatibility) + // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... + getAccountsLoadBatchScript = redis.NewScript(` + local result = {} + + local i = 2 + while i <= #ARGV do + local accountID = ARGV[i] + local maxConcurrency = tonumber(ARGV[i + 1]) + + local slotKey = 'concurrency:account:' .. accountID + local currentConcurrency = redis.call('ZCARD', slotKey) + + local waitKey = 'wait:account:' .. accountID + local waitingCount = redis.call('GET', waitKey) + if waitingCount == false then + waitingCount = 0 + else + waitingCount = tonumber(waitingCount) + end + + local loadRate = 0 + if maxConcurrency > 0 then + loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) + end + + table.insert(result, accountID) + table.insert(result, currentConcurrency) + table.insert(result, waitingCount) + table.insert(result, loadRate) + + i = i + 2 + end + + return result + `) + + // cleanupExpiredSlotsScript - remove expired slots + // KEYS[1] = concurrency:account:{accountID} + // ARGV[1] = TTL (seconds) + cleanupExpiredSlotsScript = redis.NewScript(` + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl + return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + `) ) type concurrencyCache struct { - rdb *redis.Client - slotTTLSeconds int // 槽位过期时间(秒) + rdb *redis.Client + slotTTLSeconds int // 槽位过期时间(秒) + waitQueueTTLSeconds int // 等待队列过期时间(秒) } // NewConcurrencyCache 创建并发控制缓存 // slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟 -func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache { +// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL +func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache { if slotTTLMinutes <= 0 { slotTTLMinutes = defaultSlotTTLMinutes } + if waitQueueTTLSeconds <= 0 { + waitQueueTTLSeconds = slotTTLMinutes * 60 + } return &concurrencyCache{ - rdb: rdb, - slotTTLSeconds: slotTTLMinutes * 60, + rdb: rdb, + slotTTLSeconds: slotTTLMinutes * 60, + waitQueueTTLSeconds: waitQueueTTLSeconds, } } @@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string { return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) } +func accountWaitKey(accountID int64) string { + return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) +} + // Account slot operations func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { @@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() return err } + +// Account wait queue operations + +func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + key := accountWaitKey(accountID) + result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int() + if err != nil { + return false, err + } + return result == 1, nil +} + +func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + key := accountWaitKey(accountID) + _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() + return err +} + +func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + key := accountWaitKey(accountID) + val, err := c.rdb.Get(ctx, key).Int() + if err != nil && !errors.Is(err, redis.Nil) { + return 0, err + } + if errors.Is(err, redis.Nil) { + return 0, nil + } + return val, nil +} + +func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + if len(accounts) == 0 { + return map[int64]*service.AccountLoadInfo{}, nil + } + + args := []interface{}{c.slotTTLSeconds} + for _, acc := range accounts { + args = append(args, acc.ID, acc.MaxConcurrency) + } + + result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + if err != nil { + return nil, err + } + + loadMap := make(map[int64]*service.AccountLoadInfo) + for i := 0; i < len(result); i += 4 { + if i+3 >= len(result) { + break + } + + accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) + currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) + waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) + loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) + + loadMap[accountID] = &service.AccountLoadInfo{ + AccountID: accountID, + CurrentConcurrency: currentConcurrency, + WaitingCount: waitingCount, + LoadRate: loadRate, + } + } + + return loadMap, nil +} + +func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + key := accountSlotKey(accountID) + _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() + return err +} diff --git a/backend/internal/repository/concurrency_cache_benchmark_test.go b/backend/internal/repository/concurrency_cache_benchmark_test.go index cafab9cb..25697ab1 100644 --- a/backend/internal/repository/concurrency_cache_benchmark_test.go +++ b/backend/internal/repository/concurrency_cache_benchmark_test.go @@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) { _ = rdb.Close() }() - cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) + cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache) ctx := context.Background() for _, size := range []int{10, 100, 1000} { diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 6a7c83f4..f3d70ef1 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct { func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() - s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) + s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) } func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { @@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") } +func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { + accountID := int64(30) + waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) + + ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) + require.NoError(s.T(), err, "IncrementAccountWaitCount 1") + require.True(s.T(), ok) + + ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) + require.NoError(s.T(), err, "IncrementAccountWaitCount 2") + require.True(s.T(), ok) + + ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) + require.NoError(s.T(), err, "IncrementAccountWaitCount 3") + require.False(s.T(), ok, "expected account wait increment over max to fail") + + ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() + require.NoError(s.T(), err, "TTL account waitKey") + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) + + require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount") + + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } + require.Equal(s.T(), 1, val, "expected account wait count 1") +} + +func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { + accountID := int64(301) + waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) + + require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") + + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } + require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") +} + func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { // When no slots exist, GetAccountConcurrency should return 0 cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index edeaf782..f1a8d4cf 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -10,7 +10,14 @@ import ( // ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 // 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { - return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) + waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds()) + if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout { + waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds()) + } + if waitTTLSeconds <= 0 { + waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60 + } + return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds) } // ProviderSet is the Wire provider set for all repositories diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index b5229491..65ef16db 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -18,6 +18,11 @@ type ConcurrencyCache interface { ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) + // 账号等待队列(账号级) + IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) + DecrementAccountWaitCount(ctx context.Context, accountID int64) error + GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) + // 用户槽位管理 // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) @@ -27,6 +32,12 @@ type ConcurrencyCache interface { // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error + + // 批量负载查询(只读) + GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) + + // 清理过期槽位(后台任务) + CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } // generateRequestID generates a unique request ID for concurrency slot tracking @@ -61,6 +72,18 @@ type AcquireResult struct { ReleaseFunc func() // Must be called when done (typically via defer) } +type AccountWithConcurrency struct { + ID int64 + MaxConcurrency int +} + +type AccountLoadInfo struct { + AccountID int64 + CurrentConcurrency int + WaitingCount int + LoadRate int // 0-100+ (percent) +} + // AcquireAccountSlot attempts to acquire a concurrency slot for an account. // If the account is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. @@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 } } +// IncrementAccountWaitCount increments the wait queue counter for an account. +func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + if s.cache == nil { + return true, nil + } + + result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) + if err != nil { + log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err) + return true, nil + } + return result, nil +} + +// DecrementAccountWaitCount decrements the wait queue counter for an account. +func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) { + if s.cache == nil { + return + } + + bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { + log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err) + } +} + +// GetAccountWaitingCount gets current wait queue count for an account. +func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if s.cache == nil { + return 0, nil + } + return s.cache.GetAccountWaitingCount(ctx, accountID) +} + // CalculateMaxWait calculates the maximum wait queue size for a user // maxWait = userConcurrency + defaultExtraWaitSlots func CalculateMaxWait(userConcurrency int) int { @@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int { return userConcurrency + defaultExtraWaitSlots } +// GetAccountsLoadBatch returns load info for multiple accounts. +func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if s.cache == nil { + return map[int64]*AccountLoadInfo{}, nil + } + return s.cache.GetAccountsLoadBatch(ctx, accounts) +} + +// CleanupExpiredAccountSlots removes expired slots for one account (background task). +func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + if s.cache == nil { + return nil + } + return s.cache.CleanupExpiredAccountSlots(ctx, accountID) +} + +// StartSlotCleanupWorker starts a background cleanup worker for expired account slots. +func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) { + if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 { + return + } + + runCleanup := func() { + listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + accounts, err := accountRepo.ListSchedulable(listCtx) + cancel() + if err != nil { + log.Printf("Warning: list schedulable accounts failed: %v", err) + return + } + for _, account := range accounts { + accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second) + err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) + accountCancel() + if err != nil { + log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err) + } + } + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + runCleanup() + for range ticker.C { + runCleanup() + } + }() +} + // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index d779bcfa..e1b61632 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户") } +func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") +} + // TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户 func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) { ctx := context.Background() @@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ctx := context.Background() + t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") + }) + t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d542e9c2..6c45ff0f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -13,6 +13,7 @@ import ( "log" "net/http" "regexp" + "sort" "strings" "time" @@ -66,6 +67,20 @@ type GatewayCache interface { RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error } +type AccountWaitPlan struct { + AccountID int64 + MaxConcurrency int + Timeout time.Duration + MaxWaiting int +} + +type AccountSelectionResult struct { + Account *Account + Acquired bool + ReleaseFunc func() + WaitPlan *AccountWaitPlan // nil means no wait allowed +} + // ClaudeUsage 表示Claude API返回的usage信息 type ClaudeUsage struct { InputTokens int `json:"input_tokens"` @@ -108,6 +123,7 @@ type GatewayService struct { identityService *IdentityService httpUpstream HTTPUpstream deferredService *DeferredService + concurrencyService *ConcurrencyService } // NewGatewayService creates a new GatewayService @@ -119,6 +135,7 @@ func NewGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, + concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -134,6 +151,7 @@ func NewGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, + concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -183,6 +201,14 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return "" } +// BindStickySession sets session -> account binding with standard TTL. +func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 { + return nil + } + return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) +} + func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" @@ -332,8 +358,360 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } +// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + cfg := s.schedulingConfig() + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + if err != nil { + return nil, err + } + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) + if err != nil { + return nil, err + } + preferOAuth := platform == PlatformGemini + + accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, errors.New("no available accounts") + } + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // ============ Layer 1: 粘性会话优先 ============ + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + if err == nil && accountID > 0 && !isExcluded(accountID) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) && + account.IsSchedulable() && + (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + } + } + + // ============ Layer 2: 负载感知选择 ============ + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, errors.New("no available accounts") + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.Concurrency, + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { + return result, nil + } + } else { + type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo + } + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) > 0 { + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + if preferOAuth && a.account.Type != b.account.Type { + return a.account.Type == AccountTypeOAuth + } + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + + for _, item := range available { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } + } + + // ============ Layer 3: 兜底排队 ============ + sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) + for _, acc := range candidates { + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + return nil, errors.New("no available accounts") +} + +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) + + for _, acc := range ordered { + result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, true + } + } + + return nil, false +} + +func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} + +func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + return forcePlatform, true, nil + } + if groupID != nil { + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return "", false, fmt.Errorf("get group failed: %w", err) + } + return group.Platform, false, nil + } + return PlatformAnthropic, false, nil +} + +func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform + if useMixed { + platforms := []string{platform, PlatformAntigravity} + var accounts []Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } + if err != nil { + return nil, useMixed, err + } + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + filtered = append(filtered, acc) + } + return filtered, useMixed, nil + } + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) + if err == nil && len(accounts) == 0 && hasForcePlatform { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } + if err != nil { + return nil, useMixed, err + } + return accounts, useMixed, nil +} + +func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { + if account == nil { + return false + } + if useMixed { + if account.Platform == platform { + return true + } + return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() + } + return account.Platform == platform +} + +func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +func sortAccountsByPriority(accounts []*Account) { + sort.SliceStable(accounts, func(i, j int) bool { + return accounts[i].Priority < accounts[j].Priority + }) +} + +func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + switch { + case a.LastUsedAt == nil && b.LastUsedAt != nil: + return true + case a.LastUsedAt != nil && b.LastUsedAt == nil: + return false + case a.LastUsedAt == nil && b.LastUsedAt == nil: + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + default: + return a.LastUsedAt.Before(*b.LastUsedAt) + } + }) +} + // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + preferOAuth := platform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) @@ -389,7 +767,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - // keep selected (both never used) + if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc @@ -419,6 +799,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { platforms := []string{nativePlatform, PlatformAntigravity} + preferOAuth := nativePlatform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { @@ -478,7 +859,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - // keep selected (both never used) + if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 84e98679..f8eb29bd 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -13,6 +13,7 @@ import ( "log" "net/http" "regexp" + "sort" "strconv" "strings" "time" @@ -80,6 +81,7 @@ type OpenAIGatewayService struct { userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config + concurrencyService *ConcurrencyService billingService *BillingService rateLimitService *RateLimitService billingCacheService *BillingCacheService @@ -95,6 +97,7 @@ func NewOpenAIGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, + concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -108,6 +111,7 @@ func NewOpenAIGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, + concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -126,6 +130,14 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { return hex.EncodeToString(hash[:]) } +// BindStickySession sets session -> account binding with standard TTL. +func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 { + return nil + } + return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL) +} + // SelectAccount selects an OpenAI account with sticky session support func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") @@ -218,6 +230,254 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C return selected, nil } +// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. +func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + cfg := s.schedulingConfig() + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + if err != nil { + return nil, err + } + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + accounts, err := s.listSchedulableAccounts(ctx, groupID) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, errors.New("no available accounts") + } + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // ============ Layer 1: Sticky session ============ + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) + if err == nil && accountID > 0 && !isExcluded(accountID) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account.IsSchedulable() && account.IsOpenAI() && + (requestedModel == "" || account.IsModelSupported(requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + } + } + + // ============ Layer 2: Load-aware selection ============ + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, errors.New("no available accounts") + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.Concurrency, + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, false) + for _, acc := range ordered { + result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + } + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } else { + type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo + } + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) > 0 { + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + + for _, item := range available { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + } + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } + } + + // ============ Layer 3: Fallback wait ============ + sortAccountsByPriorityAndLastUsed(candidates, false) + for _, acc := range candidates { + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + return nil, errors.New("no available accounts") +} + +func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + return accounts, nil +} + +func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} + // GetAccessToken gets the access token for an OpenAI account func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 81e01d47..a202ccf2 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -73,6 +73,15 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh return svc } +// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. +func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { + svc := NewConcurrencyService(cache) + if cfg != nil { + svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) + } + return svc +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -107,7 +116,7 @@ var ProviderSet = wire.NewSet( ProvideEmailQueueService, NewTurnstileService, NewSubscriptionService, - NewConcurrencyService, + ProvideConcurrencyService, NewIdentityService, NewCRSSyncService, ProvideUpdateService,