feat(gateway): 引入使用量记录有界 worker 池与自动扩缩容

- 新增 UsageRecordWorkerPool,支持有界队列、溢出降级策略与自动扩缩容
- 将 Gateway/OpenAI/Sora/Gemini 使用量记录改为提交到统一任务池执行
- 增加 usage_record 配置默认值与校验规则,并补充配置与任务提交相关测试
- 注入并托管 worker 池生命周期,服务退出时统一 StopAndWait

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-22 12:56:57 +08:00
parent 50b9897182
commit 33db7a0fb6
15 changed files with 1575 additions and 70 deletions

View File

@@ -77,6 +77,7 @@ func provideCleanup(
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService, billingCache *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
oauth *service.OAuthService, oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
@@ -176,6 +177,12 @@ func provideCleanup(
billingCache.Stop() billingCache.Stop()
return nil return nil
}}, }},
{"UsageRecordWorkerPool", func() error {
if usageRecordWorkerPool != nil {
usageRecordWorkerPool.Stop()
}
return nil
}},
{"OAuthService", func() error { {"OAuthService", func() error {
oauth.Stop() oauth.Stop()
return nil return nil

View File

@@ -182,12 +182,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService) totpHandler := handler.NewTotpHandler(totpService)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler)
@@ -205,7 +206,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -245,6 +246,7 @@ func provideCleanup(
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService, billingCache *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
oauth *service.OAuthService, oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
@@ -343,6 +345,12 @@ func provideCleanup(
billingCache.Stop() billingCache.Stop()
return nil return nil
}}, }},
{"UsageRecordWorkerPool", func() error {
if usageRecordWorkerPool != nil {
usageRecordWorkerPool.Stop()
}
return nil
}},
{"OAuthService", func() error { {"OAuthService", func() error {
oauth.Stop() oauth.Stop()
return nil return nil

View File

@@ -5,6 +5,7 @@ go 1.25.7
require ( require (
entgo.io/ent v0.14.5 entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alitto/pond/v2 v2.6.2
github.com/cespare/xxhash/v2 v2.3.0 github.com/cespare/xxhash/v2 v2.3.0
github.com/dgraph-io/ristretto v0.2.0 github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1

View File

@@ -14,6 +14,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=

View File

@@ -19,6 +19,13 @@ const (
RunModeSimple = "simple" RunModeSimple = "simple"
) )
// 使用量记录队列溢出策略
const (
UsageRecordOverflowPolicyDrop = "drop"
UsageRecordOverflowPolicySample = "sample"
UsageRecordOverflowPolicySync = "sync"
)
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support // DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
@@ -413,6 +420,42 @@ type GatewayConfig struct {
// TLSFingerprint: TLS指纹伪装配置 // TLSFingerprint: TLS指纹伪装配置
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
WorkerCount int `mapstructure:"worker_count"`
// QueueSize: 队列容量(有界)
QueueSize int `mapstructure:"queue_size"`
// TaskTimeoutSeconds: 单个使用量记录任务超时(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
// OverflowPolicy: 队列满时策略drop/sample/sync
OverflowPolicy string `mapstructure:"overflow_policy"`
// OverflowSamplePercent: sample 策略下同步回写采样百分比1-100
OverflowSamplePercent int `mapstructure:"overflow_sample_percent"`
// AutoScaleEnabled: 是否启用 worker 自动扩缩容
AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"`
// AutoScaleMinWorkers: 自动扩缩容最小 worker 数
AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"`
// AutoScaleMaxWorkers: 自动扩缩容最大 worker 数
AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"`
// AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容
AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"`
// AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容
AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"`
// AutoScaleUpStep: 每次扩容步长
AutoScaleUpStep int `mapstructure:"auto_scale_up_step"`
// AutoScaleDownStep: 每次缩容步长
AutoScaleDownStep int `mapstructure:"auto_scale_down_step"`
// AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒)
AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"`
// AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒)
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
} }
// SoraModelFiltersConfig Sora 模型过滤配置 // SoraModelFiltersConfig Sora 模型过滤配置
@@ -1118,6 +1161,20 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
viper.SetDefault("gateway.usage_record.worker_count", 128)
viper.SetDefault("gateway.usage_record.queue_size", 16384)
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512)
viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70)
viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15)
viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32)
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
// TLS指纹伪装配置默认关闭需要账号级别单独启用 // TLS指纹伪装配置默认关闭需要账号级别单独启用
viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10) viper.SetDefault("concurrency.ping_interval", 10)
@@ -1636,6 +1693,64 @@ func (c *Config) Validate() error {
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
return fmt.Errorf("gateway.max_line_size must be at least 1MB") return fmt.Errorf("gateway.max_line_size must be at least 1MB")
} }
if c.Gateway.UsageRecord.WorkerCount <= 0 {
return fmt.Errorf("gateway.usage_record.worker_count must be positive")
}
if c.Gateway.UsageRecord.QueueSize <= 0 {
return fmt.Errorf("gateway.usage_record.queue_size must be positive")
}
if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive")
}
switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) {
case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync:
default:
return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s",
UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync)
}
if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100")
}
if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) &&
c.Gateway.UsageRecord.OverflowSamplePercent <= 0 {
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample")
}
if c.Gateway.UsageRecord.AutoScaleEnabled {
if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive")
}
if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers {
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers")
}
if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers ||
c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers {
return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers")
}
if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99")
}
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent {
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent")
}
if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive")
}
if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 {
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
}
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
} }

View File

