feat(gateway): 引入使用量记录有界 worker 池与自动扩缩容
- 新增 UsageRecordWorkerPool,支持有界队列、溢出降级策略与自动扩缩容 - 将 Gateway/OpenAI/Sora/Gemini 使用量记录改为提交到统一任务池执行 - 增加 usage_record 配置默认值与校验规则,并补充配置与任务提交相关测试 - 注入并托管 worker 池生命周期,服务退出时统一 StopAndWait Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -19,6 +19,13 @@ const (
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
// 使用量记录队列溢出策略
|
||||
const (
|
||||
UsageRecordOverflowPolicyDrop = "drop"
|
||||
UsageRecordOverflowPolicySample = "sample"
|
||||
UsageRecordOverflowPolicySync = "sync"
|
||||
)
|
||||
|
||||
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
|
||||
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||
@@ -413,6 +420,42 @@ type GatewayConfig struct {
|
||||
|
||||
// TLSFingerprint: TLS指纹伪装配置
|
||||
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
|
||||
|
||||
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
|
||||
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
|
||||
}
|
||||
|
||||
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
||||
type GatewayUsageRecordConfig struct {
|
||||
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
|
||||
WorkerCount int `mapstructure:"worker_count"`
|
||||
// QueueSize: 队列容量(有界)
|
||||
QueueSize int `mapstructure:"queue_size"`
|
||||
// TaskTimeoutSeconds: 单个使用量记录任务超时(秒)
|
||||
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
|
||||
// OverflowPolicy: 队列满时策略(drop/sample/sync)
|
||||
OverflowPolicy string `mapstructure:"overflow_policy"`
|
||||
// OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100)
|
||||
OverflowSamplePercent int `mapstructure:"overflow_sample_percent"`
|
||||
|
||||
// AutoScaleEnabled: 是否启用 worker 自动扩缩容
|
||||
AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"`
|
||||
// AutoScaleMinWorkers: 自动扩缩容最小 worker 数
|
||||
AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"`
|
||||
// AutoScaleMaxWorkers: 自动扩缩容最大 worker 数
|
||||
AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"`
|
||||
// AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容
|
||||
AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"`
|
||||
// AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容
|
||||
AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"`
|
||||
// AutoScaleUpStep: 每次扩容步长
|
||||
AutoScaleUpStep int `mapstructure:"auto_scale_up_step"`
|
||||
// AutoScaleDownStep: 每次缩容步长
|
||||
AutoScaleDownStep int `mapstructure:"auto_scale_down_step"`
|
||||
// AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒)
|
||||
AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"`
|
||||
// AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒)
|
||||
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
|
||||
}
|
||||
|
||||
// SoraModelFiltersConfig Sora 模型过滤配置
|
||||
@@ -1118,6 +1161,20 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
|
||||
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
|
||||
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
|
||||
viper.SetDefault("gateway.usage_record.worker_count", 128)
|
||||
viper.SetDefault("gateway.usage_record.queue_size", 16384)
|
||||
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
|
||||
viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
|
||||
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
|
||||
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
|
||||
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||
viper.SetDefault("concurrency.ping_interval", 10)
|
||||
@@ -1636,6 +1693,64 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 {
|
||||
return fmt.Errorf("gateway.max_line_size must be at least 1MB")
|
||||
}
|
||||
if c.Gateway.UsageRecord.WorkerCount <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.worker_count must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.QueueSize <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.queue_size must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive")
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) {
|
||||
case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync:
|
||||
default:
|
||||
return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s",
|
||||
UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync)
|
||||
}
|
||||
if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 {
|
||||
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100")
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) &&
|
||||
c.Gateway.UsageRecord.OverflowSamplePercent <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleEnabled {
|
||||
if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers")
|
||||
}
|
||||
if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers ||
|
||||
c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers {
|
||||
return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 {
|
||||
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
|
||||
}
|
||||
}
|
||||
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
|
||||
}
|
||||
|
||||
@@ -942,6 +942,74 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
|
||||
wantErr: "gateway.max_line_size must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record worker count",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 },
|
||||
wantErr: "gateway.usage_record.worker_count",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record queue size",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 },
|
||||
wantErr: "gateway.usage_record.queue_size",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record timeout",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 },
|
||||
wantErr: "gateway.usage_record.task_timeout_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record overflow policy",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" },
|
||||
wantErr: "gateway.usage_record.overflow_policy",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record sample percent range",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 },
|
||||
wantErr: "gateway.usage_record.overflow_sample_percent",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record sample percent required for sample policy",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample
|
||||
c.Gateway.UsageRecord.OverflowSamplePercent = 0
|
||||
},
|
||||
wantErr: "gateway.usage_record.overflow_sample_percent must be positive",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale max gte min",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.AutoScaleMinWorkers = 256
|
||||
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128
|
||||
},
|
||||
wantErr: "gateway.usage_record.auto_scale_max_workers",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record worker in auto scale range",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.AutoScaleMinWorkers = 200
|
||||
c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300
|
||||
c.Gateway.UsageRecord.WorkerCount = 128
|
||||
},
|
||||
wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale queue thresholds order",
|
||||
mutate: func(c *Config) {
|
||||
c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50
|
||||
c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50
|
||||
},
|
||||
wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale up step",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 },
|
||||
wantErr: "gateway.usage_record.auto_scale_up_step",
|
||||
},
|
||||
{
|
||||
name: "gateway usage record auto scale interval",
|
||||
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
|
||||
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling sticky waiting",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||
@@ -1025,6 +1093,99 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
|
||||
cfg.Gateway.UsageRecord.WorkerCount = 64
|
||||
|
||||
// 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。
|
||||
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100
|
||||
cfg.Gateway.UsageRecord.AutoScaleUpStep = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleDownStep = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "log level required",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Level = ""
|
||||
},
|
||||
wantErr: "log.level is required",
|
||||
},
|
||||
{
|
||||
name: "log format required",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Format = ""
|
||||
},
|
||||
wantErr: "log.format is required",
|
||||
},
|
||||
{
|
||||
name: "log stacktrace required",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.StacktraceLevel = ""
|
||||
},
|
||||
wantErr: "log.stacktrace_level is required",
|
||||
},
|
||||
{
|
||||
name: "log max backups non-negative",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Rotation.MaxBackups = -1
|
||||
},
|
||||
wantErr: "log.rotation.max_backups must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "log max age non-negative",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Rotation.MaxAgeDays = -1
|
||||
},
|
||||
wantErr: "log.rotation.max_age_days must be non-negative",
|
||||
},
|
||||
{
|
||||
name: "sampling thereafter non-negative when disabled",
|
||||
mutate: func(c *Config) {
|
||||
c.Log.Sampling.Enabled = false
|
||||
c.Log.Sampling.Thereafter = -1
|
||||
},
|
||||
wantErr: "log.sampling.thereafter must be non-negative",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
tt.mutate(cfg)
|
||||
err = cfg.Validate()
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
|
||||
@@ -1112,3 +1273,53 @@ func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
|
||||
t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
|
||||
resetViperWithJWTSecret(t)
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.WorkerCount != 128 {
|
||||
t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.QueueSize != 16384 {
|
||||
t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
|
||||
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
|
||||
t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
|
||||
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
|
||||
}
|
||||
if !cfg.Gateway.UsageRecord.AutoScaleEnabled {
|
||||
t.Fatalf("auto_scale_enabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 {
|
||||
t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 {
|
||||
t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 {
|
||||
t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 {
|
||||
t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 {
|
||||
t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 {
|
||||
t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 {
|
||||
t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds)
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 {
|
||||
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ type GatewayHandler struct {
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
@@ -54,6 +55,7 @@ func NewGatewayHandler(
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
@@ -77,6 +79,7 @@ func NewGatewayHandler(
|
||||
billingCacheService: billingCacheService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
@@ -431,19 +434,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fcb,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -452,10 +453,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -700,19 +701,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: ua,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fcb,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -721,10 +720,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
zap.Int64("api_key_id", currentAPIKey.ID),
|
||||
zap.Any("group_id", currentAPIKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("gateway.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
})
|
||||
reqLog.Debug("gateway.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
@@ -1508,3 +1507,17 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
@@ -519,22 +518,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 6) record usage async (Gemini 使用长上下文双倍计费)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fcb,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -543,10 +539,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", modelName),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("gemini.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
})
|
||||
reqLog.Debug("gemini.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
|
||||
@@ -26,6 +26,7 @@ type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
@@ -37,6 +38,7 @@ func NewOpenAIGatewayHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
@@ -52,6 +54,7 @@ func NewOpenAIGatewayHandler(
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
@@ -378,18 +381,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -398,10 +399,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
})
|
||||
reqLog.Debug("openai.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
@@ -432,6 +433,20 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
|
||||
@@ -31,15 +31,16 @@ import (
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
streamMode string
|
||||
soraTLSEnabled bool
|
||||
soraMediaSigningKey string
|
||||
soraMediaRoot string
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
streamMode string
|
||||
soraTLSEnabled bool
|
||||
soraMediaSigningKey string
|
||||
soraMediaRoot string
|
||||
}
|
||||
|
||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
||||
@@ -48,6 +49,7 @@ func NewSoraGatewayHandler(
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -71,15 +73,16 @@ func NewSoraGatewayHandler(
|
||||
}
|
||||
}
|
||||
return &SoraGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
streamMode: strings.ToLower(streamMode),
|
||||
soraTLSEnabled: soraTLSEnabled,
|
||||
soraMediaSigningKey: signKey,
|
||||
soraMediaRoot: mediaRoot,
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
streamMode: strings.ToLower(streamMode),
|
||||
soraTLSEnabled: soraTLSEnabled,
|
||||
soraMediaSigningKey: signKey,
|
||||
soraMediaRoot: mediaRoot,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,17 +400,16 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
@@ -415,10 +417,10 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", usedAccount.ID),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("sora.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
})
|
||||
reqLog.Debug("sora.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
@@ -448,6 +450,20 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string {
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
|
||||
@@ -432,7 +432,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
|
||||
|
||||
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg)
|
||||
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
136
backend/internal/handler/usage_record_submit_task_test.go
Normal file
136
backend/internal/handler/usage_record_submit_task_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool {
|
||||
t.Helper()
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 8,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
return pool
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &GatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
496
backend/internal/service/usage_record_worker_pool.go
Normal file
496
backend/internal/service/usage_record_worker_pool.go
Normal file
@@ -0,0 +1,496 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/alitto/pond/v2"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUsageRecordWorkerCount = 128
|
||||
defaultUsageRecordQueueSize = 16384
|
||||
defaultUsageRecordTaskTimeoutSeconds = 5
|
||||
defaultUsageRecordOverflowPolicy = config.UsageRecordOverflowPolicySample
|
||||
defaultUsageRecordOverflowSampleRatio = 10
|
||||
defaultUsageRecordAutoScaleEnabled = true
|
||||
defaultUsageRecordAutoScaleMinWorkers = 128
|
||||
defaultUsageRecordAutoScaleMaxWorkers = 512
|
||||
defaultUsageRecordAutoScaleUpPercent = 70
|
||||
defaultUsageRecordAutoScaleDownPercent = 15
|
||||
defaultUsageRecordAutoScaleUpStep = 32
|
||||
defaultUsageRecordAutoScaleDownStep = 16
|
||||
defaultUsageRecordAutoScaleInterval = 3 * time.Second
|
||||
defaultUsageRecordAutoScaleCooldown = 10 * time.Second
|
||||
usageRecordDropLogInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
// UsageRecordTask 是提交到使用量记录池的任务。
|
||||
// 任务实现应自行处理业务错误日志;池本身只负责调度与超时控制。
|
||||
type UsageRecordTask func(ctx context.Context)
|
||||
|
||||
// UsageRecordSubmitMode 表示任务提交结果。
|
||||
type UsageRecordSubmitMode string
|
||||
|
||||
const (
|
||||
UsageRecordSubmitModeEnqueued UsageRecordSubmitMode = "enqueued"
|
||||
UsageRecordSubmitModeDropped UsageRecordSubmitMode = "dropped"
|
||||
UsageRecordSubmitModeSync UsageRecordSubmitMode = "sync_fallback"
|
||||
)
|
||||
|
||||
// UsageRecordWorkerPoolOptions 使用量记录池配置。
|
||||
type UsageRecordWorkerPoolOptions struct {
|
||||
WorkerCount int
|
||||
QueueSize int
|
||||
TaskTimeout time.Duration
|
||||
OverflowPolicy string
|
||||
OverflowSamplePercent int
|
||||
AutoScaleEnabled bool
|
||||
AutoScaleMinWorkers int
|
||||
AutoScaleMaxWorkers int
|
||||
AutoScaleUpPercent int
|
||||
AutoScaleDownPercent int
|
||||
AutoScaleUpStep int
|
||||
AutoScaleDownStep int
|
||||
AutoScaleInterval time.Duration
|
||||
AutoScaleCooldown time.Duration
|
||||
}
|
||||
|
||||
// UsageRecordWorkerPoolStats 使用量记录池运行时统计。
|
||||
type UsageRecordWorkerPoolStats struct {
|
||||
MaxConcurrency int
|
||||
RunningWorkers int64
|
||||
WaitingTasks uint64
|
||||
SubmittedTasks uint64
|
||||
CompletedTasks uint64
|
||||
SuccessfulTasks uint64
|
||||
FailedTasks uint64
|
||||
DroppedTasks uint64
|
||||
DroppedQueueFull uint64
|
||||
DroppedPoolStopped uint64
|
||||
SyncFallbackTasks uint64
|
||||
}
|
||||
|
||||
// UsageRecordWorkerPool 提供“有界队列 + 固定 worker”的异步执行器。
|
||||
// 用于替代请求路径里的直接 goroutine,避免高并发时无界堆积。
|
||||
type UsageRecordWorkerPool struct {
|
||||
pool pond.Pool
|
||||
taskTimeout time.Duration
|
||||
overflowPolicy string
|
||||
overflowSamplePercent int
|
||||
overflowCounter atomic.Uint64
|
||||
droppedQueueFull atomic.Uint64
|
||||
droppedPoolStopped atomic.Uint64
|
||||
syncFallback atomic.Uint64
|
||||
lastDropLogNanos atomic.Int64
|
||||
autoScaleEnabled bool
|
||||
autoScaleMinWorkers int
|
||||
autoScaleMaxWorkers int
|
||||
autoScaleUpPercent int
|
||||
autoScaleDownPercent int
|
||||
autoScaleUpStep int
|
||||
autoScaleDownStep int
|
||||
autoScaleInterval time.Duration
|
||||
autoScaleCooldown time.Duration
|
||||
lastScaleNanos atomic.Int64
|
||||
autoScaleCancel context.CancelFunc
|
||||
lifecycleWg sync.WaitGroup
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewUsageRecordWorkerPool 从配置构建使用量记录池。
|
||||
func NewUsageRecordWorkerPool(cfg *config.Config) *UsageRecordWorkerPool {
|
||||
opts := usageRecordPoolOptionsFromConfig(cfg)
|
||||
return NewUsageRecordWorkerPoolWithOptions(opts)
|
||||
}
|
||||
|
||||
// NewUsageRecordWorkerPoolWithOptions 根据给定参数构建使用量记录池。
|
||||
func NewUsageRecordWorkerPoolWithOptions(opts UsageRecordWorkerPoolOptions) *UsageRecordWorkerPool {
|
||||
opts = normalizeUsageRecordPoolOptions(opts)
|
||||
|
||||
p := &UsageRecordWorkerPool{
|
||||
taskTimeout: opts.TaskTimeout,
|
||||
overflowPolicy: opts.OverflowPolicy,
|
||||
overflowSamplePercent: opts.OverflowSamplePercent,
|
||||
autoScaleEnabled: opts.AutoScaleEnabled,
|
||||
autoScaleMinWorkers: opts.AutoScaleMinWorkers,
|
||||
autoScaleMaxWorkers: opts.AutoScaleMaxWorkers,
|
||||
autoScaleUpPercent: opts.AutoScaleUpPercent,
|
||||
autoScaleDownPercent: opts.AutoScaleDownPercent,
|
||||
autoScaleUpStep: opts.AutoScaleUpStep,
|
||||
autoScaleDownStep: opts.AutoScaleDownStep,
|
||||
autoScaleInterval: opts.AutoScaleInterval,
|
||||
autoScaleCooldown: opts.AutoScaleCooldown,
|
||||
}
|
||||
|
||||
p.pool = pond.NewPool(
|
||||
opts.WorkerCount,
|
||||
pond.WithQueueSize(opts.QueueSize),
|
||||
)
|
||||
if p.autoScaleEnabled {
|
||||
p.startAutoScaler()
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Submit 提交一个使用量记录任务。
|
||||
// 提交失败(队列满)时按 overflowPolicy 执行降级策略:drop/sample/sync。
|
||||
func (p *UsageRecordWorkerPool) Submit(task UsageRecordTask) UsageRecordSubmitMode {
|
||||
if p == nil || task == nil {
|
||||
return UsageRecordSubmitModeDropped
|
||||
}
|
||||
if p.pool == nil || p.pool.Stopped() {
|
||||
p.droppedPoolStopped.Add(1)
|
||||
p.logDrop("stopped")
|
||||
return UsageRecordSubmitModeDropped
|
||||
}
|
||||
|
||||
_, ok := p.pool.TrySubmit(func() {
|
||||
p.execute(task)
|
||||
})
|
||||
if ok {
|
||||
return UsageRecordSubmitModeEnqueued
|
||||
}
|
||||
|
||||
if p.pool.Stopped() {
|
||||
p.droppedPoolStopped.Add(1)
|
||||
p.logDrop("stopped")
|
||||
return UsageRecordSubmitModeDropped
|
||||
}
|
||||
|
||||
switch p.overflowPolicy {
|
||||
case config.UsageRecordOverflowPolicySync:
|
||||
p.syncFallback.Add(1)
|
||||
p.execute(task)
|
||||
return UsageRecordSubmitModeSync
|
||||
case config.UsageRecordOverflowPolicySample:
|
||||
if p.shouldSyncFallback() {
|
||||
p.syncFallback.Add(1)
|
||||
p.execute(task)
|
||||
return UsageRecordSubmitModeSync
|
||||
}
|
||||
}
|
||||
|
||||
p.droppedQueueFull.Add(1)
|
||||
p.logDrop("full")
|
||||
return UsageRecordSubmitModeDropped
|
||||
}
|
||||
|
||||
// Stats 返回当前池状态与计数器。
|
||||
func (p *UsageRecordWorkerPool) Stats() UsageRecordWorkerPoolStats {
|
||||
if p == nil || p.pool == nil {
|
||||
return UsageRecordWorkerPoolStats{}
|
||||
}
|
||||
return UsageRecordWorkerPoolStats{
|
||||
MaxConcurrency: p.pool.MaxConcurrency(),
|
||||
RunningWorkers: p.pool.RunningWorkers(),
|
||||
WaitingTasks: p.pool.WaitingTasks(),
|
||||
SubmittedTasks: p.pool.SubmittedTasks(),
|
||||
CompletedTasks: p.pool.CompletedTasks(),
|
||||
SuccessfulTasks: p.pool.SuccessfulTasks(),
|
||||
FailedTasks: p.pool.FailedTasks(),
|
||||
DroppedTasks: p.pool.DroppedTasks(),
|
||||
DroppedQueueFull: p.droppedQueueFull.Load(),
|
||||
DroppedPoolStopped: p.droppedPoolStopped.Load(),
|
||||
SyncFallbackTasks: p.syncFallback.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止池并等待队列任务完成。
|
||||
func (p *UsageRecordWorkerPool) Stop() {
|
||||
if p == nil || p.pool == nil {
|
||||
return
|
||||
}
|
||||
p.stopOnce.Do(func() {
|
||||
if p.autoScaleCancel != nil {
|
||||
p.autoScaleCancel()
|
||||
}
|
||||
p.lifecycleWg.Wait()
|
||||
p.pool.StopAndWait()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *UsageRecordWorkerPool) startAutoScaler() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p.autoScaleCancel = cancel
|
||||
|
||||
p.lifecycleWg.Add(1)
|
||||
go func() {
|
||||
defer p.lifecycleWg.Done()
|
||||
|
||||
ticker := time.NewTicker(p.autoScaleInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.autoScaleTick()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *UsageRecordWorkerPool) autoScaleTick() {
|
||||
if p == nil || p.pool == nil || p.pool.Stopped() {
|
||||
return
|
||||
}
|
||||
queueSize := p.pool.QueueSize()
|
||||
if queueSize <= 0 {
|
||||
return
|
||||
}
|
||||
current := p.pool.MaxConcurrency()
|
||||
waiting := int(p.pool.WaitingTasks())
|
||||
running := int(p.pool.RunningWorkers())
|
||||
if current <= 0 || waiting < 0 {
|
||||
return
|
||||
}
|
||||
queuePercent := waiting * 100 / queueSize
|
||||
runningPercent := 0
|
||||
if current > 0 {
|
||||
runningPercent = running * 100 / current
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
lastScaleNanos := p.lastScaleNanos.Load()
|
||||
if lastScaleNanos > 0 && now.Sub(time.Unix(0, lastScaleNanos)) < p.autoScaleCooldown {
|
||||
return
|
||||
}
|
||||
|
||||
// 扩容优先:队列占用率超过阈值时,按步长提升并发上限。
|
||||
if queuePercent >= p.autoScaleUpPercent && current < p.autoScaleMaxWorkers {
|
||||
target := current + p.autoScaleUpStep
|
||||
if target > p.autoScaleMaxWorkers {
|
||||
target = p.autoScaleMaxWorkers
|
||||
}
|
||||
p.resizePool(current, target, queuePercent, waiting, runningPercent, queueSize, "scale_up")
|
||||
return
|
||||
}
|
||||
|
||||
// 缩容:仅在队列为空且运行利用率低时收缩,避免高负载下“无排队误缩容”导致震荡。
|
||||
if queuePercent <= p.autoScaleDownPercent && waiting == 0 &&
|
||||
runningPercent <= p.autoScaleDownPercent &&
|
||||
current > p.autoScaleMinWorkers {
|
||||
target := current - p.autoScaleDownStep
|
||||
if target < p.autoScaleMinWorkers {
|
||||
target = p.autoScaleMinWorkers
|
||||
}
|
||||
p.resizePool(current, target, queuePercent, waiting, runningPercent, queueSize, "scale_down")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *UsageRecordWorkerPool) resizePool(current, target, queuePercent, waiting, runningPercent, queueSize int, action string) {
|
||||
if target == current {
|
||||
return
|
||||
}
|
||||
p.pool.Resize(target)
|
||||
p.lastScaleNanos.Store(time.Now().UnixNano())
|
||||
|
||||
logger.L().With(
|
||||
zap.String("component", "service.usage_record_worker_pool"),
|
||||
zap.String("action", action),
|
||||
zap.Int("from_workers", current),
|
||||
zap.Int("to_workers", target),
|
||||
zap.Int("queue_percent", queuePercent),
|
||||
zap.Int("waiting_tasks", waiting),
|
||||
zap.Int("running_percent", runningPercent),
|
||||
zap.Int("queue_size", queueSize),
|
||||
).Info("usage_record.auto_scale")
|
||||
}
|
||||
|
||||
func (p *UsageRecordWorkerPool) shouldSyncFallback() bool {
|
||||
if p.overflowSamplePercent <= 0 {
|
||||
return false
|
||||
}
|
||||
n := p.overflowCounter.Add(1)
|
||||
return int((n-1)%100) < p.overflowSamplePercent
|
||||
}
|
||||
|
||||
func (p *UsageRecordWorkerPool) execute(task UsageRecordTask) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.taskTimeout)
|
||||
defer cancel()
|
||||
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "service.usage_record_worker_pool"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("usage_record.task_panic")
|
||||
}
|
||||
}()
|
||||
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (p *UsageRecordWorkerPool) logDrop(reason string) {
|
||||
now := time.Now().UnixNano()
|
||||
last := p.lastDropLogNanos.Load()
|
||||
if now-last < int64(usageRecordDropLogInterval) {
|
||||
return
|
||||
}
|
||||
if !p.lastDropLogNanos.CompareAndSwap(last, now) {
|
||||
return
|
||||
}
|
||||
|
||||
stats := p.Stats()
|
||||
logger.L().With(
|
||||
zap.String("component", "service.usage_record_worker_pool"),
|
||||
zap.String("reason", reason),
|
||||
zap.String("overflow_policy", p.overflowPolicy),
|
||||
zap.Int64("running_workers", stats.RunningWorkers),
|
||||
zap.Uint64("waiting_tasks", stats.WaitingTasks),
|
||||
zap.Uint64("dropped_queue_full", stats.DroppedQueueFull),
|
||||
zap.Uint64("dropped_pool_stopped", stats.DroppedPoolStopped),
|
||||
zap.Uint64("sync_fallback_tasks", stats.SyncFallbackTasks),
|
||||
).Warn("usage_record.task_dropped")
|
||||
}
|
||||
|
||||
func usageRecordPoolOptionsFromConfig(cfg *config.Config) UsageRecordWorkerPoolOptions {
|
||||
opts := UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: defaultUsageRecordWorkerCount,
|
||||
QueueSize: defaultUsageRecordQueueSize,
|
||||
TaskTimeout: time.Duration(defaultUsageRecordTaskTimeoutSeconds) * time.Second,
|
||||
OverflowPolicy: defaultUsageRecordOverflowPolicy,
|
||||
OverflowSamplePercent: defaultUsageRecordOverflowSampleRatio,
|
||||
AutoScaleEnabled: defaultUsageRecordAutoScaleEnabled,
|
||||
AutoScaleMinWorkers: defaultUsageRecordAutoScaleMinWorkers,
|
||||
AutoScaleMaxWorkers: defaultUsageRecordAutoScaleMaxWorkers,
|
||||
AutoScaleUpPercent: defaultUsageRecordAutoScaleUpPercent,
|
||||
AutoScaleDownPercent: defaultUsageRecordAutoScaleDownPercent,
|
||||
AutoScaleUpStep: defaultUsageRecordAutoScaleUpStep,
|
||||
AutoScaleDownStep: defaultUsageRecordAutoScaleDownStep,
|
||||
AutoScaleInterval: defaultUsageRecordAutoScaleInterval,
|
||||
AutoScaleCooldown: defaultUsageRecordAutoScaleCooldown,
|
||||
}
|
||||
if cfg == nil {
|
||||
return opts
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.WorkerCount > 0 {
|
||||
opts.WorkerCount = cfg.Gateway.UsageRecord.WorkerCount
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.QueueSize > 0 {
|
||||
opts.QueueSize = cfg.Gateway.UsageRecord.QueueSize
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds > 0 {
|
||||
opts.TaskTimeout = time.Duration(cfg.Gateway.UsageRecord.TaskTimeoutSeconds) * time.Second
|
||||
}
|
||||
if policy := strings.TrimSpace(strings.ToLower(cfg.Gateway.UsageRecord.OverflowPolicy)); policy != "" {
|
||||
opts.OverflowPolicy = policy
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.OverflowSamplePercent >= 0 {
|
||||
opts.OverflowSamplePercent = cfg.Gateway.UsageRecord.OverflowSamplePercent
|
||||
}
|
||||
opts.AutoScaleEnabled = cfg.Gateway.UsageRecord.AutoScaleEnabled
|
||||
if cfg.Gateway.UsageRecord.AutoScaleMinWorkers > 0 {
|
||||
opts.AutoScaleMinWorkers = cfg.Gateway.UsageRecord.AutoScaleMinWorkers
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers > 0 {
|
||||
opts.AutoScaleMaxWorkers = cfg.Gateway.UsageRecord.AutoScaleMaxWorkers
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent > 0 {
|
||||
opts.AutoScaleUpPercent = cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 0 {
|
||||
opts.AutoScaleDownPercent = cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleUpStep > 0 {
|
||||
opts.AutoScaleUpStep = cfg.Gateway.UsageRecord.AutoScaleUpStep
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleDownStep > 0 {
|
||||
opts.AutoScaleDownStep = cfg.Gateway.UsageRecord.AutoScaleDownStep
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds > 0 {
|
||||
opts.AutoScaleInterval = time.Duration(cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds) * time.Second
|
||||
}
|
||||
if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds >= 0 {
|
||||
opts.AutoScaleCooldown = time.Duration(cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) * time.Second
|
||||
}
|
||||
return normalizeUsageRecordPoolOptions(opts)
|
||||
}
|
||||
|
||||
func normalizeUsageRecordPoolOptions(opts UsageRecordWorkerPoolOptions) UsageRecordWorkerPoolOptions {
|
||||
if opts.WorkerCount <= 0 {
|
||||
opts.WorkerCount = defaultUsageRecordWorkerCount
|
||||
}
|
||||
if opts.QueueSize <= 0 {
|
||||
opts.QueueSize = defaultUsageRecordQueueSize
|
||||
}
|
||||
if opts.TaskTimeout <= 0 {
|
||||
opts.TaskTimeout = time.Duration(defaultUsageRecordTaskTimeoutSeconds) * time.Second
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(opts.OverflowPolicy)) {
|
||||
case config.UsageRecordOverflowPolicyDrop,
|
||||
config.UsageRecordOverflowPolicySample,
|
||||
config.UsageRecordOverflowPolicySync:
|
||||
opts.OverflowPolicy = strings.ToLower(strings.TrimSpace(opts.OverflowPolicy))
|
||||
default:
|
||||
opts.OverflowPolicy = defaultUsageRecordOverflowPolicy
|
||||
}
|
||||
if opts.OverflowSamplePercent < 0 {
|
||||
opts.OverflowSamplePercent = 0
|
||||
}
|
||||
if opts.OverflowSamplePercent > 100 {
|
||||
opts.OverflowSamplePercent = 100
|
||||
}
|
||||
if opts.OverflowPolicy == config.UsageRecordOverflowPolicySample && opts.OverflowSamplePercent == 0 {
|
||||
opts.OverflowSamplePercent = defaultUsageRecordOverflowSampleRatio
|
||||
}
|
||||
if opts.AutoScaleEnabled {
|
||||
if opts.AutoScaleMinWorkers <= 0 {
|
||||
opts.AutoScaleMinWorkers = defaultUsageRecordAutoScaleMinWorkers
|
||||
}
|
||||
if opts.AutoScaleMaxWorkers <= 0 {
|
||||
opts.AutoScaleMaxWorkers = defaultUsageRecordAutoScaleMaxWorkers
|
||||
}
|
||||
if opts.AutoScaleMaxWorkers < opts.AutoScaleMinWorkers {
|
||||
opts.AutoScaleMaxWorkers = opts.AutoScaleMinWorkers
|
||||
}
|
||||
if opts.WorkerCount < opts.AutoScaleMinWorkers {
|
||||
opts.WorkerCount = opts.AutoScaleMinWorkers
|
||||
}
|
||||
if opts.WorkerCount > opts.AutoScaleMaxWorkers {
|
||||
opts.WorkerCount = opts.AutoScaleMaxWorkers
|
||||
}
|
||||
if opts.AutoScaleUpPercent <= 0 || opts.AutoScaleUpPercent > 100 {
|
||||
opts.AutoScaleUpPercent = defaultUsageRecordAutoScaleUpPercent
|
||||
}
|
||||
if opts.AutoScaleDownPercent < 0 || opts.AutoScaleDownPercent >= 100 {
|
||||
opts.AutoScaleDownPercent = defaultUsageRecordAutoScaleDownPercent
|
||||
}
|
||||
if opts.AutoScaleDownPercent >= opts.AutoScaleUpPercent {
|
||||
opts.AutoScaleDownPercent = max(0, opts.AutoScaleUpPercent/2)
|
||||
}
|
||||
if opts.AutoScaleUpStep <= 0 {
|
||||
opts.AutoScaleUpStep = defaultUsageRecordAutoScaleUpStep
|
||||
}
|
||||
if opts.AutoScaleDownStep <= 0 {
|
||||
opts.AutoScaleDownStep = defaultUsageRecordAutoScaleDownStep
|
||||
}
|
||||
if opts.AutoScaleInterval <= 0 {
|
||||
opts.AutoScaleInterval = defaultUsageRecordAutoScaleInterval
|
||||
}
|
||||
if opts.AutoScaleCooldown < 0 {
|
||||
opts.AutoScaleCooldown = defaultUsageRecordAutoScaleCooldown
|
||||
}
|
||||
} else {
|
||||
opts.AutoScaleMinWorkers = opts.WorkerCount
|
||||
opts.AutoScaleMaxWorkers = opts.WorkerCount
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func (m UsageRecordSubmitMode) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
func (s UsageRecordWorkerPoolStats) String() string {
|
||||
return fmt.Sprintf("running=%d waiting=%d submitted=%d dropped=%d", s.RunningWorkers, s.WaitingTasks, s.SubmittedTasks, s.DroppedTasks)
|
||||
}
|
||||
488
backend/internal/service/usage_record_worker_pool_test.go
Normal file
488
backend/internal/service/usage_record_worker_pool_test.go
Normal file
@@ -0,0 +1,488 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageRecordWorkerPool_SubmitEnqueued(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 8,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
OverflowSamplePercent: 0,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
done := make(chan struct{})
|
||||
mode := pool.Submit(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, mode)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
stats := pool.Stats()
|
||||
return stats.SubmittedTasks == 1 && stats.SuccessfulTasks == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_OverflowDrop(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
OverflowSamplePercent: 0,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
block := make(chan struct{})
|
||||
started := make(chan struct{})
|
||||
secondDone := make(chan struct{})
|
||||
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
<-started
|
||||
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
close(secondDone)
|
||||
}))
|
||||
require.Equal(t, UsageRecordSubmitModeDropped, pool.Submit(func(ctx context.Context) {}))
|
||||
|
||||
close(block)
|
||||
select {
|
||||
case <-secondDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("queued task not executed")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return pool.Stats().DroppedQueueFull >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_OverflowSync(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicySync,
|
||||
OverflowSamplePercent: 0,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
block := make(chan struct{})
|
||||
started := make(chan struct{})
|
||||
secondDone := make(chan struct{})
|
||||
var syncExecuted atomic.Bool
|
||||
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
<-started
|
||||
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
close(secondDone)
|
||||
}))
|
||||
|
||||
mode := pool.Submit(func(ctx context.Context) {
|
||||
syncExecuted.Store(true)
|
||||
})
|
||||
require.Equal(t, UsageRecordSubmitModeSync, mode)
|
||||
require.True(t, syncExecuted.Load())
|
||||
|
||||
close(block)
|
||||
select {
|
||||
case <-secondDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("queued task not executed")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return pool.Stats().SyncFallbackTasks >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_OverflowSample(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicySample,
|
||||
OverflowSamplePercent: 1,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
block := make(chan struct{})
|
||||
started := make(chan struct{})
|
||||
secondDone := make(chan struct{})
|
||||
var syncExecuted atomic.Bool
|
||||
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
<-started
|
||||
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
close(secondDone)
|
||||
}))
|
||||
|
||||
firstOverflow := pool.Submit(func(ctx context.Context) {
|
||||
syncExecuted.Store(true)
|
||||
})
|
||||
require.Equal(t, UsageRecordSubmitModeSync, firstOverflow)
|
||||
require.True(t, syncExecuted.Load())
|
||||
|
||||
secondOverflow := pool.Submit(func(ctx context.Context) {})
|
||||
require.Equal(t, UsageRecordSubmitModeDropped, secondOverflow)
|
||||
|
||||
close(block)
|
||||
select {
|
||||
case <-secondDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("queued task not executed")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
stats := pool.Stats()
|
||||
return stats.SyncFallbackTasks >= 1 && stats.DroppedQueueFull >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_SubmitAfterStop(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
OverflowSamplePercent: 0,
|
||||
})
|
||||
|
||||
pool.Stop()
|
||||
mode := pool.Submit(func(ctx context.Context) {})
|
||||
require.Equal(t, UsageRecordSubmitModeDropped, mode)
|
||||
require.GreaterOrEqual(t, pool.Stats().DroppedPoolStopped, uint64(1))
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_AutoScaleUpAndDown(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 2,
|
||||
QueueSize: 8,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: true,
|
||||
AutoScaleMinWorkers: 1,
|
||||
AutoScaleMaxWorkers: 4,
|
||||
AutoScaleUpPercent: 40,
|
||||
AutoScaleDownPercent: 10,
|
||||
AutoScaleUpStep: 1,
|
||||
AutoScaleDownStep: 1,
|
||||
AutoScaleInterval: 20 * time.Millisecond,
|
||||
AutoScaleCooldown: 20 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
block := make(chan struct{})
|
||||
|
||||
// 填满运行槽位 + 队列,触发扩容阈值。
|
||||
for i := 0; i < 8; i++ {
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
<-block
|
||||
}))
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return pool.Stats().MaxConcurrency >= 3
|
||||
}, 2*time.Second, 20*time.Millisecond)
|
||||
|
||||
close(block)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return pool.Stats().CompletedTasks >= 8
|
||||
}, 2*time.Second, 20*time.Millisecond)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return pool.Stats().MaxConcurrency == 1
|
||||
}, 2*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_AutoScaleDownRequiresLowRunningUtilization(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 2,
|
||||
QueueSize: 8,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: true,
|
||||
AutoScaleMinWorkers: 1,
|
||||
AutoScaleMaxWorkers: 2,
|
||||
AutoScaleUpPercent: 80,
|
||||
AutoScaleDownPercent: 50,
|
||||
AutoScaleUpStep: 1,
|
||||
AutoScaleDownStep: 1,
|
||||
AutoScaleInterval: 20 * time.Millisecond,
|
||||
AutoScaleCooldown: 20 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
block := make(chan struct{})
|
||||
for i := 0; i < 2; i++ {
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
<-block
|
||||
}))
|
||||
}
|
||||
|
||||
// 虽然 waiting=0,但 running 利用率为 100%,不应缩容。
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
require.Equal(t, 2, pool.Stats().MaxConcurrency)
|
||||
|
||||
close(block)
|
||||
require.Eventually(t, func() bool {
|
||||
return pool.Stats().MaxConcurrency == 1
|
||||
}, 2*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_SubmitNilReceiverAndNilTask(t *testing.T) {
|
||||
var nilPool *UsageRecordWorkerPool
|
||||
require.Equal(t, UsageRecordSubmitModeDropped, nilPool.Submit(func(ctx context.Context) {}))
|
||||
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 1,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
require.Equal(t, UsageRecordSubmitModeDropped, pool.Submit(nil))
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_AutoScaleDisabledKeepsFixedConcurrency(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 2,
|
||||
QueueSize: 4,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
AutoScaleMinWorkers: 1,
|
||||
AutoScaleMaxWorkers: 4,
|
||||
AutoScaleUpPercent: 10,
|
||||
AutoScaleDownPercent: 1,
|
||||
AutoScaleUpStep: 2,
|
||||
AutoScaleDownStep: 2,
|
||||
AutoScaleInterval: 10 * time.Millisecond,
|
||||
AutoScaleCooldown: 10 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
require.Equal(t, 2, pool.Stats().MaxConcurrency)
|
||||
|
||||
block := make(chan struct{})
|
||||
for i := 0; i < 4; i++ {
|
||||
require.Equal(t, UsageRecordSubmitModeEnqueued, pool.Submit(func(ctx context.Context) {
|
||||
<-block
|
||||
}))
|
||||
}
|
||||
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
require.Equal(t, 2, pool.Stats().MaxConcurrency)
|
||||
close(block)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_OptionsFromConfig_AutoScaleDisabled(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.UsageRecord.WorkerCount = 64
|
||||
cfg.Gateway.UsageRecord.QueueSize = 128
|
||||
cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 7
|
||||
cfg.Gateway.UsageRecord.OverflowPolicy = config.UsageRecordOverflowPolicyDrop
|
||||
cfg.Gateway.UsageRecord.OverflowSamplePercent = 0
|
||||
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
|
||||
cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 1
|
||||
cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 512
|
||||
|
||||
opts := usageRecordPoolOptionsFromConfig(cfg)
|
||||
require.False(t, opts.AutoScaleEnabled)
|
||||
require.Equal(t, 64, opts.WorkerCount)
|
||||
require.Equal(t, 64, opts.AutoScaleMinWorkers)
|
||||
require.Equal(t, 64, opts.AutoScaleMaxWorkers)
|
||||
require.Equal(t, 7*time.Second, opts.TaskTimeout)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_StringHelpers(t *testing.T) {
|
||||
require.Equal(t, "enqueued", UsageRecordSubmitModeEnqueued.String())
|
||||
stats := UsageRecordWorkerPoolStats{RunningWorkers: 2, WaitingTasks: 3, SubmittedTasks: 5, DroppedTasks: 1}
|
||||
require.Contains(t, stats.String(), "running=2")
|
||||
require.Contains(t, stats.String(), "waiting=3")
|
||||
}
|
||||
|
||||
func TestNewUsageRecordWorkerPool_FromConfig(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.UsageRecord.WorkerCount = 3
|
||||
cfg.Gateway.UsageRecord.QueueSize = 16
|
||||
cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 2
|
||||
cfg.Gateway.UsageRecord.OverflowPolicy = config.UsageRecordOverflowPolicyDrop
|
||||
cfg.Gateway.UsageRecord.AutoScaleEnabled = false
|
||||
|
||||
pool := NewUsageRecordWorkerPool(cfg)
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
stats := pool.Stats()
|
||||
require.Equal(t, 3, stats.MaxConcurrency)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_OptionsFromConfig_NilConfig(t *testing.T) {
|
||||
opts := usageRecordPoolOptionsFromConfig(nil)
|
||||
require.Equal(t, defaultUsageRecordWorkerCount, opts.WorkerCount)
|
||||
require.Equal(t, defaultUsageRecordQueueSize, opts.QueueSize)
|
||||
require.Equal(t, time.Duration(defaultUsageRecordTaskTimeoutSeconds)*time.Second, opts.TaskTimeout)
|
||||
require.Equal(t, defaultUsageRecordOverflowPolicy, opts.OverflowPolicy)
|
||||
require.Equal(t, defaultUsageRecordOverflowSampleRatio, opts.OverflowSamplePercent)
|
||||
require.True(t, opts.AutoScaleEnabled)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleMinWorkers, opts.AutoScaleMinWorkers)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleMaxWorkers, opts.AutoScaleMaxWorkers)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_NormalizeOptions_BoundsAndDefaults(t *testing.T) {
|
||||
opts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 0,
|
||||
QueueSize: 0,
|
||||
TaskTimeout: 0,
|
||||
OverflowPolicy: "invalid",
|
||||
OverflowSamplePercent: 300,
|
||||
AutoScaleEnabled: true,
|
||||
AutoScaleMinWorkers: 0,
|
||||
AutoScaleMaxWorkers: 0,
|
||||
AutoScaleUpPercent: 0,
|
||||
AutoScaleDownPercent: 100,
|
||||
AutoScaleUpStep: 0,
|
||||
AutoScaleDownStep: 0,
|
||||
AutoScaleInterval: 0,
|
||||
AutoScaleCooldown: -time.Second,
|
||||
})
|
||||
|
||||
require.Equal(t, defaultUsageRecordWorkerCount, opts.WorkerCount)
|
||||
require.Equal(t, defaultUsageRecordQueueSize, opts.QueueSize)
|
||||
require.Equal(t, time.Duration(defaultUsageRecordTaskTimeoutSeconds)*time.Second, opts.TaskTimeout)
|
||||
require.Equal(t, defaultUsageRecordOverflowPolicy, opts.OverflowPolicy)
|
||||
require.Equal(t, 100, opts.OverflowSamplePercent)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleMinWorkers, opts.AutoScaleMinWorkers)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleMaxWorkers, opts.AutoScaleMaxWorkers)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleUpPercent, opts.AutoScaleUpPercent)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleDownPercent, opts.AutoScaleDownPercent)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleUpStep, opts.AutoScaleUpStep)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleDownStep, opts.AutoScaleDownStep)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleInterval, opts.AutoScaleInterval)
|
||||
require.Equal(t, defaultUsageRecordAutoScaleCooldown, opts.AutoScaleCooldown)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_NormalizeOptions_SampleAndAutoScaleDisabled(t *testing.T) {
|
||||
sampleOpts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 32,
|
||||
QueueSize: 128,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicySample,
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: true,
|
||||
AutoScaleMinWorkers: 64,
|
||||
AutoScaleMaxWorkers: 48,
|
||||
AutoScaleUpPercent: 30,
|
||||
AutoScaleDownPercent: 40,
|
||||
AutoScaleUpStep: 1,
|
||||
AutoScaleDownStep: 1,
|
||||
AutoScaleInterval: time.Second,
|
||||
AutoScaleCooldown: time.Second,
|
||||
})
|
||||
require.Equal(t, defaultUsageRecordOverflowSampleRatio, sampleOpts.OverflowSamplePercent)
|
||||
require.Equal(t, 64, sampleOpts.AutoScaleMinWorkers)
|
||||
require.Equal(t, 64, sampleOpts.AutoScaleMaxWorkers)
|
||||
require.Equal(t, 64, sampleOpts.WorkerCount)
|
||||
require.Equal(t, 15, sampleOpts.AutoScaleDownPercent)
|
||||
|
||||
fixedOpts := normalizeUsageRecordPoolOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 20,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
require.Equal(t, 20, fixedOpts.AutoScaleMinWorkers)
|
||||
require.Equal(t, 20, fixedOpts.AutoScaleMaxWorkers)
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_ShouldSyncFallbackEdgeCases(t *testing.T) {
|
||||
pool := &UsageRecordWorkerPool{overflowSamplePercent: 0}
|
||||
require.False(t, pool.shouldSyncFallback())
|
||||
|
||||
pool.overflowSamplePercent = 100
|
||||
require.True(t, pool.shouldSyncFallback())
|
||||
require.True(t, pool.shouldSyncFallback())
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_StatsAndStop_NilBranches(t *testing.T) {
|
||||
var nilPool *UsageRecordWorkerPool
|
||||
require.Equal(t, UsageRecordWorkerPoolStats{}, nilPool.Stats())
|
||||
require.NotPanics(t, func() { nilPool.Stop() })
|
||||
|
||||
emptyPool := &UsageRecordWorkerPool{}
|
||||
require.Equal(t, UsageRecordWorkerPoolStats{}, emptyPool.Stats())
|
||||
require.NotPanics(t, func() { emptyPool.Stop() })
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_Execute_PanicAndTimeout(t *testing.T) {
|
||||
pool := &UsageRecordWorkerPool{taskTimeout: 30 * time.Millisecond}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
pool.execute(func(ctx context.Context) {
|
||||
panic("boom")
|
||||
})
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
pool.execute(func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
close(done)
|
||||
})
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout context not cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageRecordWorkerPool_ResizeAndLogDropBranches(t *testing.T) {
|
||||
pool := NewUsageRecordWorkerPoolWithOptions(UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 8,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: config.UsageRecordOverflowPolicyDrop,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
|
||||
// 目标值与当前值相同,应该直接返回。
|
||||
pool.resizePool(1, 1, 0, 0, 0, 8, "noop")
|
||||
require.Equal(t, 1, pool.Stats().MaxConcurrency)
|
||||
|
||||
// 在限流窗口内应静默返回。
|
||||
pool.lastDropLogNanos.Store(time.Now().UnixNano())
|
||||
require.NotPanics(t, func() {
|
||||
pool.logDrop("full")
|
||||
})
|
||||
}
|
||||
@@ -300,6 +300,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
ProvideConcurrencyService,
|
||||
NewUsageRecordWorkerPool,
|
||||
ProvideSchedulerSnapshotService,
|
||||
NewIdentityService,
|
||||
NewCRSSyncService,
|
||||
|
||||
Reference in New Issue
Block a user