diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1adabefe..83cba823 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -99,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) + concurrencyService := service.NewConcurrencyService(concurrencyCache) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) oAuthHandler := admin.NewOAuthHandler(oAuthService) @@ -127,10 +127,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 7927fec5..aeeddcb4 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -3,7 +3,6 @@ package config import ( "fmt" "strings" - "time" "github.com/spf13/viper" ) @@ -120,37 +119,6 @@ type GatewayConfig struct { // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` - - // 是否记录上游错误响应体摘要(避免输出请求内容) - LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` - // 上游错误响应体记录最大字节数(超过会截断) - LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"` - - // API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容) - InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"` - - // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) - FailoverOn400 bool `mapstructure:"failover_on_400"` - - // Scheduling: 账号调度相关配置 - Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` -} - -// GatewaySchedulingConfig accounts scheduling configuration. -type GatewaySchedulingConfig struct { - // 粘性会话排队配置 - StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"` - StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"` - - // 兜底排队配置 - FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` - FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` - - // 负载计算 - LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` - - // 过期槽位清理周期(0 表示禁用) - SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` } func (s *ServerConfig) Address() string { @@ -345,10 +313,6 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 - viper.SetDefault("gateway.log_upstream_error_body", false) - viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) - viper.SetDefault("gateway.inject_beta_for_apikey", false) - viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) @@ -359,12 +323,6 @@ func setDefaults() { viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) - viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) - viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) - viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) - viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) - viper.SetDefault("gateway.scheduling.load_batch_enabled", true) - viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -453,21 +411,6 @@ func (c *Config) Validate() error { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") } - if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { - return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") - } - if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 { - return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive") - } - if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 { - return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive") - } - if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { - return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") - } - if c.Gateway.Scheduling.SlotCleanupInterval < 0 { - return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") - } return nil } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 6e722a54..1f1becb8 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,11 +1,6 @@ package config -import ( - "testing" - "time" - - "github.com/spf13/viper" -) +import "testing" func TestNormalizeRunMode(t *testing.T) { tests := []struct { @@ -26,45 +21,3 @@ func TestNormalizeRunMode(t *testing.T) { } } } - -func TestLoadDefaultSchedulingConfig(t *testing.T) { - viper.Reset() - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 { - t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting) - } - if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second { - t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout) - } - if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second { - t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout) - } - if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 { - t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting) - } - if !cfg.Gateway.Scheduling.LoadBatchEnabled { - t.Fatalf("LoadBatchEnabled = false, want true") - } - if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second { - t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval) - } -} - -func TestLoadSchedulingConfigFromEnv(t *testing.T) { - viper.Reset() - t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 { - t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting) - } -} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 70b42ffe..a2f833ff 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -141,10 +141,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } else if apiKey.Group != nil { platform = apiKey.Group.Platform } - sessionKey := sessionHash - if platform == service.PlatformGemini && sessionHash != "" { - sessionKey = "gemini:" + sessionHash - } if platform == service.PlatformGemini { const maxAccountSwitches = 3 @@ -153,7 +149,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { lastFailoverStatus := 0 for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -162,13 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } - account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -178,46 +170,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return } // 转发请求 - 根据账号平台分流 @@ -230,9 +187,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -277,7 +231,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -286,13 +240,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } - account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -302,46 +252,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return } // 转发请求 - 根据账号平台分流 @@ -354,9 +269,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 4e049dbb..4c7bd0f0 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -83,16 +83,6 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 h.concurrencyService.DecrementWaitCount(ctx, userID) } -// IncrementAccountWaitCount increments the wait count for an account -func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait) -} - -// DecrementAccountWaitCount decrements the wait count for an account -func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) { - h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) -} - // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -136,12 +126,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) -} - -// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. -func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait) defer cancel() // Determine if ping is needed (streaming + ping format defined) @@ -215,11 +200,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType } } -// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). -func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) -} - // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 93ab23c9..4e99e00d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -197,17 +197,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) parsedReq, _ := service.ParseGatewayRequest(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) - sessionKey := sessionHash - if sessionHash != "" { - sessionKey = "gemini:" + sessionHash - } const maxAccountSwitches = 3 switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs) + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) @@ -216,48 +212,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { handleGeminiFailoverExhausted(c, lastFailoverStatus) return } - account := selection.Account // 4) account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts") - return - } - canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - stream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - googleError(c, http.StatusTooManyRequests, err.Error()) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) + if err != nil { + googleError(c, http.StatusTooManyRequests, err.Error()) + return } // 5) forward (根据平台分流) @@ -270,9 +230,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 9931052d..7c9934c6 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) if len(failedAccountIDs) == 0 { @@ -156,50 +156,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } - account := selection.Account log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) // 3. Acquire account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return } // Forward request @@ -207,9 +171,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 34e6b1f4..01b805cd 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -54,9 +54,6 @@ type CustomToolSpec struct { InputSchema map[string]any `json:"input_schema"` } -// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格) -type ClaudeCustomToolSpec = CustomToolSpec - // SystemBlock system prompt 数组形式的元素 type SystemBlock struct { Type string `json:"type"` diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 83b87a32..e0b5b886 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) + // 检测是否启用 thinking + isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") - // 检测是否启用 thinking - requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - // 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等), - // 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。 - isThinkingEnabled := requestedThinkingEnabled && allowDummyThought - // 1. 构建 contents contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) if err != nil { @@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) // 3. 构建 generationConfig - reqForGen := claudeReq - if requestedThinkingEnabled && !allowDummyThought { - log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel) - // shallow copy to avoid mutating caller's request - clone := *claudeReq - clone.Thinking = nil - reqForGen = &clone - } - generationConfig := buildGenerationConfig(reqForGen) + generationConfig := buildGenerationConfig(claudeReq) // 4. 构建 tools tools := buildTools(claudeReq.Tools) @@ -159,9 +148,8 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT if !hasThoughtPart && len(parts) > 0 { // 在开头添加 dummy thinking block parts = append([]GeminiPart{{ - Text: "Thinking...", - Thought: true, - ThoughtSignature: dummyThoughtSignature, + Text: "Thinking...", + Thought: true, }}, parts...) } } @@ -183,34 +171,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures const dummyThoughtSignature = "skip_thought_signature_validator" -// isValidThoughtSignature 验证 thought signature 是否有效 -// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节 -func isValidThoughtSignature(signature string) bool { - // 空字符串无效 - if signature == "" { - return false - } - - // signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节) - // 参考 Claude API 文档和实际观察到的有效 signature - if len(signature) < 40 { - log.Printf("[Debug] Signature too short: len=%d", len(signature)) - return false - } - - // 检查是否是有效的 base64 字符 - // base64 字符集: A-Z, a-z, 0-9, +, /, = - for i, c := range signature { - if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && - (c < '0' || c > '9') && c != '+' && c != '/' && c != '=' { - log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c) - return false - } - } - - return true -} - // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { @@ -239,30 +199,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu } case "thinking": - if allowDummyThought { - // Gemini 模型可以使用 dummy signature - parts = append(parts, GeminiPart{ - Text: block.Thinking, - Thought: true, - ThoughtSignature: dummyThoughtSignature, - }) + part := GeminiPart{ + Text: block.Thinking, + Thought: true, + } + // 保留原有 signature(Claude 模型需要有效的 signature) + if block.Signature != "" { + part.ThoughtSignature = block.Signature + } else if !allowDummyThought { + // Claude 模型需要有效 signature,跳过无 signature 的 thinking block + log.Printf("Warning: skipping thinking block without signature for Claude model") continue + } else { + // Gemini 模型使用 dummy signature + part.ThoughtSignature = dummyThoughtSignature } - - // Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。 - signature := strings.TrimSpace(block.Signature) - if signature == "" || signature == dummyThoughtSignature { - log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)") - continue - } - if !isValidThoughtSignature(signature) { - log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature)) - } - parts = append(parts, GeminiPart{ - Text: block.Thinking, - Thought: true, - ThoughtSignature: signature, - }) + parts = append(parts, part) case "image": if block.Source != nil && block.Source.Type == "base64" { @@ -287,9 +239,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ID: block.ID, }, } - // 只有 Gemini 模型使用 dummy signature - // Claude 模型不设置 signature(避免验证问题) - if allowDummyThought { + // 保留原有 signature,或对 Gemini 模型使用 dummy signature + if block.Signature != "" { + part.ThoughtSignature = block.Signature + } else if allowDummyThought { part.ThoughtSignature = dummyThoughtSignature } parts = append(parts, part) @@ -433,9 +386,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl - for i, tool := range tools { + for _, tool := range tools { // 跳过无效工具名称 - if strings.TrimSpace(tool.Name) == "" { + if tool.Name == "" { log.Printf("Warning: skipping tool with empty name") continue } @@ -444,18 +397,10 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { var inputSchema map[string]any // 检查是否为 custom 类型工具 (MCP) - if tool.Type == "custom" { - if tool.Custom == nil || tool.Custom.InputSchema == nil { - log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name) - continue - } + if tool.Type == "custom" && tool.Custom != nil { + // Custom 格式: 从 custom 字段获取 description 和 input_schema description = tool.Custom.Description inputSchema = tool.Custom.InputSchema - - // 调试日志:记录 custom 工具的 schema - if schemaJSON, err := json.Marshal(inputSchema); err == nil { - log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON)) - } } else { // 标准格式: 从顶层字段获取 description = tool.Description @@ -464,6 +409,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 清理 JSON Schema params := cleanJSONSchema(inputSchema) + // 为 nil schema 提供默认值 if params == nil { params = map[string]any{ @@ -472,11 +418,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } } - // 调试日志:记录清理后的 schema - if paramsJSON, err := json.Marshal(params); err == nil { - log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON)) - } - funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, @@ -538,64 +479,31 @@ func cleanJSONSchema(schema map[string]any) map[string]any { } // excludedSchemaKeys 不支持的 schema 字段 -// 基于 Claude API (Vertex AI) 的实际支持情况 -// 支持: type, description, enum, properties, required, additionalProperties, items -// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段 var excludedSchemaKeys = map[string]bool{ - // 元 schema 字段 - "$schema": true, - "$id": true, - "$ref": true, - - // 字符串验证(Gemini 不支持) - "minLength": true, - "maxLength": true, - "pattern": true, - - // 数字验证(Claude API 通过 Vertex AI 不支持这些字段) - "minimum": true, - "maximum": true, - "exclusiveMinimum": true, - "exclusiveMaximum": true, - "multipleOf": true, - - // 数组验证(Claude API 通过 Vertex AI 不支持这些字段) - "uniqueItems": true, - "minItems": true, - "maxItems": true, - - // 组合 schema(Gemini 不支持) - "oneOf": true, - "anyOf": true, - "allOf": true, - "not": true, - "if": true, - "then": true, - "else": true, - "$defs": true, - "definitions": true, - - // 对象验证(仅保留 properties/required/additionalProperties) - "minProperties": true, - "maxProperties": true, - "patternProperties": true, - "propertyNames": true, - "dependencies": true, - "dependentSchemas": true, - "dependentRequired": true, - - // 其他不支持的字段 - "default": true, - "const": true, - "examples": true, - "deprecated": true, - "readOnly": true, - "writeOnly": true, - "contentMediaType": true, - "contentEncoding": true, - - // Claude 特有字段 - "strict": true, + "$schema": true, + "$id": true, + "$ref": true, + "additionalProperties": true, + "minLength": true, + "maxLength": true, + "minItems": true, + "maxItems": true, + "uniqueItems": true, + "minimum": true, + "maximum": true, + "exclusiveMinimum": true, + "exclusiveMaximum": true, + "pattern": true, + "format": true, + "default": true, + "strict": true, + "const": true, + "examples": true, + "deprecated": true, + "readOnly": true, + "writeOnly": true, + "contentMediaType": true, + "contentEncoding": true, } // cleanSchemaValue 递归清理 schema 值 @@ -615,31 +523,6 @@ func cleanSchemaValue(value any) any { continue } - // 特殊处理 format 字段:只保留 Gemini 支持的 format 值 - if k == "format" { - if formatStr, ok := val.(string); ok { - // Gemini 只支持 date-time, date, time - if formatStr == "date-time" || formatStr == "date" || formatStr == "time" { - result[k] = val - } - // 其他 format 值直接跳过 - } - continue - } - - // 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象 - if k == "additionalProperties" { - if boolVal, ok := val.(bool); ok { - result[k] = boolVal - log.Printf("[Debug] additionalProperties is bool: %v", boolVal) - } else { - // 如果是 schema 对象,转换为 false(更安全的默认值) - result[k] = false - log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val) - } - continue - } - // 递归清理所有值 result[k] = cleanSchemaValue(val) } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go deleted file mode 100644 index 56eebad0..00000000 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package antigravity - -import ( - "encoding/json" - "testing" -) - -// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 -func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { - tests := []struct { - name string - content string - allowDummyThought bool - expectedParts int - description string - }{ - { - name: "Claude model - skip thinking block without signature", - content: `[ - {"type": "text", "text": "Hello"}, - {"type": "thinking", "thinking": "Let me think...", "signature": ""}, - {"type": "text", "text": "World"} - ]`, - allowDummyThought: false, - expectedParts: 2, // 只有两个text block - description: "Claude模型应该跳过无signature的thinking block", - }, - { - name: "Claude model - keep thinking block with signature", - content: `[ - {"type": "text", "text": "Hello"}, - {"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"}, - {"type": "text", "text": "World"} - ]`, - allowDummyThought: false, - expectedParts: 3, // 三个block都保留 - description: "Claude模型应该保留有signature的thinking block", - }, - { - name: "Gemini model - use dummy signature", - content: `[ - {"type": "text", "text": "Hello"}, - {"type": "thinking", "thinking": "Let me think...", "signature": ""}, - {"type": "text", "text": "World"} - ]`, - allowDummyThought: true, - expectedParts: 3, // 三个block都保留,thinking使用dummy signature - description: "Gemini模型应该为无signature的thinking block使用dummy signature", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - toolIDToName := make(map[string]string) - parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought) - - if err != nil { - t.Fatalf("buildParts() error = %v", err) - } - - if len(parts) != tt.expectedParts { - t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts) - } - }) - } -} - -// TestBuildTools_CustomTypeTools 测试custom类型工具转换 -func TestBuildTools_CustomTypeTools(t *testing.T) { - tests := []struct { - name string - tools []ClaudeTool - expectedLen int - description string - }{ - { - name: "Standard tool format", - tools: []ClaudeTool{ - { - Name: "get_weather", - Description: "Get weather information", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - }, - }, - }, - }, - expectedLen: 1, - description: "标准工具格式应该正常转换", - }, - { - name: "Custom type tool (MCP format)", - tools: []ClaudeTool{ - { - Type: "custom", - Name: "mcp_tool", - Custom: &ClaudeCustomToolSpec{ - Description: "MCP tool description", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "param": map[string]any{"type": "string"}, - }, - }, - }, - }, - }, - expectedLen: 1, - description: "Custom类型工具应该从Custom字段读取description和input_schema", - }, - { - name: "Mixed standard and custom tools", - tools: []ClaudeTool{ - { - Name: "standard_tool", - Description: "Standard tool", - InputSchema: map[string]any{"type": "object"}, - }, - { - Type: "custom", - Name: "custom_tool", - Custom: &ClaudeCustomToolSpec{ - Description: "Custom tool", - InputSchema: map[string]any{"type": "object"}, - }, - }, - }, - expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations - description: "混合标准和custom工具应该都能正确转换", - }, - { - name: "Invalid custom tool - nil Custom field", - tools: []ClaudeTool{ - { - Type: "custom", - Name: "invalid_custom", - // Custom 为 nil - }, - }, - expectedLen: 0, // 应该被跳过 - description: "Custom字段为nil的custom工具应该被跳过", - }, - { - name: "Invalid custom tool - nil InputSchema", - tools: []ClaudeTool{ - { - Type: "custom", - Name: "invalid_custom", - Custom: &ClaudeCustomToolSpec{ - Description: "Invalid", - // InputSchema 为 nil - }, - }, - }, - expectedLen: 0, // 应该被跳过 - description: "InputSchema为nil的custom工具应该被跳过", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := buildTools(tt.tools) - - if len(result) != tt.expectedLen { - t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen) - } - - // 验证function declarations存在 - if len(result) > 0 && result[0].FunctionDeclarations != nil { - if len(result[0].FunctionDeclarations) != len(tt.tools) { - t.Errorf("%s: got %d function declarations, want %d", - tt.description, len(result[0].FunctionDeclarations), len(tt.tools)) - } - } - }) - } -} diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 0db3ed4a..97ad6c83 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -16,12 +16,6 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking -// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth) -const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming - -// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code) -const ApiKeyHaikuBetaHeader = BetaInterleavedThinking - // Claude Code 客户端默认请求头 var DefaultHeaders = map[string]string{ "User-Agent": "claude-cli/2.0.62 (external, cli)", diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 35296497..9205230b 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -2,9 +2,7 @@ package repository import ( "context" - "errors" "fmt" - "strconv" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" @@ -29,8 +27,6 @@ const ( userSlotKeyPrefix = "concurrency:user:" // 等待队列计数器格式: concurrency:wait:{userID} waitQueueKeyPrefix = "concurrency:wait:" - // 账号级等待队列计数器格式: wait:account:{accountID} - accountWaitKeyPrefix = "wait:account:" // 默认槽位过期时间(分钟),可通过配置覆盖 defaultSlotTTLMinutes = 15 @@ -116,112 +112,33 @@ var ( redis.call('EXPIRE', KEYS[1], ARGV[2]) end - return 1 - `) - - // incrementAccountWaitScript - account-level wait queue count - incrementAccountWaitScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current == false then - current = 0 - else - current = tonumber(current) - end - - if current >= tonumber(ARGV[1]) then - return 0 - end - - local newVal = redis.call('INCR', KEYS[1]) - - -- Only set TTL on first creation to avoid refreshing zombie data - if newVal == 1 then - redis.call('EXPIRE', KEYS[1], ARGV[2]) - end - - return 1 - `) + return 1 + `) // decrementWaitScript - same as before decrementWaitScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current ~= false and tonumber(current) > 0 then - redis.call('DECR', KEYS[1]) - end - return 1 - `) - - // getAccountsLoadBatchScript - batch load query (read-only) - // ARGV[1] = slot TTL (seconds, retained for compatibility) - // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... - getAccountsLoadBatchScript = redis.NewScript(` - local result = {} - - local i = 2 - while i <= #ARGV do - local accountID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:account:' .. accountID - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'wait:account:' .. accountID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, accountID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - - // cleanupExpiredSlotsScript - remove expired slots - // KEYS[1] = concurrency:account:{accountID} - // ARGV[1] = TTL (seconds) - cleanupExpiredSlotsScript = redis.NewScript(` - local key = KEYS[1] - local ttl = tonumber(ARGV[1]) - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) - `) + local current = redis.call('GET', KEYS[1]) + if current ~= false and tonumber(current) > 0 then + redis.call('DECR', KEYS[1]) + end + return 1 + `) ) type concurrencyCache struct { - rdb *redis.Client - slotTTLSeconds int // 槽位过期时间(秒) - waitQueueTTLSeconds int // 等待队列过期时间(秒) + rdb *redis.Client + slotTTLSeconds int // 槽位过期时间(秒) } // NewConcurrencyCache 创建并发控制缓存 // slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟 -// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL -func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache { +func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache { if slotTTLMinutes <= 0 { slotTTLMinutes = defaultSlotTTLMinutes } - if waitQueueTTLSeconds <= 0 { - waitQueueTTLSeconds = slotTTLMinutes * 60 - } return &concurrencyCache{ - rdb: rdb, - slotTTLSeconds: slotTTLMinutes * 60, - waitQueueTTLSeconds: waitQueueTTLSeconds, + rdb: rdb, + slotTTLSeconds: slotTTLMinutes * 60, } } @@ -238,10 +155,6 @@ func waitQueueKey(userID int64) string { return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) } -func accountWaitKey(accountID int64) string { - return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) -} - // Account slot operations func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { @@ -312,75 +225,3 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() return err } - -// Account wait queue operations - -func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - key := accountWaitKey(accountID) - result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int() - if err != nil { - return false, err - } - return result == 1, nil -} - -func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { - key := accountWaitKey(accountID) - _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() - return err -} - -func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - key := accountWaitKey(accountID) - val, err := c.rdb.Get(ctx, key).Int() - if err != nil && !errors.Is(err, redis.Nil) { - return 0, err - } - if errors.Is(err, redis.Nil) { - return 0, nil - } - return val, nil -} - -func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { - if len(accounts) == 0 { - return map[int64]*service.AccountLoadInfo{}, nil - } - - args := []any{c.slotTTLSeconds} - for _, acc := range accounts { - args = append(args, acc.ID, acc.MaxConcurrency) - } - - result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() - if err != nil { - return nil, err - } - - loadMap := make(map[int64]*service.AccountLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break - } - - accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[accountID] = &service.AccountLoadInfo{ - AccountID: accountID, - CurrentConcurrency: currentConcurrency, - WaitingCount: waitingCount, - LoadRate: loadRate, - } - } - - return loadMap, nil -} - -func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { - key := accountSlotKey(accountID) - _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() - return err -} diff --git a/backend/internal/repository/concurrency_cache_benchmark_test.go b/backend/internal/repository/concurrency_cache_benchmark_test.go index 25697ab1..cafab9cb 100644 --- a/backend/internal/repository/concurrency_cache_benchmark_test.go +++ b/backend/internal/repository/concurrency_cache_benchmark_test.go @@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) { _ = rdb.Close() }() - cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache) + cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) ctx := context.Background() for _, size := range []int{10, 100, 1000} { diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 5983c832..6a7c83f4 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct { func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() - s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) + s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) } func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { @@ -218,48 +218,6 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") } -func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { - accountID := int64(30) - waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) - - ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) - require.NoError(s.T(), err, "IncrementAccountWaitCount 1") - require.True(s.T(), ok) - - ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) - require.NoError(s.T(), err, "IncrementAccountWaitCount 2") - require.True(s.T(), ok) - - ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) - require.NoError(s.T(), err, "IncrementAccountWaitCount 3") - require.False(s.T(), ok, "expected account wait increment over max to fail") - - ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() - require.NoError(s.T(), err, "TTL account waitKey") - s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) - - require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount") - - val, err := s.rdb.Get(s.ctx, waitKey).Int() - if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err, "Get waitKey") - } - require.Equal(s.T(), 1, val, "expected account wait count 1") -} - -func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { - accountID := int64(301) - waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) - - require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") - - val, err := s.rdb.Get(s.ctx, waitKey).Int() - if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err, "Get waitKey") - } - require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") -} - func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { // When no slots exist, GetAccountConcurrency should return 0 cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) @@ -274,139 +232,6 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() { require.Equal(s.T(), 0, cur) } -func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() { - s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI") - // Setup: Create accounts with different load states - account1 := int64(100) - account2 := int64(101) - account3 := int64(102) - - // Account 1: 2/3 slots used, 1 waiting - ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5) - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Account 2: 1/2 slots used, 0 waiting - ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3") - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Account 3: 0/1 slots used, 0 waiting (idle) - - // Query batch load - accounts := []service.AccountWithConcurrency{ - {ID: account1, MaxConcurrency: 3}, - {ID: account2, MaxConcurrency: 2}, - {ID: account3, MaxConcurrency: 1}, - } - - loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts) - require.NoError(s.T(), err) - require.Len(s.T(), loadMap, 3) - - // Verify account1: (2 + 1) / 3 = 100% - load1 := loadMap[account1] - require.NotNil(s.T(), load1) - require.Equal(s.T(), account1, load1.AccountID) - require.Equal(s.T(), 2, load1.CurrentConcurrency) - require.Equal(s.T(), 1, load1.WaitingCount) - require.Equal(s.T(), 100, load1.LoadRate) - - // Verify account2: (1 + 0) / 2 = 50% - load2 := loadMap[account2] - require.NotNil(s.T(), load2) - require.Equal(s.T(), account2, load2.AccountID) - require.Equal(s.T(), 1, load2.CurrentConcurrency) - require.Equal(s.T(), 0, load2.WaitingCount) - require.Equal(s.T(), 50, load2.LoadRate) - - // Verify account3: (0 + 0) / 1 = 0% - load3 := loadMap[account3] - require.NotNil(s.T(), load3) - require.Equal(s.T(), account3, load3.AccountID) - require.Equal(s.T(), 0, load3.CurrentConcurrency) - require.Equal(s.T(), 0, load3.WaitingCount) - require.Equal(s.T(), 0, load3.LoadRate) -} - -func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() { - // Test with empty account list - loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{}) - require.NoError(s.T(), err) - require.Empty(s.T(), loadMap) -} - -func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() { - accountID := int64(200) - slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) - - // Acquire 3 slots - ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3") - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Verify 3 slots exist - cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) - require.NoError(s.T(), err) - require.Equal(s.T(), 3, cur) - - // Manually set old timestamps for req1 and req2 (simulate expired slots) - now := time.Now().Unix() - expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL - err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err() - require.NoError(s.T(), err) - err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err() - require.NoError(s.T(), err) - - // Run cleanup - err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID) - require.NoError(s.T(), err) - - // Verify only 1 slot remains (req3) - cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID) - require.NoError(s.T(), err) - require.Equal(s.T(), 1, cur) - - // Verify req3 still exists - members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result() - require.NoError(s.T(), err) - require.Len(s.T(), members, 1) - require.Equal(s.T(), "req3", members[0]) -} - -func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() { - accountID := int64(201) - - // Acquire 2 fresh slots - ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2") - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Run cleanup (should not remove anything) - err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID) - require.NoError(s.T(), err) - - // Verify both slots still exist - cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) - require.NoError(s.T(), err) - require.Equal(s.T(), 2, cur) -} - func TestConcurrencyCacheSuite(t *testing.T) { suite.Run(t, new(ConcurrencyCacheSuite)) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 0d579b23..2de2d1de 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -15,14 +15,7 @@ import ( // ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 // 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { - waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds()) - if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout { - waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds()) - } - if waitTTLSeconds <= 0 { - waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60 - } - return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds) + return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) } // ProviderSet is the Wire provider set for all repositories diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 5b3bf565..ae2976f8 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -358,15 +358,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } - // 调试:记录转换后的请求体(仅记录前 2000 字符) - if bodyJSON, err := json.Marshal(geminiBody); err == nil { - truncated := string(bodyJSON) - if len(truncated) > 2000 { - truncated = truncated[:2000] + "..." - } - log.Printf("[Debug] Transformed Gemini request: %s", truncated) - } - // 构建上游 action action := "generateContent" if claudeReq.Stream { diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 65ef16db..b5229491 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -18,11 +18,6 @@ type ConcurrencyCache interface { ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) - // 账号等待队列(账号级) - IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) - DecrementAccountWaitCount(ctx context.Context, accountID int64) error - GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) - // 用户槽位管理 // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) @@ -32,12 +27,6 @@ type ConcurrencyCache interface { // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error - - // 批量负载查询(只读) - GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) - - // 清理过期槽位(后台任务) - CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } // generateRequestID generates a unique request ID for concurrency slot tracking @@ -72,18 +61,6 @@ type AcquireResult struct { ReleaseFunc func() // Must be called when done (typically via defer) } -type AccountWithConcurrency struct { - ID int64 - MaxConcurrency int -} - -type AccountLoadInfo struct { - AccountID int64 - CurrentConcurrency int - WaitingCount int - LoadRate int // 0-100+ (percent) -} - // AcquireAccountSlot attempts to acquire a concurrency slot for an account. // If the account is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. @@ -200,42 +177,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 } } -// IncrementAccountWaitCount increments the wait queue counter for an account. -func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - if s.cache == nil { - return true, nil - } - - result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) - if err != nil { - log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err) - return true, nil - } - return result, nil -} - -// DecrementAccountWaitCount decrements the wait queue counter for an account. -func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) { - if s.cache == nil { - return - } - - bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { - log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err) - } -} - -// GetAccountWaitingCount gets current wait queue count for an account. -func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - if s.cache == nil { - return 0, nil - } - return s.cache.GetAccountWaitingCount(ctx, accountID) -} - // CalculateMaxWait calculates the maximum wait queue size for a user // maxWait = userConcurrency + defaultExtraWaitSlots func CalculateMaxWait(userConcurrency int) int { @@ -245,57 +186,6 @@ func CalculateMaxWait(userConcurrency int) int { return userConcurrency + defaultExtraWaitSlots } -// GetAccountsLoadBatch returns load info for multiple accounts. -func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { - if s.cache == nil { - return map[int64]*AccountLoadInfo{}, nil - } - return s.cache.GetAccountsLoadBatch(ctx, accounts) -} - -// CleanupExpiredAccountSlots removes expired slots for one account (background task). -func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { - if s.cache == nil { - return nil - } - return s.cache.CleanupExpiredAccountSlots(ctx, accountID) -} - -// StartSlotCleanupWorker starts a background cleanup worker for expired account slots. -func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) { - if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 { - return - } - - runCleanup := func() { - listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - accounts, err := accountRepo.ListSchedulable(listCtx) - cancel() - if err != nil { - log.Printf("Warning: list schedulable accounts failed: %v", err) - return - } - for _, account := range accounts { - accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second) - err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) - accountCancel() - if err != nil { - log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err) - } - } - } - - go func() { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - runCleanup() - for range ticker.C { - runCleanup() - } - }() -} - // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 560c7767..d779bcfa 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -261,34 +261,6 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户") } -func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) { - ctx := context.Background() - - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, - {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: testConfig(), - } - - acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) - require.NoError(t, err) - require.NotNil(t, acc) - require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") -} - // TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户 func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) { ctx := context.Background() @@ -604,32 +576,6 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ctx := context.Background() - t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, - {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: testConfig(), - } - - acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) - require.NoError(t, err) - require.NotNil(t, acc) - require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") - }) - t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ @@ -837,160 +783,3 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) { }) } } - -// mockConcurrencyService for testing -type mockConcurrencyService struct { - accountLoads map[int64]*AccountLoadInfo - accountWaitCounts map[int64]int - acquireResults map[int64]bool -} - -func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { - if m.accountLoads == nil { - return map[int64]*AccountLoadInfo{}, nil - } - result := make(map[int64]*AccountLoadInfo) - for _, acc := range accounts { - if load, ok := m.accountLoads[acc.ID]; ok { - result[acc.ID] = load - } else { - result[acc.ID] = &AccountLoadInfo{ - AccountID: acc.ID, - CurrentConcurrency: 0, - WaitingCount: 0, - LoadRate: 0, - } - } - } - return result, nil -} - -func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - if m.accountWaitCounts == nil { - return 0, nil - } - return m.accountWaitCounts[accountID], nil -} - -// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection -func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { - ctx := context.Background() - - t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, - {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = false - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, // No concurrency service - } - - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) - require.NoError(t, err) - require.NotNil(t, result) - require.NotNil(t, result.Account) - require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号") - }) - - t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, - {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = true - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, - } - - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) - require.NoError(t, err) - require.NotNil(t, result) - require.NotNil(t, result.Account) - require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号") - }) - - t.Run("排除账号-不选择被排除的账号", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, - {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = false - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, - } - - excludedIDs := map[int64]struct{}{1: {}} - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs) - require.NoError(t, err) - require.NotNil(t, result) - require.NotNil(t, result.Account) - require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号") - }) - - t.Run("无可用账号-返回错误", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{}, - accountsByID: map[int64]*Account{}, - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = false - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, - } - - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) - require.Error(t, err) - require.Nil(t, result) - require.Contains(t, err.Error(), "no available accounts") - }) -} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index cb60131b..d542e9c2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -13,14 +13,12 @@ import ( "log" "net/http" "regexp" - "sort" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -68,20 +66,6 @@ type GatewayCache interface { RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error } -type AccountWaitPlan struct { - AccountID int64 - MaxConcurrency int - Timeout time.Duration - MaxWaiting int -} - -type AccountSelectionResult struct { - Account *Account - Acquired bool - ReleaseFunc func() - WaitPlan *AccountWaitPlan // nil means no wait allowed -} - // ClaudeUsage 表示Claude API返回的usage信息 type ClaudeUsage struct { InputTokens int `json:"input_tokens"` @@ -124,7 +108,6 @@ type GatewayService struct { identityService *IdentityService httpUpstream HTTPUpstream deferredService *DeferredService - concurrencyService *ConcurrencyService } // NewGatewayService creates a new GatewayService @@ -136,7 +119,6 @@ func NewGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, - concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -152,7 +134,6 @@ func NewGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, - concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -202,14 +183,6 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return "" } -// BindStickySession sets session -> account binding with standard TTL. -func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 { - return nil - } - return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) -} - func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" @@ -359,354 +332,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } -// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { - stickyAccountID = accountID - } - } - if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) - if err != nil { - return nil, err - } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - - platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) - if err != nil { - return nil, err - } - preferOAuth := platform == PlatformGemini - - accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err != nil { - return nil, err - } - if len(accounts) == 0 { - return nil, errors.New("no available accounts") - } - - isExcluded := func(accountID int64) bool { - if excludedIDs == nil { - return false - } - _, excluded := excludedIDs[accountID] - return excluded - } - - // ============ Layer 1: 粘性会话优先 ============ - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) && - account.IsSchedulable() && - (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - } - } - - // ============ Layer 2: 负载感知选择 ============ - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] - if isExcluded(acc.ID) { - continue - } - if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { - continue - } - candidates = append(candidates, acc) - } - - if len(candidates) == 0 { - return nil, errors.New("no available accounts") - } - - accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) - for _, acc := range candidates { - accountLoads = append(accountLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.Concurrency, - }) - } - - loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) - if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { - return result, nil - } - } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } - } - - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - if preferOAuth && a.account.Type != b.account.Type { - return a.account.Type == AccountTypeOAuth - } - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - - for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - } - } - } - - // ============ Layer 3: 兜底排队 ============ - sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) - for _, acc := range candidates { - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - return nil, errors.New("no available accounts") -} - -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { - ordered := append([]*Account(nil), candidates...) - sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) - - for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, true - } - } - - return nil, false -} - -func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { - if s.cfg != nil { - return s.cfg.Gateway.Scheduling - } - return config.GatewaySchedulingConfig{ - StickySessionMaxWaiting: 3, - StickySessionWaitTimeout: 45 * time.Second, - FallbackWaitTimeout: 30 * time.Second, - FallbackMaxWaiting: 100, - LoadBatchEnabled: true, - SlotCleanupInterval: 30 * time.Second, - } -} - -func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform != "" { - return forcePlatform, true, nil - } - if groupID != nil { - group, err := s.groupRepo.GetByID(ctx, *groupID) - if err != nil { - return "", false, fmt.Errorf("get group failed: %w", err) - } - return group.Platform, false, nil - } - return PlatformAnthropic, false, nil -} - -func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform - if useMixed { - platforms := []string{platform, PlatformAntigravity} - var accounts []Account - var err error - if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) - } - if err != nil { - return nil, useMixed, err - } - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - filtered = append(filtered, acc) - } - return filtered, useMixed, nil - } - - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) - if err == nil && len(accounts) == 0 && hasForcePlatform { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } - if err != nil { - return nil, useMixed, err - } - return accounts, useMixed, nil -} - -func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { - if account == nil { - return false - } - if useMixed { - if account.Platform == platform { - return true - } - return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() - } - return account.Platform == platform -} - -func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - if s.concurrencyService == nil { - return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil - } - return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) -} - -func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { - sort.SliceStable(accounts, func(i, j int) bool { - a, b := accounts[i], accounts[j] - if a.Priority != b.Priority { - return a.Priority < b.Priority - } - switch { - case a.LastUsedAt == nil && b.LastUsedAt != nil: - return true - case a.LastUsedAt != nil && b.LastUsedAt == nil: - return false - case a.LastUsedAt == nil && b.LastUsedAt == nil: - if preferOAuth && a.Type != b.Type { - return a.Type == AccountTypeOAuth - } - return false - default: - return a.LastUsedAt.Before(*b.LastUsedAt) - } - }) -} - // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { - preferOAuth := platform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) @@ -762,9 +389,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } + // keep selected (both never used) default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc @@ -794,7 +419,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { platforms := []string{nativePlatform, PlatformAntigravity} - preferOAuth := nativePlatform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { @@ -854,9 +478,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } + // keep selected (both never used) default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc @@ -1062,30 +684,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { - // 可选:对部分 400 触发 failover(默认关闭以保持语义) - if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - // ReadAll failed, fall back to normal error handling without consuming the stream - return s.handleErrorResponse(ctx, resp, c, account) - } - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - if s.shouldFailoverOn400(respBody) { - if s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "Account %d: 400 error, attempting failover: %s", - account.ID, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } else { - log.Printf("Account %d: 400 error, attempting failover", account.ID) - } - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} - } - } return s.handleErrorResponse(ctx, resp, c, account) } @@ -1188,13 +786,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { - // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultApiKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) - } - } } return req, nil @@ -1247,83 +838,6 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) return claude.DefaultBetaHeader } -func requestNeedsBetaFeatures(body []byte) bool { - tools := gjson.GetBytes(body, "tools") - if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { - return true - } - if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { - return true - } - return false -} - -func defaultApiKeyBetaHeader(body []byte) string { - modelID := gjson.GetBytes(body, "model").String() - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.ApiKeyHaikuBetaHeader - } - return claude.ApiKeyBetaHeader -} - -func truncateForLog(b []byte, maxBytes int) string { - if maxBytes <= 0 { - maxBytes = 2048 - } - if len(b) > maxBytes { - b = b[:maxBytes] - } - s := string(b) - // 保持一行,避免污染日志格式 - s = strings.ReplaceAll(s, "\n", "\\n") - s = strings.ReplaceAll(s, "\r", "\\r") - return s -} - -func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 - // 默认保守:无法识别则不切换。 - msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) - if msg == "" { - return false - } - - // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 - // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 - if strings.Contains(msg, "anthropic-beta") || - strings.Contains(msg, "beta feature") || - strings.Contains(msg, "requires beta") { - return true - } - - // thinking/tool streaming 等兼容性约束(常见于中间转换链路) - if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { - return true - } - if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { - return true - } - - return false -} - -func extractUpstreamErrorMessage(body []byte) string { - // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} - if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { - inner := strings.TrimSpace(m) - // 有些上游会把完整 JSON 作为字符串塞进 message - if strings.HasPrefix(inner, "{") { - if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { - return innerMsg - } - } - return m - } - - // 兜底:尝试顶层 message - return gjson.GetBytes(body, "message").String() -} - func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) @@ -1336,16 +850,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res switch resp.StatusCode { case 400: - // 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开 - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "Upstream 400 error (account=%d platform=%s type=%s): %s", - account.ID, - account.Platform, - account.Type, - truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } c.Data(http.StatusBadRequest, "application/json", body) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) case 401: @@ -1825,18 +1329,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 标记账号状态(429/529等) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - // 记录上游错误摘要便于排障(不回显请求内容) - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", - resp.StatusCode, - account.ID, - account.Platform, - account.Type, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } - // 返回简化的错误响应 errMsg := "Upstream request failed" switch resp.StatusCode { @@ -1917,13 +1409,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { - // API-key:与 messages 同步的按需 beta 注入(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultApiKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) - } - } } return req, nil diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index b1877800..a0bf1b6a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2278,13 +2278,11 @@ func convertClaudeToolsToGeminiTools(tools any) []any { "properties": map[string]any{}, } } - // 清理 JSON Schema - cleanedParams := cleanToolSchema(params) funcDecls = append(funcDecls, map[string]any{ "name": name, "description": desc, - "parameters": cleanedParams, + "parameters": params, }) } @@ -2298,41 +2296,6 @@ func convertClaudeToolsToGeminiTools(tools any) []any { } } -// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段 -func cleanToolSchema(schema any) any { - if schema == nil { - return nil - } - - switch v := schema.(type) { - case map[string]any: - cleaned := make(map[string]any) - for key, value := range v { - // 跳过不支持的字段 - if key == "$schema" || key == "$id" || key == "$ref" || - key == "additionalProperties" || key == "minLength" || - key == "maxLength" || key == "minItems" || key == "maxItems" { - continue - } - // 递归清理嵌套对象 - cleaned[key] = cleanToolSchema(value) - } - // 规范化 type 字段为大写 - if typeVal, ok := cleaned["type"].(string); ok { - cleaned["type"] = strings.ToUpper(typeVal) - } - return cleaned - case []any: - cleaned := make([]any, len(v)) - for i, item := range v { - cleaned[i] = cleanToolSchema(item) - } - return cleaned - default: - return v - } -} - func convertClaudeGenerationConfig(req map[string]any) map[string]any { out := make(map[string]any) if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go deleted file mode 100644 index d49f2eb3..00000000 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package service - -import ( - "testing" -) - -// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 -func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { - tests := []struct { - name string - tools any - expectedLen int - description string - }{ - { - name: "Standard tools", - tools: []any{ - map[string]any{ - "name": "get_weather", - "description": "Get weather info", - "input_schema": map[string]any{"type": "object"}, - }, - }, - expectedLen: 1, - description: "标准工具格式应该正常转换", - }, - { - name: "Custom type tool (MCP format)", - tools: []any{ - map[string]any{ - "type": "custom", - "name": "mcp_tool", - "custom": map[string]any{ - "description": "MCP tool description", - "input_schema": map[string]any{"type": "object"}, - }, - }, - }, - expectedLen: 1, - description: "Custom类型工具应该从custom字段读取", - }, - { - name: "Mixed standard and custom tools", - tools: []any{ - map[string]any{ - "name": "standard_tool", - "description": "Standard", - "input_schema": map[string]any{"type": "object"}, - }, - map[string]any{ - "type": "custom", - "name": "custom_tool", - "custom": map[string]any{ - "description": "Custom", - "input_schema": map[string]any{"type": "object"}, - }, - }, - }, - expectedLen: 1, - description: "混合工具应该都能正确转换", - }, - { - name: "Custom tool without custom field", - tools: []any{ - map[string]any{ - "type": "custom", - "name": "invalid_custom", - // 缺少 custom 字段 - }, - }, - expectedLen: 0, // 应该被跳过 - description: "缺少custom字段的custom工具应该被跳过", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertClaudeToolsToGeminiTools(tt.tools) - - if tt.expectedLen == 0 { - if result != nil { - t.Errorf("%s: expected nil result, got %v", tt.description, result) - } - return - } - - if result == nil { - t.Fatalf("%s: expected non-nil result", tt.description) - } - - if len(result) != 1 { - t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result)) - return - } - - toolDecl, ok := result[0].(map[string]any) - if !ok { - t.Fatalf("%s: result[0] is not map[string]any", tt.description) - } - - funcDecls, ok := toolDecl["functionDeclarations"].([]any) - if !ok { - t.Fatalf("%s: functionDeclarations is not []any", tt.description) - } - - toolsArr, _ := tt.tools.([]any) - expectedFuncCount := 0 - for _, tool := range toolsArr { - toolMap, _ := tool.(map[string]any) - if toolMap["name"] != "" { - // 检查是否为有效的custom工具 - if toolMap["type"] == "custom" { - if toolMap["custom"] != nil { - expectedFuncCount++ - } - } else { - expectedFuncCount++ - } - } - } - - if len(funcDecls) != expectedFuncCount { - t.Errorf("%s: expected %d function declarations, got %d", - tt.description, expectedFuncCount, len(funcDecls)) - } - }) - } -} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 221bd0f2..e4bda5f8 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/http" - "regexp" "strconv" "strings" "time" @@ -164,45 +163,6 @@ type GeminiTokenInfo struct { Scope string `json:"scope,omitempty"` ProjectID string `json:"project_id,omitempty"` OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" - TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA -} - -// validateTierID validates tier_id format and length -func validateTierID(tierID string) error { - if tierID == "" { - return nil // Empty is allowed - } - if len(tierID) > 64 { - return fmt.Errorf("tier_id exceeds maximum length of 64 characters") - } - // Allow alphanumeric, underscore, hyphen, and slash (for tier paths) - if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) { - return fmt.Errorf("tier_id contains invalid characters") - } - return nil -} - -// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response -// Prioritizes IsDefault tier, falls back to first non-empty tier -func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string { - tierID := "LEGACY" - // First pass: look for default tier - for _, tier := range allowedTiers { - if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - // Second pass: if still LEGACY, take first non-empty tier - if tierID == "LEGACY" { - for _, tier := range allowedTiers { - if strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - } - return tierID } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { @@ -263,14 +223,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID := sessionProjectID - var tierID string // 对于 code_assist 模式,project_id 是必需的 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) if oauthType == "code_assist" { if projectID == "" { var err error - projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) @@ -289,7 +248,6 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ExpiresAt: expiresAt, Scope: tokenResp.Scope, ProjectID: projectID, - TierID: tierID, OAuthType: oauthType, }, nil } @@ -399,7 +357,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { - projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) if err != nil { return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) } @@ -408,7 +366,6 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } tokenInfo.ProjectID = projectID - tokenInfo.TierID = tierID } return tokenInfo, nil @@ -431,13 +388,6 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } - if tokenInfo.TierID != "" { - // Validate tier_id before storing - if err := validateTierID(tokenInfo.TierID); err == nil { - creds["tier_id"] = tokenInfo.TierID - } - // Silently skip invalid tier_id (don't block account creation) - } if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } @@ -448,26 +398,34 @@ func (s *GeminiOAuthService) Stop() { s.sessionStore.Stop() } -func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) { +func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { if s.codeAssist == nil { - return "", "", errors.New("code assist client not configured") + return "", errors.New("code assist client not configured") } loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) - - // Extract tierID from response (works whether CloudAICompanionProject is set or not) - tierID := "LEGACY" - if loadResp != nil { - tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers) - } - - // If LoadCodeAssist returned a project, use it if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { - return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil + return strings.TrimSpace(loadResp.CloudAICompanionProject), nil } // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. - // (tierID already extracted above, reuse it) + tierID := "LEGACY" + if loadResp != nil { + for _, tier := range loadResp.AllowedTiers { + if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" { + for _, tier := range loadResp.AllowedTiers { + if strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + } + } req := &geminicli.OnboardUserRequest{ TierID: tierID, @@ -485,39 +443,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), tierID, nil + return strings.TrimSpace(fallback), nil } - return "", "", err + return "", err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { switch v := resp.Response.CloudAICompanionProject.(type) { case string: - return strings.TrimSpace(v), tierID, nil + return strings.TrimSpace(v), nil case map[string]any: if id, ok := v["id"].(string); ok { - return strings.TrimSpace(id), tierID, nil + return strings.TrimSpace(id), nil } } } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), tierID, nil + return strings.TrimSpace(fallback), nil } - return "", "", errors.New("onboardUser completed but no project_id returned") + return "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), tierID, nil + return strings.TrimSpace(fallback), nil } if loadErr != nil { - return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct { diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 5f369de5..2195ec55 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -112,7 +112,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) + detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) if err != nil { log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) return accessToken, nil @@ -123,9 +123,6 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou account.Credentials = make(map[string]any) } account.Credentials["project_id"] = detected - if tierID != "" { - account.Credentials["tier_id"] = tierID - } _ = p.accountRepo.Update(ctx, account) } } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f8eb29bd..84e98679 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -13,7 +13,6 @@ import ( "log" "net/http" "regexp" - "sort" "strconv" "strings" "time" @@ -81,7 +80,6 @@ type OpenAIGatewayService struct { userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config - concurrencyService *ConcurrencyService billingService *BillingService rateLimitService *RateLimitService billingCacheService *BillingCacheService @@ -97,7 +95,6 @@ func NewOpenAIGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, - concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -111,7 +108,6 @@ func NewOpenAIGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, - concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -130,14 +126,6 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { return hex.EncodeToString(hash[:]) } -// BindStickySession sets session -> account binding with standard TTL. -func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 { - return nil - } - return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL) -} - // SelectAccount selects an OpenAI account with sticky session support func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") @@ -230,254 +218,6 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C return selected, nil } -// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. -func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil { - stickyAccountID = accountID - } - } - if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) - if err != nil { - return nil, err - } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - - accounts, err := s.listSchedulableAccounts(ctx, groupID) - if err != nil { - return nil, err - } - if len(accounts) == 0 { - return nil, errors.New("no available accounts") - } - - isExcluded := func(accountID int64) bool { - if excludedIDs == nil { - return false - } - _, excluded := excludedIDs[accountID] - return excluded - } - - // ============ Layer 1: Sticky session ============ - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && account.IsSchedulable() && account.IsOpenAI() && - (requestedModel == "" || account.IsModelSupported(requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - } - } - - // ============ Layer 2: Load-aware selection ============ - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] - if isExcluded(acc.ID) { - continue - } - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { - continue - } - candidates = append(candidates, acc) - } - - if len(candidates) == 0 { - return nil, errors.New("no available accounts") - } - - accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) - for _, acc := range candidates { - accountLoads = append(accountLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.Concurrency, - }) - } - - loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) - if err != nil { - ordered := append([]*Account(nil), candidates...) - sortAccountsByPriorityAndLastUsed(ordered, false) - for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) - } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - } - } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } - } - - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - - for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) - } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - } - } - } - - // ============ Layer 3: Fallback wait ============ - sortAccountsByPriorityAndLastUsed(candidates, false) - for _, acc := range candidates { - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - - return nil, errors.New("no available accounts") -} - -func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) - } - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - return accounts, nil -} - -func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - if s.concurrencyService == nil { - return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil - } - return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) -} - -func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { - if s.cfg != nil { - return s.cfg.Gateway.Scheduling - } - return config.GatewaySchedulingConfig{ - StickySessionMaxWaiting: 3, - StickySessionWaitTimeout: 45 * time.Second, - FallbackWaitTimeout: 30 * time.Second, - FallbackMaxWaiting: 100, - LoadBatchEnabled: true, - SlotCleanupInterval: 30 * time.Second, - } -} - // GetAccessToken gets the access token for an OpenAI account func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index a202ccf2..81e01d47 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -73,15 +73,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh return svc } -// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. -func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { - svc := NewConcurrencyService(cache) - if cfg != nil { - svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) - } - return svc -} - // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -116,7 +107,7 @@ var ProviderSet = wire.NewSet( ProvideEmailQueueService, NewTurnstileService, NewSubscriptionService, - ProvideConcurrencyService, + NewConcurrencyService, NewIdentityService, NewCRSSyncService, ProvideUpdateService, diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 5478d151..5bd85d7d 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -122,21 +122,6 @@ pricing: # Hash check interval in minutes hash_check_interval_minutes: 10 -# ============================================================================= -# Gateway (Optional) -# ============================================================================= -gateway: - # Wait time (in seconds) for upstream response headers (streaming body not affected) - response_header_timeout: 300 - # Log upstream error response body summary (safe/truncated; does not log request content) - log_upstream_error_body: false - # Max bytes to log from upstream error body - log_upstream_error_body_max_bytes: 2048 - # Auto inject anthropic-beta for API-key accounts when needed (default off) - inject_beta_for_apikey: false - # Allow failover on selected 400 errors (default off) - failover_on_400: false - # ============================================================================= # Gemini OAuth (Required for Gemini accounts) # ============================================================================= diff --git a/deploy/flow.md b/deploy/flow.md deleted file mode 100644 index 0904c72f..00000000 --- a/deploy/flow.md +++ /dev/null @@ -1,222 +0,0 @@ -```mermaid -flowchart TD - %% Master dispatch - A[HTTP Request] --> B{Route} - B -->|v1 messages| GA0 - B -->|openai v1 responses| OA0 - B -->|v1beta models model action| GM0 - B -->|v1 messages count tokens| GT0 - B -->|v1beta models list or get| GL0 - - %% ========================= - %% FLOW A: Claude Gateway - %% ========================= - subgraph FLOW_A["v1 messages Claude Gateway"] - GA0[Auth middleware] --> GA1[Read body] - GA1 -->|empty| GA1E[400 invalid_request_error] - GA1 --> GA2[ParseGatewayRequest] - GA2 -->|parse error| GA2E[400 invalid_request_error] - GA2 --> GA3{model present} - GA3 -->|no| GA3E[400 invalid_request_error] - GA3 --> GA4[streamStarted false] - GA4 --> GA5[IncrementWaitCount user] - GA5 -->|queue full| GA5E[429 rate_limit_error] - GA5 --> GA6[AcquireUserSlotWithWait] - GA6 -->|timeout or fail| GA6E[429 rate_limit_error] - GA6 --> GA7[BillingEligibility check post wait] - GA7 -->|fail| GA7E[403 billing_error] - GA7 --> GA8[Generate sessionHash] - GA8 --> GA9[Resolve platform] - GA9 --> GA10{platform gemini} - GA10 -->|yes| GA10Y[sessionKey gemini hash] - GA10 -->|no| GA10N[sessionKey hash] - GA10Y --> GA11 - GA10N --> GA11 - - GA11[SelectAccountWithLoadAwareness] -->|err and no failed| GA11E1[503 no available accounts] - GA11 -->|err and failed| GA11E2[map failover error] - GA11 --> GA12[Warmup intercept] - GA12 -->|yes| GA12Y[return mock and release if held] - GA12 -->|no| GA13[Acquire account slot or wait] - GA13 -->|wait queue full| GA13E1[429 rate_limit_error] - GA13 -->|wait timeout| GA13E2[429 concurrency limit] - GA13 --> GA14[BindStickySession if waited] - GA14 --> GA15{account platform antigravity} - GA15 -->|yes| GA15Y[ForwardGemini antigravity] - GA15 -->|no| GA15N[Forward Claude] - GA15Y --> GA16[Release account slot and dec account wait] - GA15N --> GA16 - GA16 --> GA17{UpstreamFailoverError} - GA17 -->|yes| GA18[mark failedAccountIDs and map error if exceed] - GA18 -->|loop| GA11 - GA17 -->|no| GA19[success async RecordUsage and return] - GA19 --> GA20[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW B: OpenAI - %% ========================= - subgraph FLOW_B["openai v1 responses"] - OA0[Auth middleware] --> OA1[Read body] - OA1 -->|empty| OA1E[400 invalid_request_error] - OA1 --> OA2[json Unmarshal body] - OA2 -->|parse error| OA2E[400 invalid_request_error] - OA2 --> OA3{model present} - OA3 -->|no| OA3E[400 invalid_request_error] - OA3 --> OA4{User Agent Codex CLI} - OA4 -->|no| OA4N[set default instructions] - OA4 -->|yes| OA4Y[no change] - OA4N --> OA5 - OA4Y --> OA5 - OA5[streamStarted false] --> OA6[IncrementWaitCount user] - OA6 -->|queue full| OA6E[429 rate_limit_error] - OA6 --> OA7[AcquireUserSlotWithWait] - OA7 -->|timeout or fail| OA7E[429 rate_limit_error] - OA7 --> OA8[BillingEligibility check post wait] - OA8 -->|fail| OA8E[403 billing_error] - OA8 --> OA9[sessionHash sha256 session_id] - OA9 --> OA10[SelectAccountWithLoadAwareness] - OA10 -->|err and no failed| OA10E1[503 no available accounts] - OA10 -->|err and failed| OA10E2[map failover error] - OA10 --> OA11[Acquire account slot or wait] - OA11 -->|wait queue full| OA11E1[429 rate_limit_error] - OA11 -->|wait timeout| OA11E2[429 concurrency limit] - OA11 --> OA12[BindStickySession openai hash if waited] - OA12 --> OA13[Forward OpenAI upstream] - OA13 --> OA14[Release account slot and dec account wait] - OA14 --> OA15{UpstreamFailoverError} - OA15 -->|yes| OA16[mark failedAccountIDs and map error if exceed] - OA16 -->|loop| OA10 - OA15 -->|no| OA17[success async RecordUsage and return] - OA17 --> OA18[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW C: Gemini Native - %% ========================= - subgraph FLOW_C["v1beta models model action Gemini Native"] - GM0[Auth middleware] --> GM1[Validate platform] - GM1 -->|invalid| GM1E[400 googleError] - GM1 --> GM2[Parse path modelName action] - GM2 -->|invalid| GM2E[400 googleError] - GM2 --> GM3{action supported} - GM3 -->|no| GM3E[404 googleError] - GM3 --> GM4[Read body] - GM4 -->|empty| GM4E[400 googleError] - GM4 --> GM5[streamStarted false] - GM5 --> GM6[IncrementWaitCount user] - GM6 -->|queue full| GM6E[429 googleError] - GM6 --> GM7[AcquireUserSlotWithWait] - GM7 -->|timeout or fail| GM7E[429 googleError] - GM7 --> GM8[BillingEligibility check post wait] - GM8 -->|fail| GM8E[403 googleError] - GM8 --> GM9[Generate sessionHash] - GM9 --> GM10[sessionKey gemini hash] - GM10 --> GM11[SelectAccountWithLoadAwareness] - GM11 -->|err and no failed| GM11E1[503 googleError] - GM11 -->|err and failed| GM11E2[mapGeminiUpstreamError] - GM11 --> GM12[Acquire account slot or wait] - GM12 -->|wait queue full| GM12E1[429 googleError] - GM12 -->|wait timeout| GM12E2[429 googleError] - GM12 --> GM13[BindStickySession if waited] - GM13 --> GM14{account platform antigravity} - GM14 -->|yes| GM14Y[ForwardGemini antigravity] - GM14 -->|no| GM14N[ForwardNative] - GM14Y --> GM15[Release account slot and dec account wait] - GM14N --> GM15 - GM15 --> GM16{UpstreamFailoverError} - GM16 -->|yes| GM17[mark failedAccountIDs and map error if exceed] - GM17 -->|loop| GM11 - GM16 -->|no| GM18[success async RecordUsage and return] - GM18 --> GM19[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW D: CountTokens - %% ========================= - subgraph FLOW_D["v1 messages count tokens"] - GT0[Auth middleware] --> GT1[Read body] - GT1 -->|empty| GT1E[400 invalid_request_error] - GT1 --> GT2[ParseGatewayRequest] - GT2 -->|parse error| GT2E[400 invalid_request_error] - GT2 --> GT3{model present} - GT3 -->|no| GT3E[400 invalid_request_error] - GT3 --> GT4[BillingEligibility check] - GT4 -->|fail| GT4E[403 billing_error] - GT4 --> GT5[ForwardCountTokens] - end - - %% ========================= - %% FLOW E: Gemini Models List Get - %% ========================= - subgraph FLOW_E["v1beta models list or get"] - GL0[Auth middleware] --> GL1[Validate platform] - GL1 -->|invalid| GL1E[400 googleError] - GL1 --> GL2{force platform antigravity} - GL2 -->|yes| GL2Y[return static fallback models] - GL2 -->|no| GL3[SelectAccountForAIStudioEndpoints] - GL3 -->|no gemini and has antigravity| GL3Y[return fallback models] - GL3 -->|no accounts| GL3E[503 googleError] - GL3 --> GL4[ForwardAIStudioGET] - GL4 -->|error| GL4E[502 googleError] - GL4 --> GL5[Passthrough response or fallback] - end - - %% ========================= - %% SHARED: Account Selection - %% ========================= - subgraph SELECT["SelectAccountWithLoadAwareness detail"] - S0[Start] --> S1{concurrencyService nil OR load batch disabled} - S1 -->|yes| S2[SelectAccountForModelWithExclusions legacy] - S2 --> S3[tryAcquireAccountSlot] - S3 -->|acquired| S3Y[SelectionResult Acquired true ReleaseFunc] - S3 -->|not acquired| S3N[WaitPlan FallbackTimeout MaxWaiting] - S1 -->|no| S4[Resolve platform] - S4 --> S5[List schedulable accounts] - S5 --> S6[Layer1 Sticky session] - S6 -->|hit and valid| S6A[tryAcquireAccountSlot] - S6A -->|acquired| S6AY[SelectionResult Acquired true] - S6A -->|not acquired and waitingCount < StickyMax| S6AN[WaitPlan StickyTimeout Max] - S6 --> S7[Layer2 Load aware] - S7 --> S7A[Load batch concurrency plus wait to loadRate] - S7A --> S7B[Sort priority load LRU OAuth prefer for Gemini] - S7B --> S7C[tryAcquireAccountSlot in order] - S7C -->|first success| S7CY[SelectionResult Acquired true] - S7C -->|none| S8[Layer3 Fallback wait] - S8 --> S8A[Sort priority LRU] - S8A --> S8B[WaitPlan FallbackTimeout Max] - end - - %% ========================= - %% SHARED: Wait Acquire - %% ========================= - subgraph WAIT["AcquireXSlotWithWait detail"] - W0[Try AcquireXSlot immediately] -->|acquired| W1[return ReleaseFunc] - W0 -->|not acquired| W2[Wait loop with timeout] - W2 --> W3[Backoff 100ms x1.5 jitter max2s] - W2 --> W4[If streaming and ping format send SSE ping] - W2 --> W5[Retry AcquireXSlot on timer] - W5 -->|acquired| W1 - W2 -->|timeout| W6[ConcurrencyError IsTimeout true] - end - - %% ========================= - %% SHARED: Account Wait Queue - %% ========================= - subgraph AQ["Account Wait Queue Redis Lua"] - Q1[IncrementAccountWaitCount] --> Q2{current >= max} - Q2 -->|yes| Q2Y[return false] - Q2 -->|no| Q3[INCR and if first set TTL] - Q3 --> Q4[return true] - Q5[DecrementAccountWaitCount] --> Q6[if current > 0 then DECR] - end - - %% ========================= - %% SHARED: Background cleanup - %% ========================= - subgraph CLEANUP["Slot Cleanup Worker"] - C0[StartSlotCleanupWorker interval] --> C1[List schedulable accounts] - C1 --> C2[CleanupExpiredAccountSlots per account] - C2 --> C3[Repeat every interval] - end -``` diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 1770a985..6563ee0c 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -952,7 +952,6 @@ "integrity": "sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -1368,7 +1367,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -1445,7 +1443,6 @@ "resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz", "integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==", "license": "MIT", - "peer": true, "dependencies": { "@kurkle/color": "^0.3.0" }, @@ -2043,7 +2040,6 @@ "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "dev": true, "license": "MIT", - "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -2352,7 +2348,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -2826,7 +2821,6 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -2860,7 +2854,6 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "devOptional": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -2933,7 +2926,6 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -3105,7 +3097,6 @@ "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "license": "MIT", - "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.25", "@vue/compiler-sfc": "3.5.25", @@ -3199,7 +3190,6 @@ "integrity": "sha512-P7OP77b2h/Pmk+lZdJ0YWs+5tJ6J2+uOQPo7tlBnY44QqQSPYvS0qVT4wqDJgwrZaLe47etJLLQRFia71GYITw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@volar/typescript": "2.4.15", "@vue/language-core": "2.2.12" diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 914678a5..c1ca08fa 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -83,14 +83,6 @@ > - - - - {{ tierDisplay }} - @@ -148,23 +140,4 @@ const statusText = computed(() => { return props.account.status }) -// Computed: tier display -const tierDisplay = computed(() => { - const credentials = props.account.credentials as Record | undefined - const tierId = credentials?.tier_id - if (!tierId || tierId === 'unknown') return null - - const tierMap: Record = { - 'free': 'Free', - 'payg': 'Pay-as-you-go', - 'pay-as-you-go': 'Pay-as-you-go', - 'enterprise': 'Enterprise', - 'LEGACY': 'Legacy', - 'PRO': 'Pro', - 'ULTRA': 'Ultra' - } - - return tierMap[tierId] || tierId -}) -