@@ -942,6 +942,74 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative", wantErr: "gateway.max_line_size must be non-negative",
}, },
{
name: "gateway usage record worker count",
mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 },
wantErr: "gateway.usage_record.worker_count",
},
{
name: "gateway usage record queue size",
mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 },
wantErr: "gateway.usage_record.queue_size",
},
{
name: "gateway usage record timeout",
mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 },
wantErr: "gateway.usage_record.task_timeout_seconds",
},
{
name: "gateway usage record overflow policy",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" },
wantErr: "gateway.usage_record.overflow_policy",
},
{
name: "gateway usage record sample percent range",
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 },
wantErr: "gateway.usage_record.overflow_sample_percent",
},
{
name: "gateway usage record sample percent required for sample policy",
mutate: func(c *Config) {
c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample
c.Gateway.UsageRecord.OverflowSamplePercent = 0
},
wantErr: "gateway.usage_record.overflow_sample_percent must be positive",
},
{
name: "gateway usage record auto scale max gte min",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 256
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128
},
wantErr: "gateway.usage_record.auto_scale_max_workers",
},
{
name: "gateway usage record worker in auto scale range",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleMinWorkers = 200
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300
c.Gateway.UsageRecord.WorkerCount = 128
},
wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers",
},
{
name: "gateway usage record auto scale queue thresholds order",
mutate: func(c *Config) {
c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50
c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50
},
wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less",
},
{
name: "gateway usage record auto scale up step",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 },
wantErr: "gateway.usage_record.auto_scale_up_step",
},
{
name: "gateway usage record auto scale interval",
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
},
{ {
name: "gateway scheduling sticky waiting", name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
@@ -1025,6 +1093,99 @@ func TestValidateConfigErrors(t *testing.T) {
} }
} }
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
cfg.Gateway.UsageRecord.WorkerCount = 64
// 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0
cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0
cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100
cfg.Gateway.UsageRecord.AutoScaleUpStep = 0
cfg.Gateway.UsageRecord.AutoScaleDownStep = 0
cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0
cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err)
}
}
func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
resetViperWithJWTSecret(t)
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "log level required",
mutate: func(c *Config) {
c.Log.Level = ""
},
wantErr: "log.level is required",
},
{
name: "log format required",
mutate: func(c *Config) {
c.Log.Format = ""
},
wantErr: "log.format is required",
},
{
name: "log stacktrace required",
mutate: func(c *Config) {
c.Log.StacktraceLevel = ""
},
wantErr: "log.stacktrace_level is required",
},
{
name: "log max backups non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxBackups = -1
},
wantErr: "log.rotation.max_backups must be non-negative",
},
{
name: "log max age non-negative",
mutate: func(c *Config) {
c.Log.Rotation.MaxAgeDays = -1
},
wantErr: "log.rotation.max_age_days must be non-negative",
},
{
name: "sampling thereafter non-negative when disabled",
mutate: func(c *Config) {
c.Log.Sampling.Enabled = false
c.Log.Sampling.Thereafter = -1
},
wantErr: "log.sampling.thereafter must be non-negative",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
tt.mutate(cfg)
err = cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}
func TestSoraCurlCFFISidecarDefaults(t *testing.T) { func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
resetViperWithJWTSecret(t) resetViperWithJWTSecret(t)
@@ -1112,3 +1273,53 @@ func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
} }
} }
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.UsageRecord.WorkerCount != 128 {
t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount)
}
if cfg.Gateway.UsageRecord.QueueSize != 16384 {
t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize)
}
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
}
if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
}
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
}
if !cfg.Gateway.UsageRecord.AutoScaleEnabled {
t.Fatalf("auto_scale_enabled = false, want true")
}
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 {
t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 {
t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers)
}
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 {
t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 {
t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent)
}
if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 {
t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep)
}
if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 {
t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep)
}
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 {
t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds)
}
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 {
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
}
}

View File

@@ -37,6 +37,7 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageService *service.UsageService usageService *service.UsageService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
@@ -54,6 +55,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService, errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config, cfg *config.Config,
) *GatewayHandler { ) *GatewayHandler {
@@ -77,6 +79,7 @@ func NewGatewayHandler(
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageService: usageService, usageService: usageService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
@@ -431,19 +434,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: account,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
ForceCacheBilling: fcb, ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
@@ -452,10 +453,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID), zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID), zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel), zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID), zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err)) ).Error("gateway.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP, forceCacheBilling) })
return return
} }
} }
@@ -700,19 +701,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取 // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: currentAPIKey, APIKey: currentAPIKey,
User: currentAPIKey.User, User: currentAPIKey.User,
Account: usedAccount, Account: account,
Subscription: currentSubscription, Subscription: currentSubscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
ForceCacheBilling: fcb, ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
@@ -721,10 +720,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
zap.Int64("api_key_id", currentAPIKey.ID), zap.Int64("api_key_id", currentAPIKey.ID),
zap.Any("group_id", currentAPIKey.GroupID), zap.Any("group_id", currentAPIKey.GroupID),
zap.String("model", reqModel), zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID), zap.Int64("account_id", account.ID),
).Error("gateway.record_usage_failed", zap.Error(err)) ).Error("gateway.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP, forceCacheBilling) })
reqLog.Debug("gateway.request_completed", reqLog.Debug("gateway.request_completed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount), zap.Int("switch_count", switchCount),
@@ -1508,3 +1507,17 @@ func billingErrorDetails(err error) (status int, code, message string) {
} }
return http.StatusForbidden, "billing_error", msg return http.StatusForbidden, "billing_error", msg
} }
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}

View File

@@ -11,7 +11,6 @@ import (
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -519,22 +518,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
} }
// 6) record usage async (Gemini 使用长上下文双倍计费) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: account,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: ip, IPAddress: clientIP,
LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextThreshold: 200000, // Gemini 200K 阈值
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fcb, ForceCacheBilling: forceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
@@ -543,10 +539,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID), zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID), zap.Any("group_id", apiKey.GroupID),
zap.String("model", modelName), zap.String("model", modelName),
zap.Int64("account_id", usedAccount.ID), zap.Int64("account_id", account.ID),
).Error("gemini.record_usage_failed", zap.Error(err)) ).Error("gemini.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP, forceCacheBilling) })
reqLog.Debug("gemini.request_completed", reqLog.Debug("gemini.request_completed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount), zap.Int("switch_count", switchCount),

View File

@@ -26,6 +26,7 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
@@ -37,6 +38,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService, errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config, cfg *config.Config,
) *OpenAIGatewayHandler { ) *OpenAIGatewayHandler {
@@ -52,6 +54,7 @@ func NewOpenAIGatewayHandler(
gatewayService: gatewayService, gatewayService: gatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService, errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
@@ -378,18 +381,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
// Async record usage // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) { h.submitUsageRecordTask(func(ctx context.Context) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: account,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: ip, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
@@ -398,10 +399,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID), zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID), zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel), zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID), zap.Int64("account_id", account.ID),
).Error("openai.record_usage_failed", zap.Error(err)) ).Error("openai.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP) })
reqLog.Debug("openai.request_completed", reqLog.Debug("openai.request_completed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount), zap.Int("switch_count", switchCount),
@@ -432,6 +433,20 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
} }
} }
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response // handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",

