diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index c426eec1..c0199258 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -77,6 +77,7 @@ func provideCleanup( pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, @@ -176,6 +177,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index a0f8807a..0b57334b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -182,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) + usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) - soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) + soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler) @@ -205,7 +206,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -245,6 +246,7 @@ func provideCleanup( pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, @@ -343,6 +345,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/go.mod b/backend/go.mod index 2a79c203..94b6fcbb 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,6 +5,7 @@ go 1.25.7 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/alitto/pond/v2 v2.6.2 github.com/cespare/xxhash/v2 v2.3.0 github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 diff --git a/backend/go.sum b/backend/go.sum index eda2af99..fa84988a 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -14,6 +14,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= +github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 330ae0c1..777993cd 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -19,6 +19,13 @@ const ( RunModeSimple = "simple" ) +// 使用量记录队列溢出策略 +const ( + UsageRecordOverflowPolicyDrop = "drop" + UsageRecordOverflowPolicySample = "sample" + UsageRecordOverflowPolicySync = "sync" +) + // DefaultCSPPolicy is the default Content-Security-Policy with nonce support // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" @@ -413,6 +420,42 @@ type GatewayConfig struct { // TLSFingerprint: TLS指纹伪装配置 TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` + + // UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker) + UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"` +} + +// GatewayUsageRecordConfig 使用量记录异步队列配置 +type GatewayUsageRecordConfig struct { + // WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限) + WorkerCount int `mapstructure:"worker_count"` + // QueueSize: 队列容量(有界) + QueueSize int `mapstructure:"queue_size"` + // TaskTimeoutSeconds: 单个使用量记录任务超时(秒) + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` + // OverflowPolicy: 队列满时策略(drop/sample/sync) + OverflowPolicy string `mapstructure:"overflow_policy"` + // OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100) + OverflowSamplePercent int `mapstructure:"overflow_sample_percent"` + + // AutoScaleEnabled: 是否启用 worker 自动扩缩容 + AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"` + // AutoScaleMinWorkers: 自动扩缩容最小 worker 数 + AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"` + // AutoScaleMaxWorkers: 自动扩缩容最大 worker 数 + AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"` + // AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容 + AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"` + // AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容 + AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"` + // AutoScaleUpStep: 每次扩容步长 + AutoScaleUpStep int `mapstructure:"auto_scale_up_step"` + // AutoScaleDownStep: 每次缩容步长 + AutoScaleDownStep int `mapstructure:"auto_scale_down_step"` + // AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒) + AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"` + // AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒) + AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` } // SoraModelFiltersConfig Sora 模型过滤配置 @@ -1118,6 +1161,20 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) + viper.SetDefault("gateway.usage_record.worker_count", 128) + viper.SetDefault("gateway.usage_record.queue_size", 16384) + viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) + viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) + viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) + viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) + viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) + viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512) + viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70) + viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15) + viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32) + viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16) + viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3) + viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) @@ -1636,6 +1693,64 @@ func (c *Config) Validate() error { if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { return fmt.Errorf("gateway.max_line_size must be at least 1MB") } + if c.Gateway.UsageRecord.WorkerCount <= 0 { + return fmt.Errorf("gateway.usage_record.worker_count must be positive") + } + if c.Gateway.UsageRecord.QueueSize <= 0 { + return fmt.Errorf("gateway.usage_record.queue_size must be positive") + } + if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive") + } + switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) { + case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync: + default: + return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s", + UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync) + } + if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100") + } + if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) && + c.Gateway.UsageRecord.OverflowSamplePercent <= 0 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample") + } + if c.Gateway.UsageRecord.AutoScaleEnabled { + if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers") + } + if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers || + c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers { + return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers") + } + if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent") + } + if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative") + } + } if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index dcc60879..2e79e5ed 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -942,6 +942,74 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, wantErr: "gateway.max_line_size must be non-negative", }, + { + name: "gateway usage record worker count", + mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 }, + wantErr: "gateway.usage_record.worker_count", + }, + { + name: "gateway usage record queue size", + mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 }, + wantErr: "gateway.usage_record.queue_size", + }, + { + name: "gateway usage record timeout", + mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 }, + wantErr: "gateway.usage_record.task_timeout_seconds", + }, + { + name: "gateway usage record overflow policy", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" }, + wantErr: "gateway.usage_record.overflow_policy", + }, + { + name: "gateway usage record sample percent range", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 }, + wantErr: "gateway.usage_record.overflow_sample_percent", + }, + { + name: "gateway usage record sample percent required for sample policy", + mutate: func(c *Config) { + c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample + c.Gateway.UsageRecord.OverflowSamplePercent = 0 + }, + wantErr: "gateway.usage_record.overflow_sample_percent must be positive", + }, + { + name: "gateway usage record auto scale max gte min", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 256 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128 + }, + wantErr: "gateway.usage_record.auto_scale_max_workers", + }, + { + name: "gateway usage record worker in auto scale range", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 200 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300 + c.Gateway.UsageRecord.WorkerCount = 128 + }, + wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers", + }, + { + name: "gateway usage record auto scale queue thresholds order", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50 + c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50 + }, + wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less", + }, + { + name: "gateway usage record auto scale up step", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 }, + wantErr: "gateway.usage_record.auto_scale_up_step", + }, + { + name: "gateway usage record auto scale interval", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 }, + wantErr: "gateway.usage_record.auto_scale_check_interval_seconds", + }, { name: "gateway scheduling sticky waiting", mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, @@ -1025,6 +1093,99 @@ func TestValidateConfigErrors(t *testing.T) { } } +func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Gateway.UsageRecord.AutoScaleEnabled = false + cfg.Gateway.UsageRecord.WorkerCount = 64 + + // 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。 + cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0 + cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100 + cfg.Gateway.UsageRecord.AutoScaleUpStep = 0 + cfg.Gateway.UsageRecord.AutoScaleDownStep = 0 + cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 + cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1 + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err) + } +} + +func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) { + resetViperWithJWTSecret(t) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "log level required", + mutate: func(c *Config) { + c.Log.Level = "" + }, + wantErr: "log.level is required", + }, + { + name: "log format required", + mutate: func(c *Config) { + c.Log.Format = "" + }, + wantErr: "log.format is required", + }, + { + name: "log stacktrace required", + mutate: func(c *Config) { + c.Log.StacktraceLevel = "" + }, + wantErr: "log.stacktrace_level is required", + }, + { + name: "log max backups non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxBackups = -1 + }, + wantErr: "log.rotation.max_backups must be non-negative", + }, + { + name: "log max age non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxAgeDays = -1 + }, + wantErr: "log.rotation.max_age_days must be non-negative", + }, + { + name: "sampling thereafter non-negative when disabled", + mutate: func(c *Config) { + c.Log.Sampling.Enabled = false + c.Log.Sampling.Thereafter = -1 + }, + wantErr: "log.sampling.thereafter must be non-negative", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + tt.mutate(cfg) + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr) + } + }) + } +} + func TestSoraCurlCFFISidecarDefaults(t *testing.T) { resetViperWithJWTSecret(t) @@ -1112,3 +1273,53 @@ func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) { t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) } } + +func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Gateway.UsageRecord.WorkerCount != 128 { + t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount) + } + if cfg.Gateway.UsageRecord.QueueSize != 16384 { + t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize) + } + if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 { + t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds) + } + if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample { + t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample) + } + if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 { + t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent) + } + if !cfg.Gateway.UsageRecord.AutoScaleEnabled { + t.Fatalf("auto_scale_enabled = false, want true") + } + if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 { + t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 { + t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 { + t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 { + t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 { + t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep) + } + if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 { + t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep) + } + if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 { + t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds) + } + if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 { + t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) + } +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ca5ee9d7..bbe73689 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -37,6 +37,7 @@ type GatewayHandler struct { billingCacheService *service.BillingCacheService usageService *service.UsageService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int @@ -54,6 +55,7 @@ func NewGatewayHandler( billingCacheService *service.BillingCacheService, usageService *service.UsageService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *GatewayHandler { @@ -77,6 +79,7 @@ func NewGatewayHandler( billingCacheService: billingCacheService, usageService: usageService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, @@ -431,19 +434,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: forceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -452,10 +453,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", reqModel), - zap.Int64("account_id", usedAccount.ID), + zap.Int64("account_id", account.ID), ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) return } } @@ -700,19 +701,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: currentAPIKey, User: currentAPIKey.User, - Account: usedAccount, + Account: account, Subscription: currentSubscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: forceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -721,10 +720,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { zap.Int64("api_key_id", currentAPIKey.ID), zap.Any("group_id", currentAPIKey.GroupID), zap.String("model", reqModel), - zap.Int64("account_id", usedAccount.ID), + zap.Int64("account_id", account.ID), ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) reqLog.Debug("gateway.request_completed", zap.Int64("account_id", account.ID), zap.Int("switch_count", switchCount), @@ -1508,3 +1507,17 @@ func billingErrorDetails(err error) (status int, code, message string) { } return http.StatusForbidden, "billing_error", msg } + +func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + task(ctx) +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 8b73aad7..86c2e4a4 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -11,7 +11,6 @@ import ( "net/http" "regexp" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -519,22 +518,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } } - // 6) record usage async (Gemini 使用长上下文双倍计费) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 - ForceCacheBilling: fcb, + ForceCacheBilling: forceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -543,10 +539,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", modelName), - zap.Int64("account_id", usedAccount.ID), + zap.Int64("account_id", account.ID), ).Error("gemini.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) reqLog.Debug("gemini.request_completed", zap.Int64("account_id", account.ID), zap.Int("switch_count", switchCount), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index f5db385b..50af684d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -26,6 +26,7 @@ type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int @@ -37,6 +38,7 @@ func NewOpenAIGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { @@ -52,6 +54,7 @@ func NewOpenAIGatewayHandler( gatewayService: gatewayService, billingCacheService: billingCacheService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, @@ -378,18 +381,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // Async record usage - go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -398,10 +399,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", reqModel), - zap.Int64("account_id", usedAccount.ID), + zap.Int64("account_id", account.ID), ).Error("openai.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP) + }) reqLog.Debug("openai.request_completed", zap.Int64("account_id", account.ID), zap.Int("switch_count", switchCount), @@ -432,6 +433,20 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) { } } +func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + task(ctx) +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index b958a133..ab3a3f14 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -31,15 +31,16 @@ import ( // SoraGatewayHandler handles Sora chat completions requests type SoraGatewayHandler struct { - gatewayService *service.GatewayService - soraGatewayService *service.SoraGatewayService - billingCacheService *service.BillingCacheService - concurrencyHelper *ConcurrencyHelper - maxAccountSwitches int - streamMode string - soraTLSEnabled bool - soraMediaSigningKey string - soraMediaRoot string + gatewayService *service.GatewayService + soraGatewayService *service.SoraGatewayService + billingCacheService *service.BillingCacheService + usageRecordWorkerPool *service.UsageRecordWorkerPool + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + streamMode string + soraTLSEnabled bool + soraMediaSigningKey string + soraMediaRoot string } // NewSoraGatewayHandler creates a new SoraGatewayHandler @@ -48,6 +49,7 @@ func NewSoraGatewayHandler( soraGatewayService *service.SoraGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, cfg *config.Config, ) *SoraGatewayHandler { pingInterval := time.Duration(0) @@ -71,15 +73,16 @@ func NewSoraGatewayHandler( } } return &SoraGatewayHandler{ - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - billingCacheService: billingCacheService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - maxAccountSwitches: maxAccountSwitches, - streamMode: strings.ToLower(streamMode), - soraTLSEnabled: soraTLSEnabled, - soraMediaSigningKey: signKey, - soraMediaRoot: mediaRoot, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + billingCacheService: billingCacheService, + usageRecordWorkerPool: usageRecordWorkerPool, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, + streamMode: strings.ToLower(streamMode), + soraTLSEnabled: soraTLSEnabled, + soraMediaSigningKey: signKey, + soraMediaRoot: mediaRoot, } } @@ -397,17 +400,16 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, }); err != nil { logger.L().With( zap.String("component", "handler.sora_gateway.chat_completions"), @@ -415,10 +417,10 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", reqModel), - zap.Int64("account_id", usedAccount.ID), + zap.Int64("account_id", account.ID), ).Error("sora.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP) + }) reqLog.Debug("sora.request_completed", zap.Int64("account_id", account.ID), zap.Int64("proxy_id", proxyID), @@ -448,6 +450,20 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string { return hex.EncodeToString(hash[:]) } +func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + task(ctx) +} + func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 3f6ef10e..cc792350 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -432,7 +432,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg) - handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg) + handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go new file mode 100644 index 00000000..df759f44 --- /dev/null +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -0,0 +1,136 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool { + t.Helper() + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + return pool +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &GatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &GatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &SoraGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &SoraGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} diff --git a/backend/internal/service/usage_record_worker_pool.go b/backend/internal/service/usage_record_worker_pool.go new file mode 100644 index 00000000..5da0b890 --- /dev/null +++ b/backend/internal/service/usage_record_worker_pool.go @@ -0,0 +1,496 @@ +package service + +import ( + "context" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/alitto/pond/v2" + "go.uber.org/zap" +) + +const ( + defaultUsageRecordWorkerCount = 128 + defaultUsageRecordQueueSize = 16384 + defaultUsageRecordTaskTimeoutSeconds = 5 + defaultUsageRecordOverflowPolicy = config.UsageRecordOverflowPolicySample + defaultUsageRecordOverflowSampleRatio = 10 + defaultUsageRecordAutoScaleEnabled = true + defaultUsageRecordAutoScaleMinWorkers = 128 + defaultUsageRecordAutoScaleMaxWorkers = 512 + defaultUsageRecordAutoScaleUpPercent = 70 + defaultUsageRecordAutoScaleDownPercent = 15 + defaultUsageRecordAutoScaleUpStep = 32 + defaultUsageRecordAutoScaleDownStep = 16 + defaultUsageRecordAutoScaleInterval = 3 * time.Second + defaultUsageRecordAutoScaleCooldown = 10 * time.Second + usageRecordDropLogInterval = 5 * time.Second +) + +// UsageRecordTask 是提交到使用量记录池的任务。 +// 任务实现应自行处理业务错误日志;池本身只负责调度与超时控制。 +type UsageRecordTask func(ctx context.Context) + +// UsageRecordSubmitMode 表示任务提交结果。 +type UsageRecordSubmitMode string + +const ( + UsageRecordSubmitModeEnqueued UsageRecordSubmitMode = "enqueued" + UsageRecordSubmitModeDropped UsageRecordSubmitMode = "dropped" + UsageRecordSubmitModeSync UsageRecordSubmitMode = "sync_fallback" +) + +// UsageRecordWorkerPoolOptions 使用量记录池配置。 +type UsageRecordWorkerPoolOptions struct { + WorkerCount int + QueueSize int + TaskTimeout time.Duration + OverflowPolicy string + OverflowSamplePercent int + AutoScaleEnabled bool + AutoScaleMinWorkers int + AutoScaleMaxWorkers int + AutoScaleUpPercent int + AutoScaleDownPercent int + AutoScaleUpStep int + AutoScaleDownStep int + AutoScaleInterval time.Duration + AutoScaleCooldown time.Duration +} + +// UsageRecordWorkerPoolStats 使用量记录池运行时统计。 +type UsageRecordWorkerPoolStats struct { + MaxConcurrency int + RunningWorkers int64 + WaitingTasks uint64 + SubmittedTasks uint64 + CompletedTasks uint64 + SuccessfulTasks uint64 + FailedTasks uint64 + DroppedTasks uint64 + DroppedQueueFull uint64 + DroppedPoolStopped uint64 + SyncFallbackTasks uint64 +} + +// UsageRecordWorkerPool 提供“有界队列 + 固定 worker”的异步执行器。 +// 用于替代请求路径里的直接 goroutine,避免高并发时无界堆积。 +type UsageRecordWorkerPool struct { + pool pond.Pool + taskTimeout time.Duration + overflowPolicy string + overflowSamplePercent int + overflowCounter atomic.Uint64 + droppedQueueFull atomic.Uint64 + droppedPoolStopped atomic.Uint64 + syncFallback atomic.Uint64 + lastDropLogNanos atomic.Int64 + autoScaleEnabled bool + autoScaleMinWorkers int + autoScaleMaxWorkers int + autoScaleUpPercent int + autoScaleDownPercent int + autoScaleUpStep int + autoScaleDownStep int + autoScaleInterval time.Duration + autoScaleCooldown time.Duration + lastScaleNanos atomic.Int64 + autoScaleCancel context.CancelFunc + lifecycleWg sync.WaitGroup + stopOnce sync.Once +} + +// NewUsageRecordWorkerPool 从配置构建使用量记录池。 +func NewUsageRecordWorkerPool(cfg *config.Config) *UsageRecordWorkerPool { + opts := usageRecordPoolOptionsFromConfig(cfg) + return NewUsageRecordWorkerPoolWithOptions(opts) +} + +// NewUsageRecordWorkerPoolWithOptions 根据给定参数构建使用量记录池。 +func NewUsageRecordWorkerPoolWithOptions(opts UsageRecordWorkerPoolOptions) *UsageRecordWorkerPool { + opts = normalizeUsageRecordPoolOptions(opts) + + p := &UsageRecordWorkerPool{ + taskTimeout: opts.TaskTimeout, + overflowPolicy: opts.OverflowPolicy, + overflowSamplePercent: opts.OverflowSamplePercent, + autoScaleEnabled: opts.AutoScaleEnabled, + autoScaleMinWorkers: opts.AutoScaleMinWorkers, + autoScaleMaxWorkers: opts.AutoScaleMaxWorkers, + autoScaleUpPercent: opts.AutoScaleUpPercent, + autoScaleDownPercent: opts.AutoScaleDownPercent, + autoScaleUpStep: opts.AutoScaleUpStep, + autoScaleDownStep: opts.AutoScaleDownStep, + autoScaleInterval: opts.AutoScaleInterval, + autoScaleCooldown: opts.AutoScaleCooldown, + } + + p.pool = pond.NewPool( + opts.WorkerCount, + pond.WithQueueSize(opts.QueueSize), + ) + if p.autoScaleEnabled { + p.startAutoScaler() + } + return p +} + +// Submit 提交一个使用量记录任务。 +// 提交失败(队列满)时按 overflowPolicy 执行降级策略:drop/sample/sync。 +func (p *UsageRecordWorkerPool) Submit(task UsageRecordTask) UsageRecordSubmitMode { + if p == nil || task == nil { + return UsageRecordSubmitModeDropped + } + if p.pool == nil || p.pool.Stopped() { + p.droppedPoolStopped.Add(1) + p.logDrop("stopped") + return UsageRecordSubmitModeDropped + } + + _, ok := p.pool.TrySubmit(func() { + p.execute(task) + }) + if ok { + return UsageRecordSubmitModeEnqueued + } + + if p.pool.Stopped() { + p.droppedPoolStopped.Add(1) + p.logDrop("stopped") + return UsageRecordSubmitModeDropped + } + + switch p.overflowPolicy { + case config.UsageRecordOverflowPolicySync: + p.syncFallback.Add(1) + p.execute(task) + return UsageRecordSubmitModeSync + case config.UsageRecordOverflowPolicySample: + if p.shouldSyncFallback() { + p.syncFallback.Add(1) + p.execute(task) + return UsageRecordSubmitModeSync + } + } + + p.droppedQueueFull.Add(1) + p.logDrop("full") + return UsageRecordSubmitModeDropped +} + +// Stats 返回当前池状态与计数器。 +func (p *UsageRecordWorkerPool) Stats() UsageRecordWorkerPoolStats { + if p == nil || p.pool == nil { + return UsageRecordWorkerPoolStats{} + } + return UsageRecordWorkerPoolStats{ + MaxConcurrency: p.pool.MaxConcurrency(), + RunningWorkers: p.pool.RunningWorkers(), + WaitingTasks: p.pool.WaitingTasks(), + SubmittedTasks: p.pool.SubmittedTasks(), + CompletedTasks: p.pool.CompletedTasks(), + SuccessfulTasks: p.pool.SuccessfulTasks(), + FailedTasks: p.pool.FailedTasks(), + DroppedTasks: p.pool.DroppedTasks(), + DroppedQueueFull: p.droppedQueueFull.Load(), + DroppedPoolStopped: p.droppedPoolStopped.Load(), + SyncFallbackTasks: p.syncFallback.Load(), + } +} + +// Stop 停止池并等待队列任务完成。 +func (p *UsageRecordWorkerPool) Stop() { + if p == nil || p.pool == nil { + return + } + p.stopOnce.Do(func() { + if p.autoScaleCancel != nil { + p.autoScaleCancel() + } + p.lifecycleWg.Wait() + p.pool.StopAndWait() + }) +} + +func (p *UsageRecordWorkerPool) startAutoScaler() { + ctx, cancel := context.WithCancel(context.Background()) + p.autoScaleCancel = cancel + + p.lifecycleWg.Add(1) + go func() { + defer p.lifecycleWg.Done() + + ticker := time.NewTicker(p.autoScaleInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.autoScaleTick() + } + } + }() +} + +func (p *UsageRecordWorkerPool) autoScaleTick() { + if p == nil || p.pool == nil || p.pool.Stopped() { + return + } + queueSize := p.pool.QueueSize() + if queueSize <= 0 { + return + } + current := p.pool.MaxConcurrency() + waiting := int(p.pool.WaitingTasks()) + running := int(p.pool.RunningWorkers()) + if current <= 0 || waiting < 0 { + return + } + queuePercent := waiting * 100 / queueSize + runningPercent := 0 + if current > 0 { + runningPercent = running * 100 / current + } + + now := time.Now() + lastScaleNanos := p.lastScaleNanos.Load() + if lastScaleNanos > 0 && now.Sub(time.Unix(0, lastScaleNanos)) < p.autoScaleCooldown { + return + } + + // 扩容优先:队列占用率超过阈值时,按步长提升并发上限。 + if queuePercent >= p.autoScaleUpPercent && current < p.autoScaleMaxWorkers { + target := current + p.autoScaleUpStep + if target > p.autoScaleMaxWorkers { + target = p.autoScaleMaxWorkers + } + p.resizePool(current, target, queuePercent, waiting, runningPercent, queueSize, "scale_up") + return + } + + // 缩容:仅在队列为空且运行利用率低时收缩,避免高负载下“无排队误缩容”导致震荡。 + if queuePercent <= p.autoScaleDownPercent && waiting == 0 && + runningPercent <= p.autoScaleDownPercent && + current > p.autoScaleMinWorkers { + target := current - p.autoScaleDownStep + if target < p.autoScaleMinWorkers { + target = p.autoScaleMinWorkers + } + p.resizePool(current, target, queuePercent, waiting, runningPercent, queueSize, "scale_down") + } +} + +func (p *UsageRecordWorkerPool) resizePool(current, target, queuePercent, waiting, runningPercent, queueSize int, action string) { + if target == current { + return + } + p.pool.Resize(target) + p.lastScaleNanos.Store(time.Now().UnixNano()) + + logger.L().With( + zap.String("component", "service.usage_record_worker_pool"), + zap.String("action", action), + zap.Int("from_workers", current), + zap.Int("to_workers", target), + zap.Int("queue_percent", queuePercent), + zap.Int("waiting_tasks", waiting), + zap.Int("running_percent", runningPercent), + zap.Int("queue_size", queueSize), + ).Info("usage_record.auto_scale") +} + +func (p *UsageRecordWorkerPool) shouldSyncFallback() bool { + if p.overflowSamplePercent <= 0 { + return false + } + n := p.overflowCounter.Add(1) + return int((n-1)%100) < p.overflowSamplePercent +} + +func (p *UsageRecordWorkerPool) execute(task UsageRecordTask) { + ctx, cancel := context.WithTimeout(context.Background(), p.taskTimeout) + defer cancel() + + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "service.usage_record_worker_pool"), + zap.Any("panic", recovered), + ).Error("usage_record.task_panic") + } + }() + + task(ctx) +} + +func (p *UsageRecordWorkerPool) logDrop(reason string) { + now := time.Now().UnixNano() + last := p.lastDropLogNanos.Load() + if now-last < int64(usageRecordDropLogInterval) { + return + } + if !p.lastDropLogNanos.CompareAndSwap(last, now) { + return + } + + stats := p.Stats() + logger.L().With( + zap.String("component", "service.usage_record_worker_pool"), + zap.String("reason", reason), + zap.String("overflow_policy", p.overflowPolicy), + zap.Int64("running_workers", stats.RunningWorkers), + zap.Uint64("waiting_tasks", stats.WaitingTasks), + zap.Uint64("dropped_queue_full", stats.DroppedQueueFull), + zap.Uint64("dropped_pool_stopped", stats.DroppedPoolStopped), + zap.Uint64("sync_fallback_tasks", stats.SyncFallbackTasks), + ).Warn("usage_record.task_dropped") +} + +func usageRecordPoolOptionsFromConfig(cfg *config.Config) UsageRecordWorkerPoolOptions { + opts := UsageRecordWorkerPoolOptions{ + WorkerCount: defaultUsageRecordWorkerCount, + QueueSize: defaultUsageRecordQueueSize, + TaskTimeout: time.Duration(defaultUsageRecordTaskTimeoutSeconds) * time.Second, + OverflowPolicy: defaultUsageRecordOverflowPolicy, + OverflowSamplePercent: defaultUsageRecordOverflowSampleRatio, + AutoScaleEnabled: defaultUsageRecordAutoScaleEnabled, + AutoScaleMinWorkers: defaultUsageRecordAutoScaleMinWorkers, + AutoScaleMaxWorkers: defaultUsageRecordAutoScaleMaxWorkers, + AutoScaleUpPercent: defaultUsageRecordAutoScaleUpPercent, + AutoScaleDownPercent: defaultUsageRecordAutoScaleDownPercent, + AutoScaleUpStep: defaultUsageRecordAutoScaleUpStep, + AutoScaleDownStep: defaultUsageRecordAutoScaleDownStep, + AutoScaleInterval: defaultUsageRecordAutoScaleInterval, + AutoScaleCooldown: defaultUsageRecordAutoScaleCooldown, + } + if cfg == nil { + return opts + } + if cfg.Gateway.UsageRecord.WorkerCount > 0 { + opts.WorkerCount = cfg.Gateway.UsageRecord.WorkerCount + } + if cfg.Gateway.UsageRecord.QueueSize > 0 { + opts.QueueSize = cfg.Gateway.UsageRecord.QueueSize + } + if cfg.Gateway.UsageRecord.TaskTimeoutSeconds > 0 { + opts.TaskTimeout = time.Duration(cfg.Gateway.UsageRecord.TaskTimeoutSeconds) * time.Second + } + if policy := strings.TrimSpace(strings.ToLower(cfg.Gateway.UsageRecord.OverflowPolicy)); policy != "" { + opts.OverflowPolicy = policy + } + if cfg.Gateway.UsageRecord.OverflowSamplePercent >= 0 { + opts.OverflowSamplePercent = cfg.Gateway.UsageRecord.OverflowSamplePercent + } + opts.AutoScaleEnabled = cfg.Gateway.UsageRecord.AutoScaleEnabled + if cfg.Gateway.UsageRecord.AutoScaleMinWorkers > 0 { + opts.AutoScaleMinWorkers = cfg.Gateway.UsageRecord.AutoScaleMinWorkers + } + if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers > 0 { + opts.AutoScaleMaxWorkers = cfg.Gateway.UsageRecord.AutoScaleMaxWorkers + } + if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent > 0 { + opts.AutoScaleUpPercent = cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent + } + if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 0 { + opts.AutoScaleDownPercent = cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent + } + if cfg.Gateway.UsageRecord.AutoScaleUpStep > 0 { + opts.AutoScaleUpStep = cfg.Gateway.UsageRecord.AutoScaleUpStep + } + if cfg.Gateway.UsageRecord.AutoScaleDownStep > 0 { + opts.AutoScaleDownStep = cfg.Gateway.UsageRecord.AutoScaleDownStep + } + if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds > 0 { + opts.AutoScaleInterval = time.Duration(cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds) * time.Second + } + if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds >= 0 { + opts.AutoScaleCooldown = time.Duration(cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) * time.Second + } + return normalizeUsageRecordPoolOptions(opts) +} + +func normalizeUsageRecordPoolOptions(opts UsageRecordWorkerPoolOptions) UsageRecordWorkerPoolOptions { + if opts.WorkerCount <= 0 { + opts.WorkerCount = defaultUsageRecordWorkerCount + } + if opts.QueueSize <= 0 { + opts.QueueSize = defaultUsageRecordQueueSize + } + if opts.TaskTimeout <= 0 { + opts.TaskTimeout = time.Duration(defaultUsageRecordTaskTimeoutSeconds) * time.Second + } + switch strings.ToLower(strings.TrimSpace(opts.OverflowPolicy)) { + case config.UsageRecordOverflowPolicyDrop, + config.UsageRecordOverflowPolicySample, + config.UsageRecordOverflowPolicySync: + opts.OverflowPolicy = strings.ToLower(strings.TrimSpace(opts.OverflowPolicy)) + default: + opts.OverflowPolicy = defaultUsageRecordOverflowPolicy + } + if opts.OverflowSamplePercent < 0 { + opts.OverflowSamplePercent = 0 + } + if opts.OverflowSamplePercent > 100 { + opts.OverflowSamplePercent = 100 + } + if opts.OverflowPolicy == config.UsageRecordOverflowPolicySample && opts.OverflowSamplePercent == 0 { + opts.OverflowSamplePercent = defaultUsageRecordOverflowSampleRatio + } + if opts.AutoScaleEnabled { + if opts.AutoScaleMinWorkers <= 0 { + opts.AutoScaleMinWorkers = defaultUsageRecordAutoScaleMinWorkers + } + if opts.AutoScaleMaxWorkers <= 0 { + opts.AutoScaleMaxWorkers = defaultUsageRecordAutoScaleMaxWorkers + } + if opts.AutoScaleMaxWorkers < opts.AutoScaleMinWorkers { + opts.AutoScaleMaxWorkers = opts.AutoScaleMinWorkers + } + if opts.WorkerCount < opts.AutoScaleMinWorkers { + opts.WorkerCount = opts.AutoScaleMinWorkers + } + if opts.WorkerCount > opts.AutoScaleMaxWorkers { + opts.WorkerCount = opts.AutoScaleMaxWorkers + } + if opts.AutoScaleUpPercent <= 0 || opts.AutoScaleUpPercent > 100 { + opts.AutoScaleUpPercent = defaultUsageRecordAutoScaleUpPercent + } + if opts.AutoScaleDownPercent < 0 || opts.AutoScaleDownPercent >= 100 { + opts.AutoScaleDownPercent = defaultUsageRecordAutoScaleDownPercent + } + if opts.AutoScaleDownPercent >= opts.AutoScaleUpPercent { + opts.AutoScaleDownPercent = max(0, opts.AutoScaleUpPercent/2) + } + if opts.AutoScaleUpStep <= 0 { + opts.AutoScaleUpStep = defaultUsageRecordAutoScaleUpStep + } + if opts.AutoScaleDownStep <= 0 { + opts.AutoScaleDownStep = defaultUsageRecordAutoScaleDownStep + } + if opts.AutoScaleInterval <= 0 { + opts.AutoScaleInterval = defaultUsageRecordAutoScaleInterval + } + if opts.AutoScaleCooldown < 0 { + opts.AutoScaleCooldown = defaultUsageRecordAutoScaleCooldown + } + } else { + opts.AutoScaleMinWorkers = opts.WorkerCount + opts.AutoScaleMaxWorkers = opts.WorkerCount + } + return opts +} + +func (m UsageRecordSubmitMode) String() string { + return string(m) +} + +func (s UsageRecordWorkerPoolStats) String() string { + return fmt.Sprintf("running=%d waiting=%d submitted=%d dropped=%d", s.RunningWorkers, s.WaitingTasks, s.SubmittedTasks, s.DroppedTasks) +} diff --git a/backend/internal/service/usage_record_worker_pool_test.go b/backend/internal/service/usage_record_worker_pool_test.go new file mode 100644 index 00000000..f896e41d --- /dev/null +++ b/backend/internal/service/usage_record_worker_pool_test.go @@ -0,0 +1,488 @@ +package service + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestUsageRecordWorkerPool_SubmitEnqueued(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + OverflowSamplePercent: 0, + }) + t.Cleanup(pool.Stop) + + done := make(chan struct{}) + mode := pool.Submit(func(ctx context.Context) { + close(done) + }) + require.Equal(t, UsageRecordSubmitModeEnqueued, mode) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } + + require.Eventually(t, func() bool { + stats := pool.Stats() + return stats.SubmittedTasks == 1 && stats.SuccessfulTasks == 1 + }, time.Second, 10*time.Millisecond) +} + +func TestUsageRecordWorkerPool_OverflowDrop(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + OverflowSamplePercent: 0, + }) + t.Cleanup(pool.Stop) + + block := make(chan struct{}) + started := make(chan struct{}) + secondDone := make(chan struct{}) + + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + close(started) + <-block + })) + <-started + + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + close(secondDone) + })) + require.Equal(t, UsageRecordSubmitModeDropped, pool.Submit(func(ctx context.Context) {})) + + close(block) + select { + case <-secondDone: + case <-time.After(time.Second): + t.Fatal("queued task not executed") + } + + require.Eventually(t, func() bool { + return pool.Stats().DroppedQueueFull >= 1 + }, time.Second, 10*time.Millisecond) +} + +func TestUsageRecordWorkerPool_OverflowSync(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicySync, + OverflowSamplePercent: 0, + }) + t.Cleanup(pool.Stop) + + block := make(chan struct{}) + started := make(chan struct{}) + secondDone := make(chan struct{}) + var syncExecuted atomic.Bool + + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + close(started) + <-block + })) + <-started + + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + close(secondDone) + })) + + mode := pool.Submit(func(ctx context.Context) { + syncExecuted.Store(true) + }) + require.Equal(t, UsageRecordSubmitModeSync, mode) + require.True(t, syncExecuted.Load()) + + close(block) + select { + case <-secondDone: + case <-time.After(time.Second): + t.Fatal("queued task not executed") + } + + require.Eventually(t, func() bool { + return pool.Stats().SyncFallbackTasks >= 1 + }, time.Second, 10*time.Millisecond) +} + +func TestUsageRecordWorkerPool_OverflowSample(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicySample, + OverflowSamplePercent: 1, + }) + t.Cleanup(pool.Stop) + + block := make(chan struct{}) + started := make(chan struct{}) + secondDone := make(chan struct{}) + var syncExecuted atomic.Bool + + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + close(started) + <-block + })) + <-started + + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + close(secondDone) + })) + + firstOverflow := pool.Submit(func(ctx context.Context) { + syncExecuted.Store(true) + }) + require.Equal(t, UsageRecordSubmitModeSync, firstOverflow) + require.True(t, syncExecuted.Load()) + + secondOverflow := pool.Submit(func(ctx context.Context) {}) + require.Equal(t, UsageRecordSubmitModeDropped, secondOverflow) + + close(block) + select { + case <-secondDone: + case <-time.After(time.Second): + t.Fatal("queued task not executed") + } + + require.Eventually(t, func() bool { + stats := pool.Stats() + return stats.SyncFallbackTasks >= 1 && stats.DroppedQueueFull >= 1 + }, time.Second, 10*time.Millisecond) +} + +func TestUsageRecordWorkerPool_SubmitAfterStop(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + OverflowSamplePercent: 0, + }) + + pool.Stop() + mode := pool.Submit(func(ctx context.Context) {}) + require.Equal(t, UsageRecordSubmitModeDropped, mode) + require.GreaterOrEqual(t, pool.Stats().DroppedPoolStopped, uint64(1)) +} + +func TestUsageRecordWorkerPool_AutoScaleUpAndDown(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 2, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + OverflowSamplePercent: 0, + AutoScaleEnabled: true, + AutoScaleMinWorkers: 1, + AutoScaleMaxWorkers: 4, + AutoScaleUpPercent: 40, + AutoScaleDownPercent: 10, + AutoScaleUpStep: 1, + AutoScaleDownStep: 1, + AutoScaleInterval: 20 * time.Millisecond, + AutoScaleCooldown: 20 * time.Millisecond, + }) + t.Cleanup(pool.Stop) + + block := make(chan struct{}) + + // 填满运行槽位 + 队列,触发扩容阈值。 + for i := 0; i < 8; i++ { + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + <-block + })) + } + + require.Eventually(t, func() bool { + return pool.Stats().MaxConcurrency >= 3 + }, 2*time.Second, 20*time.Millisecond) + + close(block) + + require.Eventually(t, func() bool { + return pool.Stats().CompletedTasks >= 8 + }, 2*time.Second, 20*time.Millisecond) + + require.Eventually(t, func() bool { + return pool.Stats().MaxConcurrency == 1 + }, 2*time.Second, 20*time.Millisecond) +} + +func TestUsageRecordWorkerPool_AutoScaleDownRequiresLowRunningUtilization(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 2, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + OverflowSamplePercent: 0, + AutoScaleEnabled: true, + AutoScaleMinWorkers: 1, + AutoScaleMaxWorkers: 2, + AutoScaleUpPercent: 80, + AutoScaleDownPercent: 50, + AutoScaleUpStep: 1, + AutoScaleDownStep: 1, + AutoScaleInterval: 20 * time.Millisecond, + AutoScaleCooldown: 20 * time.Millisecond, + }) + t.Cleanup(pool.Stop) + + block := make(chan struct{}) + for i := 0; i < 2; i++ { + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + <-block + })) + } + + // 虽然 waiting=0,但 running 利用率为 100%,不应缩容。 + time.Sleep(200 * time.Millisecond) + require.Equal(t, 2, pool.Stats().MaxConcurrency) + + close(block) + require.Eventually(t, func() bool { + return pool.Stats().MaxConcurrency == 1 + }, 2*time.Second, 20*time.Millisecond) +} + +func TestUsageRecordWorkerPool_SubmitNilReceiverAndNilTask(t *testing.T) { + var nilPool *UsageRecordWorkerPool + require.Equal(t, UsageRecordSubmitModeDropped, nilPool.Submit(func(ctx context.Context) {})) + + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + + require.Equal(t, UsageRecordSubmitModeDropped, pool.Submit(nil)) +} + +func TestUsageRecordWorkerPool_AutoScaleDisabledKeepsFixedConcurrency(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 2, + QueueSize: 4, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + AutoScaleMinWorkers: 1, + AutoScaleMaxWorkers: 4, + AutoScaleUpPercent: 10, + AutoScaleDownPercent: 1, + AutoScaleUpStep: 2, + AutoScaleDownStep: 2, + AutoScaleInterval: 10 * time.Millisecond, + AutoScaleCooldown: 10 * time.Millisecond, + }) + t.Cleanup(pool.Stop) + + require.Equal(t, 2, pool.Stats().MaxConcurrency) + + block := make(chan struct{}) + for i := 0; i < 4; i++ { + require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) { + <-block + })) + } + + time.Sleep(120 * time.Millisecond) + require.Equal(t, 2, pool.Stats().MaxConcurrency) + close(block) +} + +func TestUsageRecordWorkerPool_OptionsFromConfig_AutoScaleDisabled(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UsageRecord.WorkerCount = 64 + cfg.Gateway.UsageRecord.QueueSize = 128 + cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 7 + cfg.Gateway.UsageRecord.OverflowPolicy = config.UsageRecordOverflowPolicyDrop + cfg.Gateway.UsageRecord.OverflowSamplePercent = 0 + cfg.Gateway.UsageRecord.AutoScaleEnabled = false + cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 1 + cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 512 + + opts := usageRecordPoolOptionsFromConfig(cfg) + require.False(t, opts.AutoScaleEnabled) + require.Equal(t, 64, opts.WorkerCount) + require.Equal(t, 64, opts.AutoScaleMinWorkers) + require.Equal(t, 64, opts.AutoScaleMaxWorkers) + require.Equal(t, 7*time.Second, opts.TaskTimeout) +} + +func TestUsageRecordWorkerPool_StringHelpers(t *testing.T) { + require.Equal(t, "enqueued", UsageRecordSubmitModeEnqueued.String()) + stats := UsageRecordWorkerPoolStats{RunningWorkers: 2, WaitingTasks: 3, SubmittedTasks: 5, DroppedTasks: 1} + require.Contains(t, stats.String(), "running=2") + require.Contains(t, stats.String(), "waiting=3") +} + +func TestNewUsageRecordWorkerPool_FromConfig(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UsageRecord.WorkerCount = 3 + cfg.Gateway.UsageRecord.QueueSize = 16 + cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 2 + cfg.Gateway.UsageRecord.OverflowPolicy = config.UsageRecordOverflowPolicyDrop + cfg.Gateway.UsageRecord.AutoScaleEnabled = false + + pool := NewUsageRecordWorkerPool(cfg) + t.Cleanup(pool.Stop) + + stats := pool.Stats() + require.Equal(t, 3, stats.MaxConcurrency) +} + +func TestUsageRecordWorkerPool_OptionsFromConfig_NilConfig(t *testing.T) { + opts := usageRecordPoolOptionsFromConfig(nil) + require.Equal(t, defaultUsageRecordWorkerCount, opts.WorkerCount) + require.Equal(t, defaultUsageRecordQueueSize, opts.QueueSize) + require.Equal(t, time.Duration(defaultUsageRecordTaskTimeoutSeconds)*time.Second, opts.TaskTimeout) + require.Equal(t, defaultUsageRecordOverflowPolicy, opts.OverflowPolicy) + require.Equal(t, defaultUsageRecordOverflowSampleRatio, opts.OverflowSamplePercent) + require.True(t, opts.AutoScaleEnabled) + require.Equal(t, defaultUsageRecordAutoScaleMinWorkers, opts.AutoScaleMinWorkers) + require.Equal(t, defaultUsageRecordAutoScaleMaxWorkers, opts.AutoScaleMaxWorkers) +} + +func TestUsageRecordWorkerPool_NormalizeOptions_BoundsAndDefaults(t *testing.T) { + opts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 0, + QueueSize: 0, + TaskTimeout: 0, + OverflowPolicy: "invalid", + OverflowSamplePercent: 300, + AutoScaleEnabled: true, + AutoScaleMinWorkers: 0, + AutoScaleMaxWorkers: 0, + AutoScaleUpPercent: 0, + AutoScaleDownPercent: 100, + AutoScaleUpStep: 0, + AutoScaleDownStep: 0, + AutoScaleInterval: 0, + AutoScaleCooldown: -time.Second, + }) + + require.Equal(t, defaultUsageRecordWorkerCount, opts.WorkerCount) + require.Equal(t, defaultUsageRecordQueueSize, opts.QueueSize) + require.Equal(t, time.Duration(defaultUsageRecordTaskTimeoutSeconds)*time.Second, opts.TaskTimeout) + require.Equal(t, defaultUsageRecordOverflowPolicy, opts.OverflowPolicy) + require.Equal(t, 100, opts.OverflowSamplePercent) + require.Equal(t, defaultUsageRecordAutoScaleMinWorkers, opts.AutoScaleMinWorkers) + require.Equal(t, defaultUsageRecordAutoScaleMaxWorkers, opts.AutoScaleMaxWorkers) + require.Equal(t, defaultUsageRecordAutoScaleUpPercent, opts.AutoScaleUpPercent) + require.Equal(t, defaultUsageRecordAutoScaleDownPercent, opts.AutoScaleDownPercent) + require.Equal(t, defaultUsageRecordAutoScaleUpStep, opts.AutoScaleUpStep) + require.Equal(t, defaultUsageRecordAutoScaleDownStep, opts.AutoScaleDownStep) + require.Equal(t, defaultUsageRecordAutoScaleInterval, opts.AutoScaleInterval) + require.Equal(t, defaultUsageRecordAutoScaleCooldown, opts.AutoScaleCooldown) +} + +func TestUsageRecordWorkerPool_NormalizeOptions_SampleAndAutoScaleDisabled(t *testing.T) { + sampleOpts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 32, + QueueSize: 128, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicySample, + OverflowSamplePercent: 0, + AutoScaleEnabled: true, + AutoScaleMinWorkers: 64, + AutoScaleMaxWorkers: 48, + AutoScaleUpPercent: 30, + AutoScaleDownPercent: 40, + AutoScaleUpStep: 1, + AutoScaleDownStep: 1, + AutoScaleInterval: time.Second, + AutoScaleCooldown: time.Second, + }) + require.Equal(t, defaultUsageRecordOverflowSampleRatio, sampleOpts.OverflowSamplePercent) + require.Equal(t, 64, sampleOpts.AutoScaleMinWorkers) + require.Equal(t, 64, sampleOpts.AutoScaleMaxWorkers) + require.Equal(t, 64, sampleOpts.WorkerCount) + require.Equal(t, 15, sampleOpts.AutoScaleDownPercent) + + fixedOpts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 20, + AutoScaleEnabled: false, + }) + require.Equal(t, 20, fixedOpts.AutoScaleMinWorkers) + require.Equal(t, 20, fixedOpts.AutoScaleMaxWorkers) +} + +func TestUsageRecordWorkerPool_ShouldSyncFallbackEdgeCases(t *testing.T) { + pool := &UsageRecordWorkerPool{overflowSamplePercent: 0} + require.False(t, pool.shouldSyncFallback()) + + pool.overflowSamplePercent = 100 + require.True(t, pool.shouldSyncFallback()) + require.True(t, pool.shouldSyncFallback()) +} + +func TestUsageRecordWorkerPool_StatsAndStop_NilBranches(t *testing.T) { + var nilPool *UsageRecordWorkerPool + require.Equal(t, UsageRecordWorkerPoolStats{}, nilPool.Stats()) + require.NotPanics(t, func() { nilPool.Stop() }) + + emptyPool := &UsageRecordWorkerPool{} + require.Equal(t, UsageRecordWorkerPoolStats{}, emptyPool.Stats()) + require.NotPanics(t, func() { emptyPool.Stop() }) +} + +func TestUsageRecordWorkerPool_Execute_PanicAndTimeout(t *testing.T) { + pool := &UsageRecordWorkerPool{taskTimeout: 30 * time.Millisecond} + + require.NotPanics(t, func() { + pool.execute(func(ctx context.Context) { + panic("boom") + }) + }) + + done := make(chan struct{}) + pool.execute(func(ctx context.Context) { + <-ctx.Done() + close(done) + }) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout context not cancelled") + } +} + +func TestUsageRecordWorkerPool_ResizeAndLogDropBranches(t *testing.T) { + pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: config.UsageRecordOverflowPolicyDrop, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + + // 目标值与当前值相同,应该直接返回。 + pool.resizePool(1, 1, 0, 0, 0, 8, "noop") + require.Equal(t, 1, pool.Stats().MaxConcurrency) + + // 在限流窗口内应静默返回。 + pool.lastDropLogNanos.Store(time.Now().UnixNano()) + require.NotPanics(t, func() { + pool.logDrop("full") + }) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 652f9e00..bfc2ea48 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -300,6 +300,7 @@ var ProviderSet = wire.NewSet( NewTurnstileService, NewSubscriptionService, ProvideConcurrencyService, + NewUsageRecordWorkerPool, ProvideSchedulerSnapshotService, NewIdentityService, NewCRSSyncService,