From 3fcb0cc37c48a8e06022a60ea221a589a85ee4c7 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 10 Feb 2026 00:37:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(subscription):=20=E6=9C=89=E7=95=8C?= =?UTF-8?q?=E9=98=9F=E5=88=97=E6=89=A7=E8=A1=8C=E7=BB=B4=E6=8A=A4=E5=B9=B6?= =?UTF-8?q?=E6=94=B9=E8=BF=9B=E9=89=B4=E6=9D=83=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire.go | 7 + backend/cmd/server/wire_gen.go | 9 +- backend/internal/config/config.go | 92 ++++++++----- backend/internal/config/config_test.go | 114 +++++++++++++--- .../internal/server/middleware/admin_auth.go | 9 +- .../server/middleware/api_key_auth.go | 6 +- .../server/middleware/api_key_auth_test.go | 65 +++++++++ .../internal/server/middleware/jwt_auth.go | 4 +- .../server/middleware/jwt_auth_test.go | 22 +++ .../server/middleware/misc_coverage_test.go | 126 ++++++++++++++++++ .../service/subscription_maintenance_queue.go | 75 +++++++++++ .../subscription_maintenance_queue_test.go | 54 ++++++++ .../internal/service/subscription_service.go | 41 ++++++ 13 files changed, 558 insertions(+), 66 deletions(-) create mode 100644 backend/internal/server/middleware/misc_coverage_test.go create mode 100644 backend/internal/service/subscription_maintenance_queue.go create mode 100644 backend/internal/service/subscription_maintenance_queue_test.go diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index c55ea844..18515236 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -76,6 +76,7 @@ func provideCleanup( pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -150,6 +151,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 8fb34a63..5c870934 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -204,7 +204,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -243,6 +243,7 @@ func provideCleanup( pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -316,6 +317,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index ac90f9a0..317ff1c1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -38,33 +38,34 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - Totp TotpConfig `mapstructure:"totp"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` - UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - Sora SoraConfig `mapstructure:"sora"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"` + SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Sora SoraConfig `mapstructure:"sora"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } type GeminiConfig struct { @@ -609,6 +610,13 @@ type SubscriptionCacheConfig struct { JitterPercent int `mapstructure:"jitter_percent"` } +// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。 +// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。 +type SubscriptionMaintenanceConfig struct { + WorkerCount int `mapstructure:"worker_count"` + QueueSize int `mapstructure:"queue_size"` +} + // DashboardCacheConfig 仪表盘统计缓存配置 type DashboardCacheConfig struct { // Enabled: 是否启用仪表盘缓存 @@ -734,15 +742,6 @@ func Load() (*Config, error) { cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy) - if cfg.JWT.Secret == "" { - secret, err := generateJWTSecret(64) - if err != nil { - return nil, fmt.Errorf("generate jwt secret error: %w", err) - } - cfg.JWT.Secret = secret - log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.") - } - // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) if cfg.Totp.EncryptionKey == "" { @@ -1057,9 +1056,30 @@ func setDefaults() { // Security - proxy fallback viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + // Subscription Maintenance (bounded queue + worker pool) + viper.SetDefault("subscription_maintenance.worker_count", 2) + viper.SetDefault("subscription_maintenance.queue_size", 1024) + } func (c *Config) Validate() error { + jwtSecret := strings.TrimSpace(c.JWT.Secret) + if jwtSecret == "" { + return fmt.Errorf("jwt.secret is required") + } + // NOTE: 按 UTF-8 编码后的字节长度计算。 + // 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。 + if len([]byte(jwtSecret)) < 32 { + return fmt.Errorf("jwt.secret must be at least 32 bytes") + } + + if c.SubscriptionMaintenance.WorkerCount < 0 { + return fmt.Errorf("subscription_maintenance.worker_count must be non-negative") + } + if c.SubscriptionMaintenance.QueueSize < 0 { + return fmt.Errorf("subscription_maintenance.queue_size must be non-negative") + } + // Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。 // 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。 geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index a645d343..0f02a8bd 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -8,6 +8,12 @@ import ( "github.com/spf13/viper" ) +func resetViperWithJWTSecret(t *testing.T) { + t.Helper() + viper.Reset() + t.Setenv("JWT_SECRET", strings.Repeat("x", 32)) +} + func TestNormalizeRunMode(t *testing.T) { tests := []struct { input string @@ -29,7 +35,7 @@ func TestNormalizeRunMode(t *testing.T) { } func TestLoadDefaultSchedulingConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -57,7 +63,7 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { } func TestLoadSchedulingConfigFromEnv(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") cfg, err := Load() @@ -71,7 +77,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { } func TestLoadDefaultSecurityToggles(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -93,7 +99,7 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { } func TestLoadDefaultServerMode(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -106,7 +112,7 @@ func TestLoadDefaultServerMode(t *testing.T) { } func TestLoadDefaultDatabaseSSLMode(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -119,7 +125,7 @@ func TestLoadDefaultDatabaseSSLMode(t *testing.T) { } func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -144,7 +150,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { } func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -169,7 +175,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { } func TestLoadDefaultDashboardCacheConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -194,7 +200,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) { } func TestValidateDashboardCacheConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -214,7 +220,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) { } func TestValidateDashboardCacheConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -233,7 +239,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) { } func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -270,7 +276,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { } func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -289,7 +295,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { } func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -308,7 +314,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { } func TestLoadDefaultUsageCleanupConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -333,7 +339,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) { } func TestValidateUsageCleanupConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -352,7 +358,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) { } func TestValidateUsageCleanupConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -451,7 +457,7 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { } func TestValidateServerFrontendURL(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -505,6 +511,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) { func TestWarnIfInsecureURL(t *testing.T) { warnIfInsecureURL("test", "http://example.com") warnIfInsecureURL("test", "bad://url") + warnIfInsecureURL("test", "://invalid") } func TestGenerateJWTSecretDefaultLength(t *testing.T) { @@ -518,7 +525,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) { } func TestValidateOpsCleanupScheduleRequired(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -536,7 +543,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) { } func TestValidateConcurrencyPingInterval(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -553,14 +560,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) { } func TestProvideConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) if _, err := ProvideConfig(); err != nil { t.Fatalf("ProvideConfig() error: %v", err) } } func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -604,6 +611,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) { } } +func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) { + d := &DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "u", + Password: "p", + DBName: "db", + SSLMode: "prefer", + } + got := d.DSNWithTimezone("UTC") + if !strings.Contains(got, "password=p") { + t.Fatalf("DSNWithTimezone should include password: %q", got) + } + if !strings.Contains(got, "TimeZone=UTC") { + t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got) + } +} + func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) { if err := ValidateAbsoluteHTTPURL("https://"); err == nil { t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host") @@ -626,10 +651,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) { warnIfInsecureURL("secure", "https://example.com") } +func TestValidateJWTSecret_UTF8Bytes(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // 31 bytes (< 32) even though it's 31 characters. + cfg.JWT.Secret = strings.Repeat("a", 31) + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() should reject 31-byte secret") + } + if !strings.Contains(err.Error(), "at least 32 bytes") { + t.Fatalf("Validate() error = %v", err) + } + + // 32 bytes OK. + cfg.JWT.Secret = strings.Repeat("a", 32) + err = cfg.Validate() + if err != nil { + t.Fatalf("Validate() should accept 32-byte secret: %v", err) + } +} + func TestValidateConfigErrors(t *testing.T) { buildValid := func(t *testing.T) *Config { t.Helper() - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { t.Fatalf("Load() error: %v", err) @@ -642,6 +692,26 @@ func TestValidateConfigErrors(t *testing.T) { mutate func(*Config) wantErr string }{ + { + name: "jwt secret required", + mutate: func(c *Config) { c.JWT.Secret = "" }, + wantErr: "jwt.secret is required", + }, + { + name: "jwt secret min bytes", + mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) }, + wantErr: "jwt.secret must be at least 32 bytes", + }, + { + name: "subscription maintenance worker_count non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 }, + wantErr: "subscription_maintenance.worker_count", + }, + { + name: "subscription maintenance queue_size non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 }, + wantErr: "subscription_maintenance.queue_size", + }, { name: "jwt expire hour positive", mutate: func(c *Config) { c.JWT.ExpireHour = 0 }, diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index 4167b7ab..6f294ff0 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -58,8 +58,13 @@ func adminAuth( authHeader := c.GetHeader("Authorization") if authHeader != "" { parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - if !validateJWTForAdmin(c, parts[1], authService, userService) { + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + token := strings.TrimSpace(parts[1]) + if token == "" { + AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required") + return + } + if !validateJWTForAdmin(c, token, authService, userService) { return } c.Next() diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 4525aee7..8e03f785 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -35,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti if authHeader != "" { // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - apiKeyString = parts[1] + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + apiKeyString = strings.TrimSpace(parts[1]) } } @@ -166,7 +166,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race if needsMaintenance { maintenanceCopy := *subscription - go subscriptionService.DoWindowMaintenance(&maintenanceCopy) + subscriptionService.DoWindowMaintenance(&maintenanceCopy) } } else { // 余额模式:检查用户余额 diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 6d1f8ecd..3e33c7e3 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -57,6 +57,57 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { }, } + t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + cfg.SubscriptionMaintenance.WorkerCount = 1 + cfg.SubscriptionMaintenance.QueueSize = 1 + + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + + past := time.Now().Add(-48 * time.Hour) + sub := &service.UserSubscription{ + ID: 55, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + DailyWindowStart: &past, + DailyUsageUSD: 0, + } + maintenanceCalled := make(chan struct{}, 1) + subscriptionRepo := &stubUserSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { + maintenanceCalled <- struct{}{} + return nil + }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + } + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) + t.Cleanup(subscriptionService.Stop) + + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + select { + case <-maintenanceCalled: + // ok + case <-time.After(time.Second): + t.Fatalf("expected maintenance to be scheduled") + } + }) + t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeSimple} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) @@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) }) + t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "bearer "+apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeStandard} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go index 9a89aab7..4aceb355 100644 --- a/backend/internal/server/middleware/jwt_auth.go +++ b/backend/internal/server/middleware/jwt_auth.go @@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 || parts[0] != "Bearer" { + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'") return } - tokenString := parts[1] + tokenString := strings.TrimSpace(parts[1]) if tokenString == "" { AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty") return diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index e1b8e1ad..bc320958 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -84,6 +84,28 @@ func TestJWTAuth_ValidToken(t *testing.T) { require.Equal(t, "user", body["role"]) } +func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { router, _ := newJWTTestEnv(nil) diff --git a/backend/internal/server/middleware/misc_coverage_test.go b/backend/internal/server/middleware/misc_coverage_test.go new file mode 100644 index 00000000..c0adfc4d --- /dev/null +++ b/backend/internal/server/middleware/misc_coverage_test.go @@ -0,0 +1,126 @@ +//go:build unit + +package middleware + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestClientRequestID_GeneratesWhenMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + v := c.Request.Context().Value(ctxkey.ClientRequestID) + require.NotNil(t, v) + id, ok := v.(string) + require.True(t, ok) + require.NotEmpty(t, id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestClientRequestID_PreservesExisting(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string) + require.True(t, ok) + require.Equal(t, "keep", id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestRequestBodyLimit_LimitsBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RequestBodyLimit(4)) + r.POST("/t", func(c *gin.Context) { + _, err := io.ReadAll(c.Request.Body) + require.Error(t, err) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestForcePlatform_SetsContextAndGinValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ForcePlatform("anthropic")) + r.GET("/t", func(c *gin.Context) { + require.True(t, HasForcePlatform(c)) + v, ok := GetForcePlatformFromContext(c) + require.True(t, ok) + require.Equal(t, "anthropic", v) + + ctxV := c.Request.Context().Value(ctxkey.ForcePlatform) + require.Equal(t, "anthropic", ctxV) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthSubjectHelpers_RoundTrip(t *testing.T) { + c := &gin.Context{} + c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2}) + c.Set(string(ContextKeyUserRole), "admin") + + sub, ok := GetAuthSubjectFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), sub.UserID) + require.Equal(t, 2, sub.Concurrency) + + role, ok := GetUserRoleFromContext(c) + require.True(t, ok) + require.Equal(t, "admin", role) +} + +func TestAPIKeyAndSubscriptionFromContext(t *testing.T) { + c := &gin.Context{} + + key := &service.APIKey{ID: 1} + c.Set(string(ContextKeyAPIKey), key) + gotKey, ok := GetAPIKeyFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), gotKey.ID) + + sub := &service.UserSubscription{ID: 2} + c.Set(string(ContextKeySubscription), sub) + gotSub, ok := GetSubscriptionFromContext(c) + require.True(t, ok) + require.Equal(t, int64(2), gotSub.ID) +} diff --git a/backend/internal/service/subscription_maintenance_queue.go b/backend/internal/service/subscription_maintenance_queue.go new file mode 100644 index 00000000..52ad6472 --- /dev/null +++ b/backend/internal/service/subscription_maintenance_queue.go @@ -0,0 +1,75 @@ +package service + +import ( + "fmt" + "log" + "sync" +) + +// SubscriptionMaintenanceQueue 提供“有界队列 + 固定 worker”的后台执行器。 +// 用于从请求热路径触发维护动作时,避免无限 goroutine 膨胀。 +type SubscriptionMaintenanceQueue struct { + queue chan func() + wg sync.WaitGroup + stop sync.Once +} + +func NewSubscriptionMaintenanceQueue(workerCount, queueSize int) *SubscriptionMaintenanceQueue { + if workerCount <= 0 { + workerCount = 1 + } + if queueSize <= 0 { + queueSize = 1 + } + + q := &SubscriptionMaintenanceQueue{ + queue: make(chan func(), queueSize), + } + + q.wg.Add(workerCount) + for i := 0; i < workerCount; i++ { + go func(workerID int) { + defer q.wg.Done() + for fn := range q.queue { + func() { + defer func() { + if r := recover(); r != nil { + log.Printf("SubscriptionMaintenance worker panic: %v", r) + } + }() + fn() + }() + } + }(i) + } + + return q +} + +// TryEnqueue 尝试将任务入队。 +// 当队列已满时返回 error(调用方应该选择跳过并记录告警/限频日志)。 +func (q *SubscriptionMaintenanceQueue) TryEnqueue(task func()) error { + if q == nil { + return fmt.Errorf("maintenance queue is nil") + } + if task == nil { + return fmt.Errorf("maintenance task is nil") + } + + select { + case q.queue <- task: + return nil + default: + return fmt.Errorf("maintenance queue full") + } +} + +func (q *SubscriptionMaintenanceQueue) Stop() { + if q == nil { + return + } + q.stop.Do(func() { + close(q.queue) + q.wg.Wait() + }) +} diff --git a/backend/internal/service/subscription_maintenance_queue_test.go b/backend/internal/service/subscription_maintenance_queue_test.go new file mode 100644 index 00000000..69034bb9 --- /dev/null +++ b/backend/internal/service/subscription_maintenance_queue_test.go @@ -0,0 +1,54 @@ +//go:build unit + +package service + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSubscriptionMaintenanceQueue_TryEnqueue_QueueFull(t *testing.T) { + q := NewSubscriptionMaintenanceQueue(1, 1) + t.Cleanup(q.Stop) + + block := make(chan struct{}) + var started atomic.Int32 + + require.NoError(t, q.TryEnqueue(func() { + started.Store(1) + <-block + })) + + // Wait until worker started consuming the first task. + require.Eventually(t, func() bool { return started.Load() == 1 }, time.Second, 10*time.Millisecond) + + // Queue size is 1; with the worker blocked, enqueueing one more should fill it. + require.NoError(t, q.TryEnqueue(func() {})) + + // Now the queue is full; next enqueue must fail. + err := q.TryEnqueue(func() {}) + require.Error(t, err) + require.Contains(t, err.Error(), "full") + + close(block) +} + +func TestSubscriptionMaintenanceQueue_TryEnqueue_PanicDoesNotKillWorker(t *testing.T) { + q := NewSubscriptionMaintenanceQueue(1, 8) + t.Cleanup(q.Stop) + + require.NoError(t, q.TryEnqueue(func() { panic("boom") })) + + done := make(chan struct{}) + require.NoError(t, q.TryEnqueue(func() { close(done) })) + + select { + case <-done: + // ok + case <-time.After(time.Second): + t.Fatalf("worker did not continue after panic") + } +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 4360b261..29ef3662 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -48,6 +48,8 @@ type SubscriptionService struct { subCacheGroup singleflight.Group subCacheTTL time.Duration subCacheJitter int // 抖动百分比 + + maintenanceQueue *SubscriptionMaintenanceQueue } // NewSubscriptionService 创建订阅服务 @@ -59,9 +61,31 @@ func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscript entClient: entClient, } svc.initSubCache(cfg) + svc.initMaintenanceQueue(cfg) return svc } +func (s *SubscriptionService) initMaintenanceQueue(cfg *config.Config) { + if cfg == nil { + return + } + mc := cfg.SubscriptionMaintenance + if mc.WorkerCount <= 0 || mc.QueueSize <= 0 { + return + } + s.maintenanceQueue = NewSubscriptionMaintenanceQueue(mc.WorkerCount, mc.QueueSize) +} + +// Stop stops the maintenance worker pool. +func (s *SubscriptionService) Stop() { + if s == nil { + return + } + if s.maintenanceQueue != nil { + s.maintenanceQueue.Stop() + } +} + // initSubCache 初始化订阅 L1 缓存 func (s *SubscriptionService) initSubCache(cfg *config.Config) { if cfg == nil { @@ -720,6 +744,23 @@ func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, grou // 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误, // 因此进入此方法的订阅一定未过期,无需处理过期状态同步。 func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) { + if s == nil { + return + } + if s.maintenanceQueue != nil { + err := s.maintenanceQueue.TryEnqueue(func() { + s.doWindowMaintenance(sub) + }) + if err != nil { + log.Printf("Subscription maintenance enqueue failed: %v", err) + } + return + } + + s.doWindowMaintenance(sub) +} + +func (s *SubscriptionService) doWindowMaintenance(sub *UserSubscription) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()