View File

@@ -34,6 +34,7 @@ type SoraGatewayHandler struct {
gatewayService *service.GatewayService gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService soraGatewayService *service.SoraGatewayService
billingCacheService *service.BillingCacheService billingCacheService *service.BillingCacheService
usageRecordWorkerPool *service.UsageRecordWorkerPool
concurrencyHelper *ConcurrencyHelper concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int maxAccountSwitches int
streamMode string streamMode string
@@ -48,6 +49,7 @@ func NewSoraGatewayHandler(
soraGatewayService *service.SoraGatewayService, soraGatewayService *service.SoraGatewayService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService, billingCacheService *service.BillingCacheService,
usageRecordWorkerPool *service.UsageRecordWorkerPool,
cfg *config.Config, cfg *config.Config,
) *SoraGatewayHandler { ) *SoraGatewayHandler {
pingInterval := time.Duration(0) pingInterval := time.Duration(0)
@@ -74,6 +76,7 @@ func NewSoraGatewayHandler(
gatewayService: gatewayService, gatewayService: gatewayService,
soraGatewayService: soraGatewayService, soraGatewayService: soraGatewayService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
usageRecordWorkerPool: usageRecordWorkerPool,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches, maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode), streamMode: strings.ToLower(streamMode),
@@ -397,17 +400,16 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c) clientIP := ip.GetClientIP(c)
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) h.submitUsageRecordTask(func(ctx context.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: account,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: userAgent,
IPAddress: ip, IPAddress: clientIP,
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.sora_gateway.chat_completions"), zap.String("component", "handler.sora_gateway.chat_completions"),
@@ -415,10 +417,10 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID), zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID), zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel), zap.String("model", reqModel),
zap.Int64("account_id", usedAccount.ID), zap.Int64("account_id", account.ID),
).Error("sora.record_usage_failed", zap.Error(err)) ).Error("sora.record_usage_failed", zap.Error(err))
} }
}(result, account, userAgent, clientIP) })
reqLog.Debug("sora.request_completed", reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.Int64("proxy_id", proxyID), zap.Int64("proxy_id", proxyID),
@@ -448,6 +450,20 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string {
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
} }
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
}
if h.usageRecordWorkerPool != nil {
h.usageRecordWorkerPool.Submit(task)
return
}
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
task(ctx)
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)

View File

@@ -432,7 +432,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg) soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg) handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec) c, _ := gin.CreateTestContext(rec)

View File

@@ -0,0 +1,136 @@
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool {
t.Helper()
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: "drop",
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
return pool
}
func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &GatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &GatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &OpenAIGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
done := make(chan struct{})
h.submitUsageRecordTask(func(ctx context.Context) {
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &SoraGatewayHandler{}
var called atomic.Bool
h.submitUsageRecordTask(func(ctx context.Context) {
if _, ok := ctx.Deadline(); !ok {
t.Fatal("expected deadline in fallback context")
}
called.Store(true)
})
require.True(t, called.Load())
}
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
h.submitUsageRecordTask(nil)
})
}

View File

