feat(subscription): 有界队列执行维护并改进鉴权解析
This commit is contained in:
@@ -76,6 +76,7 @@ func provideCleanup(
|
|||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
emailQueue *service.EmailQueueService,
|
emailQueue *service.EmailQueueService,
|
||||||
billingCache *service.BillingCacheService,
|
billingCache *service.BillingCacheService,
|
||||||
|
subscriptionService *service.SubscriptionService,
|
||||||
oauth *service.OAuthService,
|
oauth *service.OAuthService,
|
||||||
openaiOAuth *service.OpenAIOAuthService,
|
openaiOAuth *service.OpenAIOAuthService,
|
||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
@@ -150,6 +151,12 @@ func provideCleanup(
|
|||||||
subscriptionExpiry.Stop()
|
subscriptionExpiry.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SubscriptionService", func() error {
|
||||||
|
if subscriptionService != nil {
|
||||||
|
subscriptionService.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"PricingService", func() error {
|
{"PricingService", func() error {
|
||||||
pricing.Stop()
|
pricing.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -243,6 +243,7 @@ func provideCleanup(
|
|||||||
pricing *service.PricingService,
|
pricing *service.PricingService,
|
||||||
emailQueue *service.EmailQueueService,
|
emailQueue *service.EmailQueueService,
|
||||||
billingCache *service.BillingCacheService,
|
billingCache *service.BillingCacheService,
|
||||||
|
subscriptionService *service.SubscriptionService,
|
||||||
oauth *service.OAuthService,
|
oauth *service.OAuthService,
|
||||||
openaiOAuth *service.OpenAIOAuthService,
|
openaiOAuth *service.OpenAIOAuthService,
|
||||||
geminiOAuth *service.GeminiOAuthService,
|
geminiOAuth *service.GeminiOAuthService,
|
||||||
@@ -316,6 +317,12 @@ func provideCleanup(
|
|||||||
subscriptionExpiry.Stop()
|
subscriptionExpiry.Stop()
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SubscriptionService", func() error {
|
||||||
|
if subscriptionService != nil {
|
||||||
|
subscriptionService.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"PricingService", func() error {
|
{"PricingService", func() error {
|
||||||
pricing.Stop()
|
pricing.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -38,33 +38,34 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Security SecurityConfig `mapstructure:"security"`
|
Security SecurityConfig `mapstructure:"security"`
|
||||||
Billing BillingConfig `mapstructure:"billing"`
|
Billing BillingConfig `mapstructure:"billing"`
|
||||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
Ops OpsConfig `mapstructure:"ops"`
|
Ops OpsConfig `mapstructure:"ops"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
Totp TotpConfig `mapstructure:"totp"`
|
Totp TotpConfig `mapstructure:"totp"`
|
||||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||||
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
|
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
|
||||||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"`
|
||||||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||||
Sora SoraConfig `mapstructure:"sora"`
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
Sora SoraConfig `mapstructure:"sora"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
Update UpdateConfig `mapstructure:"update"`
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
|
Update UpdateConfig `mapstructure:"update"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiConfig struct {
|
type GeminiConfig struct {
|
||||||
@@ -609,6 +610,13 @@ type SubscriptionCacheConfig struct {
|
|||||||
JitterPercent int `mapstructure:"jitter_percent"`
|
JitterPercent int `mapstructure:"jitter_percent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。
|
||||||
|
// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。
|
||||||
|
type SubscriptionMaintenanceConfig struct {
|
||||||
|
WorkerCount int `mapstructure:"worker_count"`
|
||||||
|
QueueSize int `mapstructure:"queue_size"`
|
||||||
|
}
|
||||||
|
|
||||||
// DashboardCacheConfig 仪表盘统计缓存配置
|
// DashboardCacheConfig 仪表盘统计缓存配置
|
||||||
type DashboardCacheConfig struct {
|
type DashboardCacheConfig struct {
|
||||||
// Enabled: 是否启用仪表盘缓存
|
// Enabled: 是否启用仪表盘缓存
|
||||||
@@ -734,15 +742,6 @@ func Load() (*Config, error) {
|
|||||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||||||
|
|
||||||
if cfg.JWT.Secret == "" {
|
|
||||||
secret, err := generateJWTSecret(64)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generate jwt secret error: %w", err)
|
|
||||||
}
|
|
||||||
cfg.JWT.Secret = secret
|
|
||||||
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
|
||||||
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
|
||||||
if cfg.Totp.EncryptionKey == "" {
|
if cfg.Totp.EncryptionKey == "" {
|
||||||
@@ -1057,9 +1056,30 @@ func setDefaults() {
|
|||||||
// Security - proxy fallback
|
// Security - proxy fallback
|
||||||
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
|
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
|
||||||
|
|
||||||
|
// Subscription Maintenance (bounded queue + worker pool)
|
||||||
|
viper.SetDefault("subscription_maintenance.worker_count", 2)
|
||||||
|
viper.SetDefault("subscription_maintenance.queue_size", 1024)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
|
jwtSecret := strings.TrimSpace(c.JWT.Secret)
|
||||||
|
if jwtSecret == "" {
|
||||||
|
return fmt.Errorf("jwt.secret is required")
|
||||||
|
}
|
||||||
|
// NOTE: 按 UTF-8 编码后的字节长度计算。
|
||||||
|
// 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。
|
||||||
|
if len([]byte(jwtSecret)) < 32 {
|
||||||
|
return fmt.Errorf("jwt.secret must be at least 32 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.SubscriptionMaintenance.WorkerCount < 0 {
|
||||||
|
return fmt.Errorf("subscription_maintenance.worker_count must be non-negative")
|
||||||
|
}
|
||||||
|
if c.SubscriptionMaintenance.QueueSize < 0 {
|
||||||
|
return fmt.Errorf("subscription_maintenance.queue_size must be non-negative")
|
||||||
|
}
|
||||||
|
|
||||||
// Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。
|
// Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。
|
||||||
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
|
// 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。
|
||||||
geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID)
|
geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID)
|
||||||
|
|||||||
@@ -8,6 +8,12 @@ import (
|
|||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func resetViperWithJWTSecret(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
viper.Reset()
|
||||||
|
t.Setenv("JWT_SECRET", strings.Repeat("x", 32))
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeRunMode(t *testing.T) {
|
func TestNormalizeRunMode(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
@@ -29,7 +35,7 @@ func TestNormalizeRunMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -57,7 +63,7 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
|
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
@@ -71,7 +77,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadDefaultSecurityToggles(t *testing.T) {
|
func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -93,7 +99,7 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadDefaultServerMode(t *testing.T) {
|
func TestLoadDefaultServerMode(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -106,7 +112,7 @@ func TestLoadDefaultServerMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
|
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -119,7 +125,7 @@ func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -144,7 +150,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -169,7 +175,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
|
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -194,7 +200,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
|
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -214,7 +220,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
|
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -233,7 +239,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -270,7 +276,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
|
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -289,7 +295,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -308,7 +314,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
|
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -333,7 +339,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
|
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -352,7 +358,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
|
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -451,7 +457,7 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateServerFrontendURL(t *testing.T) {
|
func TestValidateServerFrontendURL(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -505,6 +511,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) {
|
|||||||
func TestWarnIfInsecureURL(t *testing.T) {
|
func TestWarnIfInsecureURL(t *testing.T) {
|
||||||
warnIfInsecureURL("test", "http://example.com")
|
warnIfInsecureURL("test", "http://example.com")
|
||||||
warnIfInsecureURL("test", "bad://url")
|
warnIfInsecureURL("test", "bad://url")
|
||||||
|
warnIfInsecureURL("test", "://invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
|
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
|
||||||
@@ -518,7 +525,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
|
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -536,7 +543,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateConcurrencyPingInterval(t *testing.T) {
|
func TestValidateConcurrencyPingInterval(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -553,14 +560,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProvideConfig(t *testing.T) {
|
func TestProvideConfig(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
if _, err := ProvideConfig(); err != nil {
|
if _, err := ProvideConfig(); err != nil {
|
||||||
t.Fatalf("ProvideConfig() error: %v", err)
|
t.Fatalf("ProvideConfig() error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
|
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -604,6 +611,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) {
|
||||||
|
d := &DatabaseConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "u",
|
||||||
|
Password: "p",
|
||||||
|
DBName: "db",
|
||||||
|
SSLMode: "prefer",
|
||||||
|
}
|
||||||
|
got := d.DSNWithTimezone("UTC")
|
||||||
|
if !strings.Contains(got, "password=p") {
|
||||||
|
t.Fatalf("DSNWithTimezone should include password: %q", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "TimeZone=UTC") {
|
||||||
|
t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
|
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
|
||||||
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
|
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
|
||||||
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
|
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
|
||||||
@@ -626,10 +651,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) {
|
|||||||
warnIfInsecureURL("secure", "https://example.com")
|
warnIfInsecureURL("secure", "https://example.com")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateJWTSecret_UTF8Bytes(t *testing.T) {
|
||||||
|
resetViperWithJWTSecret(t)
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 31 bytes (< 32) even though it's 31 characters.
|
||||||
|
cfg.JWT.Secret = strings.Repeat("a", 31)
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() should reject 31-byte secret")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "at least 32 bytes") {
|
||||||
|
t.Fatalf("Validate() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 32 bytes OK.
|
||||||
|
cfg.JWT.Secret = strings.Repeat("a", 32)
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Validate() should accept 32-byte secret: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateConfigErrors(t *testing.T) {
|
func TestValidateConfigErrors(t *testing.T) {
|
||||||
buildValid := func(t *testing.T) *Config {
|
buildValid := func(t *testing.T) *Config {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
viper.Reset()
|
resetViperWithJWTSecret(t)
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Load() error: %v", err)
|
t.Fatalf("Load() error: %v", err)
|
||||||
@@ -642,6 +692,26 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate func(*Config)
|
mutate func(*Config)
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
name: "jwt secret required",
|
||||||
|
mutate: func(c *Config) { c.JWT.Secret = "" },
|
||||||
|
wantErr: "jwt.secret is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "jwt secret min bytes",
|
||||||
|
mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) },
|
||||||
|
wantErr: "jwt.secret must be at least 32 bytes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subscription maintenance worker_count non-negative",
|
||||||
|
mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 },
|
||||||
|
wantErr: "subscription_maintenance.worker_count",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subscription maintenance queue_size non-negative",
|
||||||
|
mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 },
|
||||||
|
wantErr: "subscription_maintenance.queue_size",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "jwt expire hour positive",
|
name: "jwt expire hour positive",
|
||||||
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
|
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
|
||||||
|
|||||||
@@ -58,8 +58,13 @@ func adminAuth(
|
|||||||
authHeader := c.GetHeader("Authorization")
|
authHeader := c.GetHeader("Authorization")
|
||||||
if authHeader != "" {
|
if authHeader != "" {
|
||||||
parts := strings.SplitN(authHeader, " ", 2)
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||||
if !validateJWTForAdmin(c, parts[1], authService, userService) {
|
token := strings.TrimSpace(parts[1])
|
||||||
|
if token == "" {
|
||||||
|
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
|||||||
if authHeader != "" {
|
if authHeader != "" {
|
||||||
// 验证Bearer scheme
|
// 验证Bearer scheme
|
||||||
parts := strings.SplitN(authHeader, " ", 2)
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||||
apiKeyString = parts[1]
|
apiKeyString = strings.TrimSpace(parts[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +166,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
|||||||
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
||||||
if needsMaintenance {
|
if needsMaintenance {
|
||||||
maintenanceCopy := *subscription
|
maintenanceCopy := *subscription
|
||||||
go subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 余额模式:检查用户余额
|
// 余额模式:检查用户余额
|
||||||
|
|||||||
@@ -57,6 +57,57 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||||
|
cfg.SubscriptionMaintenance.WorkerCount = 1
|
||||||
|
cfg.SubscriptionMaintenance.QueueSize = 1
|
||||||
|
|
||||||
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
|
|
||||||
|
past := time.Now().Add(-48 * time.Hour)
|
||||||
|
sub := &service.UserSubscription{
|
||||||
|
ID: 55,
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: service.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
DailyWindowStart: &past,
|
||||||
|
DailyUsageUSD: 0,
|
||||||
|
}
|
||||||
|
maintenanceCalled := make(chan struct{}, 1)
|
||||||
|
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||||
|
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||||
|
clone := *sub
|
||||||
|
return &clone, nil
|
||||||
|
},
|
||||||
|
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||||
|
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||||
|
resetDaily: func(ctx context.Context, id int64, start time.Time) error {
|
||||||
|
maintenanceCalled <- struct{}{}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||||
|
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||||
|
}
|
||||||
|
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
|
||||||
|
t.Cleanup(subscriptionService.Stop)
|
||||||
|
|
||||||
|
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
req.Header.Set("x-api-key", apiKey.Key)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
select {
|
||||||
|
case <-maintenanceCalled:
|
||||||
|
// ok
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("expected maintenance to be scheduled")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||||
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
|
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
|
||||||
|
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
req.Header.Set("Authorization", "bearer "+apiKey.Key)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
|
|||||||
|
|
||||||
// 验证Bearer scheme
|
// 验证Bearer scheme
|
||||||
parts := strings.SplitN(authHeader, " ", 2)
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||||
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenString := parts[1]
|
tokenString := strings.TrimSpace(parts[1])
|
||||||
if tokenString == "" {
|
if tokenString == "" {
|
||||||
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -84,6 +84,28 @@ func TestJWTAuth_ValidToken(t *testing.T) {
|
|||||||
require.Equal(t, "user", body["role"])
|
require.Equal(t, "user", body["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
|
||||||
|
user := &service.User{
|
||||||
|
ID: 1,
|
||||||
|
Email: "test@example.com",
|
||||||
|
Role: "user",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Concurrency: 5,
|
||||||
|
TokenVersion: 1,
|
||||||
|
}
|
||||||
|
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||||
|
|
||||||
|
token, err := authSvc.GenerateToken(user)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||||
|
req.Header.Set("Authorization", "bearer "+token)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
|
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
|
||||||
router, _ := newJWTTestEnv(nil)
|
router, _ := newJWTTestEnv(nil)
|
||||||
|
|
||||||
|
|||||||
126
backend/internal/server/middleware/misc_coverage_test.go
Normal file
126
backend/internal/server/middleware/misc_coverage_test.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(ClientRequestID())
|
||||||
|
r.GET("/t", func(c *gin.Context) {
|
||||||
|
v := c.Request.Context().Value(ctxkey.ClientRequestID)
|
||||||
|
require.NotNil(t, v)
|
||||||
|
id, ok := v.(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, id)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientRequestID_PreservesExisting(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(ClientRequestID())
|
||||||
|
r.GET("/t", func(c *gin.Context) {
|
||||||
|
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "keep", id)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestBodyLimit(4))
|
||||||
|
r.POST("/t", func(c *gin.Context) {
|
||||||
|
_, err := io.ReadAll(c.Request.Body)
|
||||||
|
require.Error(t, err)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(ForcePlatform("anthropic"))
|
||||||
|
r.GET("/t", func(c *gin.Context) {
|
||||||
|
require.True(t, HasForcePlatform(c))
|
||||||
|
v, ok := GetForcePlatformFromContext(c)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "anthropic", v)
|
||||||
|
|
||||||
|
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
|
||||||
|
require.Equal(t, "anthropic", ctxV)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
|
||||||
|
c := &gin.Context{}
|
||||||
|
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
|
||||||
|
c.Set(string(ContextKeyUserRole), "admin")
|
||||||
|
|
||||||
|
sub, ok := GetAuthSubjectFromContext(c)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, int64(1), sub.UserID)
|
||||||
|
require.Equal(t, 2, sub.Concurrency)
|
||||||
|
|
||||||
|
role, ok := GetUserRoleFromContext(c)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "admin", role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
|
||||||
|
c := &gin.Context{}
|
||||||
|
|
||||||
|
key := &service.APIKey{ID: 1}
|
||||||
|
c.Set(string(ContextKeyAPIKey), key)
|
||||||
|
gotKey, ok := GetAPIKeyFromContext(c)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, int64(1), gotKey.ID)
|
||||||
|
|
||||||
|
sub := &service.UserSubscription{ID: 2}
|
||||||
|
c.Set(string(ContextKeySubscription), sub)
|
||||||
|
gotSub, ok := GetSubscriptionFromContext(c)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, int64(2), gotSub.ID)
|
||||||
|
}
|
||||||
75
backend/internal/service/subscription_maintenance_queue.go
Normal file
75
backend/internal/service/subscription_maintenance_queue.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SubscriptionMaintenanceQueue 提供“有界队列 + 固定 worker”的后台执行器。
|
||||||
|
// 用于从请求热路径触发维护动作时,避免无限 goroutine 膨胀。
|
||||||
|
type SubscriptionMaintenanceQueue struct {
|
||||||
|
queue chan func()
|
||||||
|
wg sync.WaitGroup
|
||||||
|
stop sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSubscriptionMaintenanceQueue(workerCount, queueSize int) *SubscriptionMaintenanceQueue {
|
||||||
|
if workerCount <= 0 {
|
||||||
|
workerCount = 1
|
||||||
|
}
|
||||||
|
if queueSize <= 0 {
|
||||||
|
queueSize = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
q := &SubscriptionMaintenanceQueue{
|
||||||
|
queue: make(chan func(), queueSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
q.wg.Add(workerCount)
|
||||||
|
for i := 0; i < workerCount; i++ {
|
||||||
|
go func(workerID int) {
|
||||||
|
defer q.wg.Done()
|
||||||
|
for fn := range q.queue {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("SubscriptionMaintenance worker panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
fn()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryEnqueue 尝试将任务入队。
|
||||||
|
// 当队列已满时返回 error(调用方应该选择跳过并记录告警/限频日志)。
|
||||||
|
func (q *SubscriptionMaintenanceQueue) TryEnqueue(task func()) error {
|
||||||
|
if q == nil {
|
||||||
|
return fmt.Errorf("maintenance queue is nil")
|
||||||
|
}
|
||||||
|
if task == nil {
|
||||||
|
return fmt.Errorf("maintenance task is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case q.queue <- task:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("maintenance queue full")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *SubscriptionMaintenanceQueue) Stop() {
|
||||||
|
if q == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q.stop.Do(func() {
|
||||||
|
close(q.queue)
|
||||||
|
q.wg.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSubscriptionMaintenanceQueue_TryEnqueue_QueueFull(t *testing.T) {
|
||||||
|
q := NewSubscriptionMaintenanceQueue(1, 1)
|
||||||
|
t.Cleanup(q.Stop)
|
||||||
|
|
||||||
|
block := make(chan struct{})
|
||||||
|
var started atomic.Int32
|
||||||
|
|
||||||
|
require.NoError(t, q.TryEnqueue(func() {
|
||||||
|
started.Store(1)
|
||||||
|
<-block
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Wait until worker started consuming the first task.
|
||||||
|
require.Eventually(t, func() bool { return started.Load() == 1 }, time.Second, 10*time.Millisecond)
|
||||||
|
|
||||||
|
// Queue size is 1; with the worker blocked, enqueueing one more should fill it.
|
||||||
|
require.NoError(t, q.TryEnqueue(func() {}))
|
||||||
|
|
||||||
|
// Now the queue is full; next enqueue must fail.
|
||||||
|
err := q.TryEnqueue(func() {})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "full")
|
||||||
|
|
||||||
|
close(block)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionMaintenanceQueue_TryEnqueue_PanicDoesNotKillWorker(t *testing.T) {
|
||||||
|
q := NewSubscriptionMaintenanceQueue(1, 8)
|
||||||
|
t.Cleanup(q.Stop)
|
||||||
|
|
||||||
|
require.NoError(t, q.TryEnqueue(func() { panic("boom") }))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
require.NoError(t, q.TryEnqueue(func() { close(done) }))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// ok
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("worker did not continue after panic")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -48,6 +48,8 @@ type SubscriptionService struct {
|
|||||||
subCacheGroup singleflight.Group
|
subCacheGroup singleflight.Group
|
||||||
subCacheTTL time.Duration
|
subCacheTTL time.Duration
|
||||||
subCacheJitter int // 抖动百分比
|
subCacheJitter int // 抖动百分比
|
||||||
|
|
||||||
|
maintenanceQueue *SubscriptionMaintenanceQueue
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSubscriptionService 创建订阅服务
|
// NewSubscriptionService 创建订阅服务
|
||||||
@@ -59,9 +61,31 @@ func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscript
|
|||||||
entClient: entClient,
|
entClient: entClient,
|
||||||
}
|
}
|
||||||
svc.initSubCache(cfg)
|
svc.initSubCache(cfg)
|
||||||
|
svc.initMaintenanceQueue(cfg)
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionService) initMaintenanceQueue(cfg *config.Config) {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mc := cfg.SubscriptionMaintenance
|
||||||
|
if mc.WorkerCount <= 0 || mc.QueueSize <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.maintenanceQueue = NewSubscriptionMaintenanceQueue(mc.WorkerCount, mc.QueueSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the maintenance worker pool.
|
||||||
|
func (s *SubscriptionService) Stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.maintenanceQueue != nil {
|
||||||
|
s.maintenanceQueue.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// initSubCache 初始化订阅 L1 缓存
|
// initSubCache 初始化订阅 L1 缓存
|
||||||
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
|
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
@@ -720,6 +744,23 @@ func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, grou
|
|||||||
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
|
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
|
||||||
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
|
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
|
||||||
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
|
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.maintenanceQueue != nil {
|
||||||
|
err := s.maintenanceQueue.TryEnqueue(func() {
|
||||||
|
s.doWindowMaintenance(sub)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Subscription maintenance enqueue failed: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.doWindowMaintenance(sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionService) doWindowMaintenance(sub *UserSubscription) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user