@@ -0,0 +1,496 @@
package service
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/alitto/pond/v2"
"go.uber.org/zap"
)
const (
defaultUsageRecordWorkerCount = 128
defaultUsageRecordQueueSize = 16384
defaultUsageRecordTaskTimeoutSeconds = 5
defaultUsageRecordOverflowPolicy = config.UsageRecordOverflowPolicySample
defaultUsageRecordOverflowSampleRatio = 10
defaultUsageRecordAutoScaleEnabled = true
defaultUsageRecordAutoScaleMinWorkers = 128
defaultUsageRecordAutoScaleMaxWorkers = 512
defaultUsageRecordAutoScaleUpPercent = 70
defaultUsageRecordAutoScaleDownPercent = 15
defaultUsageRecordAutoScaleUpStep = 32
defaultUsageRecordAutoScaleDownStep = 16
defaultUsageRecordAutoScaleInterval = 3 * time.Second
defaultUsageRecordAutoScaleCooldown = 10 * time.Second
usageRecordDropLogInterval = 5 * time.Second
)
// UsageRecordTask 是提交到使用量记录池的任务。
// 任务实现应自行处理业务错误日志;池本身只负责调度与超时控制。
type UsageRecordTask func(ctx context.Context)
// UsageRecordSubmitMode 表示任务提交结果。
type UsageRecordSubmitMode string
const (
UsageRecordSubmitModeEnqueued UsageRecordSubmitMode = "enqueued"
UsageRecordSubmitModeDropped UsageRecordSubmitMode = "dropped"
UsageRecordSubmitModeSync UsageRecordSubmitMode = "sync_fallback"
)
// UsageRecordWorkerPoolOptions 使用量记录池配置。
type UsageRecordWorkerPoolOptions struct {
WorkerCount int
QueueSize int
TaskTimeout time.Duration
OverflowPolicy string
OverflowSamplePercent int
AutoScaleEnabled bool
AutoScaleMinWorkers int
AutoScaleMaxWorkers int
AutoScaleUpPercent int
AutoScaleDownPercent int
AutoScaleUpStep int
AutoScaleDownStep int
AutoScaleInterval time.Duration
AutoScaleCooldown time.Duration
}
// UsageRecordWorkerPoolStats 使用量记录池运行时统计。
type UsageRecordWorkerPoolStats struct {
MaxConcurrency int
RunningWorkers int64
WaitingTasks uint64
SubmittedTasks uint64
CompletedTasks uint64
SuccessfulTasks uint64
FailedTasks uint64
DroppedTasks uint64
DroppedQueueFull uint64
DroppedPoolStopped uint64
SyncFallbackTasks uint64
}
// UsageRecordWorkerPool 提供“有界队列 + 固定 worker”的异步执行器。
// 用于替代请求路径里的直接 goroutine避免高并发时无界堆积。
type UsageRecordWorkerPool struct {
pool pond.Pool
taskTimeout time.Duration
overflowPolicy string
overflowSamplePercent int
overflowCounter atomic.Uint64
droppedQueueFull atomic.Uint64
droppedPoolStopped atomic.Uint64
syncFallback atomic.Uint64
lastDropLogNanos atomic.Int64
autoScaleEnabled bool
autoScaleMinWorkers int
autoScaleMaxWorkers int
autoScaleUpPercent int
autoScaleDownPercent int
autoScaleUpStep int
autoScaleDownStep int
autoScaleInterval time.Duration
autoScaleCooldown time.Duration
lastScaleNanos atomic.Int64
autoScaleCancel context.CancelFunc
lifecycleWg sync.WaitGroup
stopOnce sync.Once
}
// NewUsageRecordWorkerPool 从配置构建使用量记录池。
func NewUsageRecordWorkerPool(cfg *config.Config) *UsageRecordWorkerPool {
opts := usageRecordPoolOptionsFromConfig(cfg)
return NewUsageRecordWorkerPoolWithOptions(opts)
}
// NewUsageRecordWorkerPoolWithOptions 根据给定参数构建使用量记录池。
func NewUsageRecordWorkerPoolWithOptions(opts UsageRecordWorkerPoolOptions) *UsageRecordWorkerPool {
opts = normalizeUsageRecordPoolOptions(opts)
p := &UsageRecordWorkerPool{
taskTimeout: opts.TaskTimeout,
overflowPolicy: opts.OverflowPolicy,
overflowSamplePercent: opts.OverflowSamplePercent,
autoScaleEnabled: opts.AutoScaleEnabled,
autoScaleMinWorkers: opts.AutoScaleMinWorkers,
autoScaleMaxWorkers: opts.AutoScaleMaxWorkers,
autoScaleUpPercent: opts.AutoScaleUpPercent,
autoScaleDownPercent: opts.AutoScaleDownPercent,
autoScaleUpStep: opts.AutoScaleUpStep,
autoScaleDownStep: opts.AutoScaleDownStep,
autoScaleInterval: opts.AutoScaleInterval,
autoScaleCooldown: opts.AutoScaleCooldown,
}
p.pool = pond.NewPool(
opts.WorkerCount,
pond.WithQueueSize(opts.QueueSize),
)
if p.autoScaleEnabled {
p.startAutoScaler()
}
return p
}
// Submit 提交一个使用量记录任务。
// 提交失败(队列满)时按 overflowPolicy 执行降级策略drop/sample/sync。
func (p *UsageRecordWorkerPool) Submit(task UsageRecordTask) UsageRecordSubmitMode {
if p == nil || task == nil {
return UsageRecordSubmitModeDropped
}
if p.pool == nil || p.pool.Stopped() {
p.droppedPoolStopped.Add(1)
p.logDrop("stopped")
return UsageRecordSubmitModeDropped
}
_, ok := p.pool.TrySubmit(func() {
p.execute(task)
})
if ok {
return UsageRecordSubmitModeEnqueued
}
if p.pool.Stopped() {
p.droppedPoolStopped.Add(1)
p.logDrop("stopped")
return UsageRecordSubmitModeDropped
}
switch p.overflowPolicy {
case config.UsageRecordOverflowPolicySync:
p.syncFallback.Add(1)
p.execute(task)
return UsageRecordSubmitModeSync
case config.UsageRecordOverflowPolicySample:
if p.shouldSyncFallback() {
p.syncFallback.Add(1)
p.execute(task)
return UsageRecordSubmitModeSync
}
}
p.droppedQueueFull.Add(1)
p.logDrop("full")
return UsageRecordSubmitModeDropped
}
// Stats 返回当前池状态与计数器。
func (p *UsageRecordWorkerPool) Stats() UsageRecordWorkerPoolStats {
if p == nil || p.pool == nil {
return UsageRecordWorkerPoolStats{}
}
return UsageRecordWorkerPoolStats{
MaxConcurrency: p.pool.MaxConcurrency(),
RunningWorkers: p.pool.RunningWorkers(),
WaitingTasks: p.pool.WaitingTasks(),
SubmittedTasks: p.pool.SubmittedTasks(),
CompletedTasks: p.pool.CompletedTasks(),
SuccessfulTasks: p.pool.SuccessfulTasks(),
FailedTasks: p.pool.FailedTasks(),
DroppedTasks: p.pool.DroppedTasks(),
DroppedQueueFull: p.droppedQueueFull.Load(),
DroppedPoolStopped: p.droppedPoolStopped.Load(),
SyncFallbackTasks: p.syncFallback.Load(),
}
}
// Stop 停止池并等待队列任务完成。
func (p *UsageRecordWorkerPool) Stop() {
if p == nil || p.pool == nil {
return
}
p.stopOnce.Do(func() {
if p.autoScaleCancel != nil {
p.autoScaleCancel()
}
p.lifecycleWg.Wait()
p.pool.StopAndWait()
})
}
func (p *UsageRecordWorkerPool) startAutoScaler() {
ctx, cancel := context.WithCancel(context.Background())
p.autoScaleCancel = cancel
p.lifecycleWg.Add(1)
go func() {
defer p.lifecycleWg.Done()
ticker := time.NewTicker(p.autoScaleInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.autoScaleTick()
}
}
}()
}
func (p *UsageRecordWorkerPool) autoScaleTick() {
if p == nil || p.pool == nil || p.pool.Stopped() {
return
}
queueSize := p.pool.QueueSize()
if queueSize <= 0 {
return
}
current := p.pool.MaxConcurrency()
waiting := int(p.pool.WaitingTasks())
running := int(p.pool.RunningWorkers())
if current <= 0 || waiting < 0 {
return
}
queuePercent := waiting * 100 / queueSize
runningPercent := 0
if current > 0 {
runningPercent = running * 100 / current
}
now := time.Now()
lastScaleNanos := p.lastScaleNanos.Load()
if lastScaleNanos > 0 && now.Sub(time.Unix(0, lastScaleNanos)) < p.autoScaleCooldown {
return
}
// 扩容优先:队列占用率超过阈值时,按步长提升并发上限。
if queuePercent >= p.autoScaleUpPercent && current < p.autoScaleMaxWorkers {
target := current + p.autoScaleUpStep
if target > p.autoScaleMaxWorkers {
target = p.autoScaleMaxWorkers
}
p.resizePool(current, target, queuePercent, waiting, runningPercent, queueSize, "scale_up")
return
}
// 缩容:仅在队列为空且运行利用率低时收缩,避免高负载下“无排队误缩容”导致震荡。
if queuePercent <= p.autoScaleDownPercent && waiting == 0 &&
runningPercent <= p.autoScaleDownPercent &&
current > p.autoScaleMinWorkers {
target := current - p.autoScaleDownStep
if target < p.autoScaleMinWorkers {
target = p.autoScaleMinWorkers
}
p.resizePool(current, target, queuePercent, waiting, runningPercent, queueSize, "scale_down")
}
}
func (p *UsageRecordWorkerPool) resizePool(current, target, queuePercent, waiting, runningPercent, queueSize int, action string) {
if target == current {
return
}
p.pool.Resize(target)
p.lastScaleNanos.Store(time.Now().UnixNano())
logger.L().With(
zap.String("component", "service.usage_record_worker_pool"),
zap.String("action", action),
zap.Int("from_workers", current),
zap.Int("to_workers", target),
zap.Int("queue_percent", queuePercent),
zap.Int("waiting_tasks", waiting),
zap.Int("running_percent", runningPercent),
zap.Int("queue_size", queueSize),
).Info("usage_record.auto_scale")
}
func (p *UsageRecordWorkerPool) shouldSyncFallback() bool {
if p.overflowSamplePercent <= 0 {
return false
}
n := p.overflowCounter.Add(1)
return int((n-1)%100) < p.overflowSamplePercent
}
func (p *UsageRecordWorkerPool) execute(task UsageRecordTask) {
ctx, cancel := context.WithTimeout(context.Background(), p.taskTimeout)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "service.usage_record_worker_pool"),
zap.Any("panic", recovered),
).Error("usage_record.task_panic")
}
}()
task(ctx)
}
func (p *UsageRecordWorkerPool) logDrop(reason string) {
now := time.Now().UnixNano()
last := p.lastDropLogNanos.Load()
if now-last < int64(usageRecordDropLogInterval) {
return
}
if !p.lastDropLogNanos.CompareAndSwap(last, now) {
return
}
stats := p.Stats()
logger.L().With(
zap.String("component", "service.usage_record_worker_pool"),
zap.String("reason", reason),
zap.String("overflow_policy", p.overflowPolicy),
zap.Int64("running_workers", stats.RunningWorkers),
zap.Uint64("waiting_tasks", stats.WaitingTasks),
zap.Uint64("dropped_queue_full", stats.DroppedQueueFull),
zap.Uint64("dropped_pool_stopped", stats.DroppedPoolStopped),
zap.Uint64("sync_fallback_tasks", stats.SyncFallbackTasks),
).Warn("usage_record.task_dropped")
}
func usageRecordPoolOptionsFromConfig(cfg *config.Config) UsageRecordWorkerPoolOptions {
opts := UsageRecordWorkerPoolOptions{
WorkerCount: defaultUsageRecordWorkerCount,
QueueSize: defaultUsageRecordQueueSize,
TaskTimeout: time.Duration(defaultUsageRecordTaskTimeoutSeconds) * time.Second,
OverflowPolicy: defaultUsageRecordOverflowPolicy,
OverflowSamplePercent: defaultUsageRecordOverflowSampleRatio,
AutoScaleEnabled: defaultUsageRecordAutoScaleEnabled,
AutoScaleMinWorkers: defaultUsageRecordAutoScaleMinWorkers,
AutoScaleMaxWorkers: defaultUsageRecordAutoScaleMaxWorkers,
AutoScaleUpPercent: defaultUsageRecordAutoScaleUpPercent,
AutoScaleDownPercent: defaultUsageRecordAutoScaleDownPercent,
AutoScaleUpStep: defaultUsageRecordAutoScaleUpStep,
AutoScaleDownStep: defaultUsageRecordAutoScaleDownStep,
AutoScaleInterval: defaultUsageRecordAutoScaleInterval,
AutoScaleCooldown: defaultUsageRecordAutoScaleCooldown,
}
if cfg == nil {
return opts
}
if cfg.Gateway.UsageRecord.WorkerCount > 0 {
opts.WorkerCount = cfg.Gateway.UsageRecord.WorkerCount
}
if cfg.Gateway.UsageRecord.QueueSize > 0 {
opts.QueueSize = cfg.Gateway.UsageRecord.QueueSize
}
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds > 0 {
opts.TaskTimeout = time.Duration(cfg.Gateway.UsageRecord.TaskTimeoutSeconds) * time.Second
}
if policy := strings.TrimSpace(strings.ToLower(cfg.Gateway.UsageRecord.OverflowPolicy)); policy != "" {
opts.OverflowPolicy = policy
}
if cfg.Gateway.UsageRecord.OverflowSamplePercent >= 0 {
opts.OverflowSamplePercent = cfg.Gateway.UsageRecord.OverflowSamplePercent
}
opts.AutoScaleEnabled = cfg.Gateway.UsageRecord.AutoScaleEnabled
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers > 0 {
opts.AutoScaleMinWorkers = cfg.Gateway.UsageRecord.AutoScaleMinWorkers
}
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers > 0 {
opts.AutoScaleMaxWorkers = cfg.Gateway.UsageRecord.AutoScaleMaxWorkers
}
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent > 0 {
opts.AutoScaleUpPercent = cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent
}
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 0 {
opts.AutoScaleDownPercent = cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent
}
if cfg.Gateway.UsageRecord.AutoScaleUpStep > 0 {
opts.AutoScaleUpStep = cfg.Gateway.UsageRecord.AutoScaleUpStep
}
if cfg.Gateway.UsageRecord.AutoScaleDownStep > 0 {
opts.AutoScaleDownStep = cfg.Gateway.UsageRecord.AutoScaleDownStep
}
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds > 0 {
opts.AutoScaleInterval = time.Duration(cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds) * time.Second
}
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds >= 0 {
opts.AutoScaleCooldown = time.Duration(cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) * time.Second
}
return normalizeUsageRecordPoolOptions(opts)
}
func normalizeUsageRecordPoolOptions(opts UsageRecordWorkerPoolOptions) UsageRecordWorkerPoolOptions {
if opts.WorkerCount <= 0 {
opts.WorkerCount = defaultUsageRecordWorkerCount
}
if opts.QueueSize <= 0 {
opts.QueueSize = defaultUsageRecordQueueSize
}
if opts.TaskTimeout <= 0 {
opts.TaskTimeout = time.Duration(defaultUsageRecordTaskTimeoutSeconds) * time.Second
}
switch strings.ToLower(strings.TrimSpace(opts.OverflowPolicy)) {
case config.UsageRecordOverflowPolicyDrop,
config.UsageRecordOverflowPolicySample,
config.UsageRecordOverflowPolicySync:
opts.OverflowPolicy = strings.ToLower(strings.TrimSpace(opts.OverflowPolicy))
default:
opts.OverflowPolicy = defaultUsageRecordOverflowPolicy
}
if opts.OverflowSamplePercent < 0 {
opts.OverflowSamplePercent = 0
}
if opts.OverflowSamplePercent > 100 {
opts.OverflowSamplePercent = 100
}
if opts.OverflowPolicy == config.UsageRecordOverflowPolicySample && opts.OverflowSamplePercent == 0 {
opts.OverflowSamplePercent = defaultUsageRecordOverflowSampleRatio
}
if opts.AutoScaleEnabled {
if opts.AutoScaleMinWorkers <= 0 {
opts.AutoScaleMinWorkers = defaultUsageRecordAutoScaleMinWorkers
}
if opts.AutoScaleMaxWorkers <= 0 {
opts.AutoScaleMaxWorkers = defaultUsageRecordAutoScaleMaxWorkers
}
if opts.AutoScaleMaxWorkers < opts.AutoScaleMinWorkers {
opts.AutoScaleMaxWorkers = opts.AutoScaleMinWorkers
}
if opts.WorkerCount < opts.AutoScaleMinWorkers {
opts.WorkerCount = opts.AutoScaleMinWorkers
}
if opts.WorkerCount > opts.AutoScaleMaxWorkers {
opts.WorkerCount = opts.AutoScaleMaxWorkers
}
if opts.AutoScaleUpPercent <= 0 || opts.AutoScaleUpPercent > 100 {
opts.AutoScaleUpPercent = defaultUsageRecordAutoScaleUpPercent
}
if opts.AutoScaleDownPercent < 0 || opts.AutoScaleDownPercent >= 100 {
opts.AutoScaleDownPercent = defaultUsageRecordAutoScaleDownPercent
}
if opts.AutoScaleDownPercent >= opts.AutoScaleUpPercent {
opts.AutoScaleDownPercent = max(0, opts.AutoScaleUpPercent/2)
}
if opts.AutoScaleUpStep <= 0 {
opts.AutoScaleUpStep = defaultUsageRecordAutoScaleUpStep
}
if opts.AutoScaleDownStep <= 0 {
opts.AutoScaleDownStep = defaultUsageRecordAutoScaleDownStep
}
if opts.AutoScaleInterval <= 0 {
opts.AutoScaleInterval = defaultUsageRecordAutoScaleInterval
}
if opts.AutoScaleCooldown < 0 {
opts.AutoScaleCooldown = defaultUsageRecordAutoScaleCooldown
}
} else {
opts.AutoScaleMinWorkers = opts.WorkerCount
opts.AutoScaleMaxWorkers = opts.WorkerCount
}
return opts
}
func (m UsageRecordSubmitMode) String() string {
return string(m)
}
func (s UsageRecordWorkerPoolStats) String() string {
return fmt.Sprintf("running=%d waiting=%d submitted=%d dropped=%d", s.RunningWorkers, s.WaitingTasks, s.SubmittedTasks, s.DroppedTasks)
}

View File

@@ -0,0 +1,488 @@
package service
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestUsageRecordWorkerPool_SubmitEnqueued(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
OverflowSamplePercent: 0,
})
t.Cleanup(pool.Stop)
done := make(chan struct{})
mode := pool.Submit(func(ctx context.Context) {
close(done)
})
require.Equal(t, UsageRecordSubmitModeEnqueued, mode)
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("task not executed")
}
require.Eventually(t, func() bool {
stats := pool.Stats()
return stats.SubmittedTasks == 1 && stats.SuccessfulTasks == 1
}, time.Second, 10*time.Millisecond)
}
func TestUsageRecordWorkerPool_OverflowDrop(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
OverflowSamplePercent: 0,
})
t.Cleanup(pool.Stop)
block := make(chan struct{})
started := make(chan struct{})
secondDone := make(chan struct{})
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
close(started)
<-block
}))
<-started
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
close(secondDone)
}))
require.Equal(t, UsageRecordSubmitModeDropped, pool.Submit(func(ctx context.Context) {}))
close(block)
select {
case <-secondDone:
case <-time.After(time.Second):
t.Fatal("queued task not executed")
}
require.Eventually(t, func() bool {
return pool.Stats().DroppedQueueFull >= 1
}, time.Second, 10*time.Millisecond)
}
func TestUsageRecordWorkerPool_OverflowSync(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicySync,
OverflowSamplePercent: 0,
})
t.Cleanup(pool.Stop)
block := make(chan struct{})
started := make(chan struct{})
secondDone := make(chan struct{})
var syncExecuted atomic.Bool
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
close(started)
<-block
}))
<-started
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
close(secondDone)
}))
mode := pool.Submit(func(ctx context.Context) {
syncExecuted.Store(true)
})
require.Equal(t, UsageRecordSubmitModeSync, mode)
require.True(t, syncExecuted.Load())
close(block)
select {
case <-secondDone:
case <-time.After(time.Second):
t.Fatal("queued task not executed")
}
require.Eventually(t, func() bool {
return pool.Stats().SyncFallbackTasks >= 1
}, time.Second, 10*time.Millisecond)
}
func TestUsageRecordWorkerPool_OverflowSample(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicySample,
OverflowSamplePercent: 1,
})
t.Cleanup(pool.Stop)
block := make(chan struct{})
started := make(chan struct{})
secondDone := make(chan struct{})
var syncExecuted atomic.Bool
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
close(started)
<-block
}))
<-started
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
close(secondDone)
}))
firstOverflow := pool.Submit(func(ctx context.Context) {
syncExecuted.Store(true)
})
require.Equal(t, UsageRecordSubmitModeSync, firstOverflow)
require.True(t, syncExecuted.Load())
secondOverflow := pool.Submit(func(ctx context.Context) {})
require.Equal(t, UsageRecordSubmitModeDropped, secondOverflow)
close(block)
select {
case <-secondDone:
case <-time.After(time.Second):
t.Fatal("queued task not executed")
}
require.Eventually(t, func() bool {
stats := pool.Stats()
return stats.SyncFallbackTasks >= 1 && stats.DroppedQueueFull >= 1
}, time.Second, 10*time.Millisecond)
}
func TestUsageRecordWorkerPool_SubmitAfterStop(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
OverflowSamplePercent: 0,
})
pool.Stop()
mode := pool.Submit(func(ctx context.Context) {})
require.Equal(t, UsageRecordSubmitModeDropped, mode)
require.GreaterOrEqual(t, pool.Stats().DroppedPoolStopped, uint64(1))
}
func TestUsageRecordWorkerPool_AutoScaleUpAndDown(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 2,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
OverflowSamplePercent: 0,
AutoScaleEnabled: true,
AutoScaleMinWorkers: 1,
AutoScaleMaxWorkers: 4,
AutoScaleUpPercent: 40,
AutoScaleDownPercent: 10,
AutoScaleUpStep: 1,
AutoScaleDownStep: 1,
AutoScaleInterval: 20 * time.Millisecond,
AutoScaleCooldown: 20 * time.Millisecond,
})
t.Cleanup(pool.Stop)
block := make(chan struct{})
// 填满运行槽位 + 队列,触发扩容阈值。
for i := 0; i < 8; i++ {
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
<-block
}))
}
require.Eventually(t, func() bool {
return pool.Stats().MaxConcurrency >= 3
}, 2*time.Second, 20*time.Millisecond)
close(block)
require.Eventually(t, func() bool {
return pool.Stats().CompletedTasks >= 8
}, 2*time.Second, 20*time.Millisecond)
require.Eventually(t, func() bool {
return pool.Stats().MaxConcurrency == 1
}, 2*time.Second, 20*time.Millisecond)
}
func TestUsageRecordWorkerPool_AutoScaleDownRequiresLowRunningUtilization(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 2,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
OverflowSamplePercent: 0,
AutoScaleEnabled: true,
AutoScaleMinWorkers: 1,
AutoScaleMaxWorkers: 2,
AutoScaleUpPercent: 80,
AutoScaleDownPercent: 50,
AutoScaleUpStep: 1,
AutoScaleDownStep: 1,
AutoScaleInterval: 20 * time.Millisecond,
AutoScaleCooldown: 20 * time.Millisecond,
})
t.Cleanup(pool.Stop)
block := make(chan struct{})
for i := 0; i < 2; i++ {
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
<-block
}))
}
// 虽然 waiting=0但 running 利用率为 100%,不应缩容。
time.Sleep(200 * time.Millisecond)
require.Equal(t, 2, pool.Stats().MaxConcurrency)
close(block)
require.Eventually(t, func() bool {
return pool.Stats().MaxConcurrency == 1
}, 2*time.Second, 20*time.Millisecond)
}
func TestUsageRecordWorkerPool_SubmitNilReceiverAndNilTask(t *testing.T) {
var nilPool *UsageRecordWorkerPool
require.Equal(t, UsageRecordSubmitModeDropped, nilPool.Submit(func(ctx context.Context) {}))
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 1,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
require.Equal(t, UsageRecordSubmitModeDropped, pool.Submit(nil))
}
func TestUsageRecordWorkerPool_AutoScaleDisabledKeepsFixedConcurrency(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 2,
QueueSize: 4,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
OverflowSamplePercent: 0,
AutoScaleEnabled: false,
AutoScaleMinWorkers: 1,
AutoScaleMaxWorkers: 4,
AutoScaleUpPercent: 10,
AutoScaleDownPercent: 1,
AutoScaleUpStep: 2,
AutoScaleDownStep: 2,
AutoScaleInterval: 10 * time.Millisecond,
AutoScaleCooldown: 10 * time.Millisecond,
})
t.Cleanup(pool.Stop)
require.Equal(t, 2, pool.Stats().MaxConcurrency)
block := make(chan struct{})
for i := 0; i < 4; i++ {
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
<-block
}))
}
time.Sleep(120 * time.Millisecond)
require.Equal(t, 2, pool.Stats().MaxConcurrency)
close(block)
}
func TestUsageRecordWorkerPool_OptionsFromConfig_AutoScaleDisabled(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.UsageRecord.WorkerCount = 64
cfg.Gateway.UsageRecord.QueueSize = 128
cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 7
cfg.Gateway.UsageRecord.OverflowPolicy = config.UsageRecordOverflowPolicyDrop
cfg.Gateway.UsageRecord.OverflowSamplePercent = 0
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 1
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 512
opts := usageRecordPoolOptionsFromConfig(cfg)
require.False(t, opts.AutoScaleEnabled)
require.Equal(t, 64, opts.WorkerCount)
require.Equal(t, 64, opts.AutoScaleMinWorkers)
require.Equal(t, 64, opts.AutoScaleMaxWorkers)
require.Equal(t, 7*time.Second, opts.TaskTimeout)
}
func TestUsageRecordWorkerPool_StringHelpers(t *testing.T) {
require.Equal(t, "enqueued", UsageRecordSubmitModeEnqueued.String())
stats := UsageRecordWorkerPoolStats{RunningWorkers: 2, WaitingTasks: 3, SubmittedTasks: 5, DroppedTasks: 1}
require.Contains(t, stats.String(), "running=2")
require.Contains(t, stats.String(), "waiting=3")
}
func TestNewUsageRecordWorkerPool_FromConfig(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.UsageRecord.WorkerCount = 3
cfg.Gateway.UsageRecord.QueueSize = 16
cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 2
cfg.Gateway.UsageRecord.OverflowPolicy = config.UsageRecordOverflowPolicyDrop
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
pool := NewUsageRecordWorkerPool(cfg)
t.Cleanup(pool.Stop)
stats := pool.Stats()
require.Equal(t, 3, stats.MaxConcurrency)
}
func TestUsageRecordWorkerPool_OptionsFromConfig_NilConfig(t *testing.T) {
opts := usageRecordPoolOptionsFromConfig(nil)
require.Equal(t, defaultUsageRecordWorkerCount, opts.WorkerCount)
require.Equal(t, defaultUsageRecordQueueSize, opts.QueueSize)
require.Equal(t, time.Duration(defaultUsageRecordTaskTimeoutSeconds)*time.Second, opts.TaskTimeout)
require.Equal(t, defaultUsageRecordOverflowPolicy, opts.OverflowPolicy)
require.Equal(t, defaultUsageRecordOverflowSampleRatio, opts.OverflowSamplePercent)
require.True(t, opts.AutoScaleEnabled)
require.Equal(t, defaultUsageRecordAutoScaleMinWorkers, opts.AutoScaleMinWorkers)
require.Equal(t, defaultUsageRecordAutoScaleMaxWorkers, opts.AutoScaleMaxWorkers)
}
func TestUsageRecordWorkerPool_NormalizeOptions_BoundsAndDefaults(t *testing.T) {
opts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 0,
QueueSize: 0,
TaskTimeout: 0,
OverflowPolicy: "invalid",
OverflowSamplePercent: 300,
AutoScaleEnabled: true,
AutoScaleMinWorkers: 0,
AutoScaleMaxWorkers: 0,
AutoScaleUpPercent: 0,
AutoScaleDownPercent: 100,
AutoScaleUpStep: 0,
AutoScaleDownStep: 0,
AutoScaleInterval: 0,
AutoScaleCooldown: -time.Second,
})
require.Equal(t, defaultUsageRecordWorkerCount, opts.WorkerCount)
require.Equal(t, defaultUsageRecordQueueSize, opts.QueueSize)
require.Equal(t, time.Duration(defaultUsageRecordTaskTimeoutSeconds)*time.Second, opts.TaskTimeout)
require.Equal(t, defaultUsageRecordOverflowPolicy, opts.OverflowPolicy)
require.Equal(t, 100, opts.OverflowSamplePercent)
require.Equal(t, defaultUsageRecordAutoScaleMinWorkers, opts.AutoScaleMinWorkers)
require.Equal(t, defaultUsageRecordAutoScaleMaxWorkers, opts.AutoScaleMaxWorkers)
require.Equal(t, defaultUsageRecordAutoScaleUpPercent, opts.AutoScaleUpPercent)
require.Equal(t, defaultUsageRecordAutoScaleDownPercent, opts.AutoScaleDownPercent)
require.Equal(t, defaultUsageRecordAutoScaleUpStep, opts.AutoScaleUpStep)
require.Equal(t, defaultUsageRecordAutoScaleDownStep, opts.AutoScaleDownStep)
require.Equal(t, defaultUsageRecordAutoScaleInterval, opts.AutoScaleInterval)
require.Equal(t, defaultUsageRecordAutoScaleCooldown, opts.AutoScaleCooldown)
}
func TestUsageRecordWorkerPool_NormalizeOptions_SampleAndAutoScaleDisabled(t *testing.T) {
sampleOpts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 32,
QueueSize: 128,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicySample,
OverflowSamplePercent: 0,
AutoScaleEnabled: true,
AutoScaleMinWorkers: 64,
AutoScaleMaxWorkers: 48,
AutoScaleUpPercent: 30,
AutoScaleDownPercent: 40,
AutoScaleUpStep: 1,
AutoScaleDownStep: 1,
AutoScaleInterval: time.Second,
AutoScaleCooldown: time.Second,
})
require.Equal(t, defaultUsageRecordOverflowSampleRatio, sampleOpts.OverflowSamplePercent)
require.Equal(t, 64, sampleOpts.AutoScaleMinWorkers)
require.Equal(t, 64, sampleOpts.AutoScaleMaxWorkers)
require.Equal(t, 64, sampleOpts.WorkerCount)
require.Equal(t, 15, sampleOpts.AutoScaleDownPercent)
fixedOpts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 20,
AutoScaleEnabled: false,
})
require.Equal(t, 20, fixedOpts.AutoScaleMinWorkers)
require.Equal(t, 20, fixedOpts.AutoScaleMaxWorkers)
}
func TestUsageRecordWorkerPool_ShouldSyncFallbackEdgeCases(t *testing.T) {
pool := &UsageRecordWorkerPool{overflowSamplePercent: 0}
require.False(t, pool.shouldSyncFallback())
pool.overflowSamplePercent = 100
require.True(t, pool.shouldSyncFallback())
require.True(t, pool.shouldSyncFallback())
}
func TestUsageRecordWorkerPool_StatsAndStop_NilBranches(t *testing.T) {
var nilPool *UsageRecordWorkerPool
require.Equal(t, UsageRecordWorkerPoolStats{}, nilPool.Stats())
require.NotPanics(t, func() { nilPool.Stop() })
emptyPool := &UsageRecordWorkerPool{}
require.Equal(t, UsageRecordWorkerPoolStats{}, emptyPool.Stats())
require.NotPanics(t, func() { emptyPool.Stop() })
}
func TestUsageRecordWorkerPool_Execute_PanicAndTimeout(t *testing.T) {
pool := &UsageRecordWorkerPool{taskTimeout: 30 * time.Millisecond}
require.NotPanics(t, func() {
pool.execute(func(ctx context.Context) {
panic("boom")
})
})
done := make(chan struct{})
pool.execute(func(ctx context.Context) {
<-ctx.Done()
close(done)
})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout context not cancelled")
}
}
func TestUsageRecordWorkerPool_ResizeAndLogDropBranches(t *testing.T) {
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
WorkerCount: 1,
QueueSize: 8,
TaskTimeout: time.Second,
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
AutoScaleEnabled: false,
})
t.Cleanup(pool.Stop)
// 目标值与当前值相同,应该直接返回。
pool.resizePool(1, 1, 0, 0, 0, 8, "noop")
require.Equal(t, 1, pool.Stats().MaxConcurrency)
// 在限流窗口内应静默返回。
pool.lastDropLogNanos.Store(time.Now().UnixNano())
require.NotPanics(t, func() {
pool.logDrop("full")
})
}

View File

@@ -300,6 +300,7 @@ var ProviderSet = wire.NewSet(
NewTurnstileService, NewTurnstileService,
NewSubscriptionService, NewSubscriptionService,
ProvideConcurrencyService, ProvideConcurrencyService,
NewUsageRecordWorkerPool,
ProvideSchedulerSnapshotService, ProvideSchedulerSnapshotService,
NewIdentityService, NewIdentityService,
NewCRSSyncService, NewCRSSyncService,