diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 0a5f9744..5ef04a66 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -70,6 +70,7 @@ func provideCleanup( schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, + usageCleanup *service.UsageCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -123,6 +124,12 @@ func provideCleanup( } return nil }}, + {"UsageCleanupService", func() error { + if usageCleanup != nil { + usageCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 27404b02..509cf13a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -153,7 +153,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) systemHandler := handler.ProvideSystemHandler(updateService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) - adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) + usageCleanupRepository := repository.NewUsageCleanupRepository(db) + usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) + adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService) userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) @@ -175,7 +177,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -208,6 +210,7 @@ func provideCleanup( schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, + usageCleanup *service.UsageCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -260,6 +263,12 @@ func provideCleanup( } return nil }}, + {"UsageCleanupService", func() error { + if usageCleanup != nil { + usageCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil diff --git a/backend/go.mod b/backend/go.mod index 4ac6ba14..9ebae69e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -31,6 +31,7 @@ require ( ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 415e73a7..4496603d 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -141,6 +141,7 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 5dc6ad19..d616e44b 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -55,6 +55,7 @@ type Config struct { APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` @@ -489,6 +490,20 @@ type DashboardAggregationRetentionConfig struct { DailyDays int `mapstructure:"daily_days"` } +// UsageCleanupConfig 使用记录清理任务配置 +type UsageCleanupConfig struct { + // Enabled: 是否启用清理任务执行器 + Enabled bool `mapstructure:"enabled"` + // MaxRangeDays: 单次任务允许的最大时间跨度(天) + MaxRangeDays int `mapstructure:"max_range_days"` + // BatchSize: 单批删除数量 + BatchSize int `mapstructure:"batch_size"` + // WorkerIntervalSeconds: 后台任务轮询间隔(秒) + WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"` + // TaskTimeoutSeconds: 单次任务最大执行时长(秒) + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -749,6 +764,13 @@ func setDefaults() { viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) viper.SetDefault("dashboard_aggregation.recompute_days", 2) + // Usage cleanup task + viper.SetDefault("usage_cleanup.enabled", true) + viper.SetDefault("usage_cleanup.max_range_days", 31) + viper.SetDefault("usage_cleanup.batch_size", 5000) + viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) + viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", true) @@ -985,6 +1007,33 @@ func (c *Config) Validate() error { return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") } } + if c.UsageCleanup.Enabled { + if c.UsageCleanup.MaxRangeDays <= 0 { + return fmt.Errorf("usage_cleanup.max_range_days must be positive") + } + if c.UsageCleanup.BatchSize <= 0 { + return fmt.Errorf("usage_cleanup.batch_size must be positive") + } + if c.UsageCleanup.WorkerIntervalSeconds <= 0 { + return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive") + } + if c.UsageCleanup.TaskTimeoutSeconds <= 0 { + return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive") + } + } else { + if c.UsageCleanup.MaxRangeDays < 0 { + return fmt.Errorf("usage_cleanup.max_range_days must be non-negative") + } + if c.UsageCleanup.BatchSize < 0 { + return fmt.Errorf("usage_cleanup.batch_size must be non-negative") + } + if c.UsageCleanup.WorkerIntervalSeconds < 0 { + return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative") + } + if c.UsageCleanup.TaskTimeoutSeconds < 0 { + return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") + } + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 4637989e..f734619f 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { t.Fatalf("Validate() expected backfill_max_days error, got: %v", err) } } + +func TestLoadDefaultUsageCleanupConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.UsageCleanup.Enabled { + t.Fatalf("UsageCleanup.Enabled = false, want true") + } + if cfg.UsageCleanup.MaxRangeDays != 31 { + t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays) + } + if cfg.UsageCleanup.BatchSize != 5000 { + t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize) + } + if cfg.UsageCleanup.WorkerIntervalSeconds != 10 { + t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds) + } + if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 { + t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds) + } +} + +func TestValidateUsageCleanupConfigEnabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.UsageCleanup.Enabled = true + cfg.UsageCleanup.MaxRangeDays = 0 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil") + } + if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") { + t.Fatalf("Validate() expected max_range_days error, got: %v", err) + } +} + +func TestValidateUsageCleanupConfigDisabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.UsageCleanup.Enabled = false + cfg.UsageCleanup.BatchSize = -1 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil") + } + if !strings.Contains(err.Error(), "usage_cleanup.batch_size") { + t.Fatalf("Validate() expected batch_size error, got: %v", err) + } +} + +func TestConfigAddressHelpers(t *testing.T) { + server := ServerConfig{Host: "127.0.0.1", Port: 9000} + if server.Address() != "127.0.0.1:9000" { + t.Fatalf("ServerConfig.Address() = %q", server.Address()) + } + + dbCfg := DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "postgres", + Password: "", + DBName: "sub2api", + SSLMode: "disable", + } + if !strings.Contains(dbCfg.DSN(), "password=") { + } else { + t.Fatalf("DatabaseConfig.DSN() should not include password when empty") + } + + dbCfg.Password = "secret" + if !strings.Contains(dbCfg.DSN(), "password=secret") { + t.Fatalf("DatabaseConfig.DSN() missing password") + } + + dbCfg.Password = "" + if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") { + t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty") + } + + if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") { + t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone") + } + if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") { + t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone") + } + + redis := RedisConfig{Host: "redis", Port: 6379} + if redis.Address() != "redis:6379" { + t.Fatalf("RedisConfig.Address() = %q", redis.Address()) + } +} + +func TestNormalizeStringSlice(t *testing.T) { + values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"}) + if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" { + t.Fatalf("normalizeStringSlice() unexpected result: %#v", values) + } + if normalizeStringSlice(nil) != nil { + t.Fatalf("normalizeStringSlice(nil) expected nil slice") + } +} + +func TestGetServerAddressFromEnv(t *testing.T) { + t.Setenv("SERVER_HOST", "127.0.0.1") + t.Setenv("SERVER_PORT", "9090") + + address := GetServerAddress() + if address != "127.0.0.1:9090" { + t.Fatalf("GetServerAddress() = %q", address) + } +} + +func TestValidateAbsoluteHTTPURL(t *testing.T) { + if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil { + t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err) + } + if err := ValidateAbsoluteHTTPURL(""); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url") + } + if err := ValidateAbsoluteHTTPURL("/relative"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url") + } + if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme") + } + if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment") + } +} + +func TestValidateFrontendRedirectURL(t *testing.T) { + if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { + t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) + } + if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil { + t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err) + } + if err := ValidateFrontendRedirectURL("example.com/path"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url") + } + if err := ValidateFrontendRedirectURL("//evil.com"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject // prefix") + } + if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme") + } +} + +func TestWarnIfInsecureURL(t *testing.T) { + warnIfInsecureURL("test", "http://example.com") + warnIfInsecureURL("test", "bad://url") +} + +func TestGenerateJWTSecretDefaultLength(t *testing.T) { + secret, err := generateJWTSecret(0) + if err != nil { + t.Fatalf("generateJWTSecret error: %v", err) + } + if len(secret) == 0 { + t.Fatalf("generateJWTSecret returned empty string") + } +} + +func TestValidateOpsCleanupScheduleRequired(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + cfg.Ops.Cleanup.Enabled = true + cfg.Ops.Cleanup.Schedule = "" + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for ops.cleanup.schedule") + } + if !strings.Contains(err.Error(), "ops.cleanup.schedule") { + t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err) + } +} + +func TestValidateConcurrencyPingInterval(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + cfg.Concurrency.PingInterval = 3 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for concurrency.ping_interval") + } + if !strings.Contains(err.Error(), "concurrency.ping_interval") { + t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err) + } +} + +func TestProvideConfig(t *testing.T) { + viper.Reset() + if _, err := ProvideConfig(); err != nil { + t.Fatalf("ProvideConfig() error: %v", err) + } +} + +func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Security.CSP.Enabled = true + cfg.Security.CSP.Policy = "default-src 'self'" + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "client" + cfg.LinuxDo.ClientSecret = "secret" + cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize" + cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token" + cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() unexpected error: %v", err) + } +} + +func TestValidateJWTSecretStrength(t *testing.T) { + if !isWeakJWTSecret("change-me-in-production") { + t.Fatalf("isWeakJWTSecret should detect weak secret") + } + if isWeakJWTSecret("StrongSecretValue") { + t.Fatalf("isWeakJWTSecret should accept strong secret") + } +} + +func TestGenerateJWTSecretWithLength(t *testing.T) { + secret, err := generateJWTSecret(16) + if err != nil { + t.Fatalf("generateJWTSecret error: %v", err) + } + if len(secret) == 0 { + t.Fatalf("generateJWTSecret returned empty string") + } +} + +func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) { + if err := ValidateAbsoluteHTTPURL("https://"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host") + } +} + +func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) { + if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars") + } + if err := ValidateFrontendRedirectURL("http://"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject missing host") + } + if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject mailto") + } +} + +func TestWarnIfInsecureURLHTTPS(t *testing.T) { + warnIfInsecureURL("secure", "https://example.com") +} + +func TestValidateConfigErrors(t *testing.T) { + buildValid := func(t *testing.T) *Config { + t.Helper() + viper.Reset() + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + return cfg + } + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "jwt expire hour positive", + mutate: func(c *Config) { c.JWT.ExpireHour = 0 }, + wantErr: "jwt.expire_hour must be positive", + }, + { + name: "jwt expire hour max", + mutate: func(c *Config) { c.JWT.ExpireHour = 200 }, + wantErr: "jwt.expire_hour must be <= 168", + }, + { + name: "csp policy required", + mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" }, + wantErr: "security.csp.policy", + }, + { + name: "linuxdo client id required", + mutate: func(c *Config) { + c.LinuxDo.Enabled = true + c.LinuxDo.ClientID = "" + }, + wantErr: "linuxdo_connect.client_id", + }, + { + name: "linuxdo token auth method", + mutate: func(c *Config) { + c.LinuxDo.Enabled = true + c.LinuxDo.ClientID = "client" + c.LinuxDo.ClientSecret = "secret" + c.LinuxDo.AuthorizeURL = "https://example.com/authorize" + c.LinuxDo.TokenURL = "https://example.com/token" + c.LinuxDo.UserInfoURL = "https://example.com/userinfo" + c.LinuxDo.RedirectURL = "https://example.com/callback" + c.LinuxDo.FrontendRedirectURL = "/auth/callback" + c.LinuxDo.TokenAuthMethod = "invalid" + }, + wantErr: "linuxdo_connect.token_auth_method", + }, + { + name: "billing circuit breaker threshold", + mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 }, + wantErr: "billing.circuit_breaker.failure_threshold", + }, + { + name: "billing circuit breaker reset", + mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 }, + wantErr: "billing.circuit_breaker.reset_timeout_seconds", + }, + { + name: "billing circuit breaker half open", + mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 }, + wantErr: "billing.circuit_breaker.half_open_requests", + }, + { + name: "database max open conns", + mutate: func(c *Config) { c.Database.MaxOpenConns = 0 }, + wantErr: "database.max_open_conns", + }, + { + name: "database max lifetime", + mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 }, + wantErr: "database.conn_max_lifetime_minutes", + }, + { + name: "database idle exceeds open", + mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 }, + wantErr: "database.max_idle_conns cannot exceed", + }, + { + name: "redis dial timeout", + mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 }, + wantErr: "redis.dial_timeout_seconds", + }, + { + name: "redis read timeout", + mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 }, + wantErr: "redis.read_timeout_seconds", + }, + { + name: "redis write timeout", + mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 }, + wantErr: "redis.write_timeout_seconds", + }, + { + name: "redis pool size", + mutate: func(c *Config) { c.Redis.PoolSize = 0 }, + wantErr: "redis.pool_size", + }, + { + name: "redis idle exceeds pool", + mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 }, + wantErr: "redis.min_idle_conns cannot exceed", + }, + { + name: "dashboard cache disabled negative", + mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 }, + wantErr: "dashboard_cache.stats_ttl_seconds", + }, + { + name: "dashboard cache fresh ttl positive", + mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 }, + wantErr: "dashboard_cache.stats_fresh_ttl_seconds", + }, + { + name: "dashboard aggregation enabled interval", + mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 }, + wantErr: "dashboard_aggregation.interval_seconds", + }, + { + name: "dashboard aggregation backfill positive", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.BackfillEnabled = true + c.DashboardAgg.BackfillMaxDays = 0 + }, + wantErr: "dashboard_aggregation.backfill_max_days", + }, + { + name: "dashboard aggregation retention", + mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 }, + wantErr: "dashboard_aggregation.retention.usage_logs_days", + }, + { + name: "dashboard aggregation disabled interval", + mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 }, + wantErr: "dashboard_aggregation.interval_seconds", + }, + { + name: "usage cleanup max range", + mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 }, + wantErr: "usage_cleanup.max_range_days", + }, + { + name: "usage cleanup worker interval", + mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 }, + wantErr: "usage_cleanup.worker_interval_seconds", + }, + { + name: "usage cleanup batch size", + mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 }, + wantErr: "usage_cleanup.batch_size", + }, + { + name: "usage cleanup disabled negative", + mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 }, + wantErr: "usage_cleanup.batch_size", + }, + { + name: "gateway max body size", + mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 }, + wantErr: "gateway.max_body_size", + }, + { + name: "gateway max idle conns", + mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 }, + wantErr: "gateway.max_idle_conns", + }, + { + name: "gateway max idle conns per host", + mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 }, + wantErr: "gateway.max_idle_conns_per_host", + }, + { + name: "gateway idle timeout", + mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 }, + wantErr: "gateway.idle_conn_timeout_seconds", + }, + { + name: "gateway max upstream clients", + mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 }, + wantErr: "gateway.max_upstream_clients", + }, + { + name: "gateway client idle ttl", + mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 }, + wantErr: "gateway.client_idle_ttl_seconds", + }, + { + name: "gateway concurrency slot ttl", + mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 }, + wantErr: "gateway.concurrency_slot_ttl_minutes", + }, + { + name: "gateway max conns per host", + mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 }, + wantErr: "gateway.max_conns_per_host", + }, + { + name: "gateway connection isolation", + mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" }, + wantErr: "gateway.connection_pool_isolation", + }, + { + name: "gateway stream keepalive range", + mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 }, + wantErr: "gateway.stream_keepalive_interval", + }, + { + name: "gateway stream data interval range", + mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 }, + wantErr: "gateway.stream_data_interval_timeout", + }, + { + name: "gateway stream data interval negative", + mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 }, + wantErr: "gateway.stream_data_interval_timeout must be non-negative", + }, + { + name: "gateway max line size", + mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 }, + wantErr: "gateway.max_line_size must be at least", + }, + { + name: "gateway max line size negative", + mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, + wantErr: "gateway.max_line_size must be non-negative", + }, + { + name: "gateway scheduling sticky waiting", + mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, + wantErr: "gateway.scheduling.sticky_session_max_waiting", + }, + { + name: "gateway scheduling outbox poll", + mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 }, + wantErr: "gateway.scheduling.outbox_poll_interval_seconds", + }, + { + name: "gateway scheduling outbox failures", + mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 }, + wantErr: "gateway.scheduling.outbox_lag_rebuild_failures", + }, + { + name: "gateway outbox lag rebuild", + mutate: func(c *Config) { + c.Gateway.Scheduling.OutboxLagWarnSeconds = 10 + c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5 + }, + wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds", + }, + { + name: "ops metrics collector ttl", + mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 }, + wantErr: "ops.metrics_collector_cache.ttl", + }, + { + name: "ops cleanup retention", + mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 }, + wantErr: "ops.cleanup.error_log_retention_days", + }, + { + name: "ops cleanup minute retention", + mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 }, + wantErr: "ops.cleanup.minute_metrics_retention_days", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + cfg := buildValid(t) + tt.mutate(cfg) + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr) + } + }) + } +} diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go new file mode 100644 index 00000000..e0f731e1 --- /dev/null +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -0,0 +1,262 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAdminRouter() (*gin.Engine, *stubAdminService) { + gin.SetMode(gin.TestMode) + router := gin.New() + adminSvc := newStubAdminService() + + userHandler := NewUserHandler(adminSvc) + groupHandler := NewGroupHandler(adminSvc) + proxyHandler := NewProxyHandler(adminSvc) + redeemHandler := NewRedeemHandler(adminSvc) + + router.GET("/api/v1/admin/users", userHandler.List) + router.GET("/api/v1/admin/users/:id", userHandler.GetByID) + router.POST("/api/v1/admin/users", userHandler.Create) + router.PUT("/api/v1/admin/users/:id", userHandler.Update) + router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) + router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance) + router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys) + router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage) + + router.GET("/api/v1/admin/groups", groupHandler.List) + router.GET("/api/v1/admin/groups/all", groupHandler.GetAll) + router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID) + router.POST("/api/v1/admin/groups", groupHandler.Create) + router.PUT("/api/v1/admin/groups/:id", groupHandler.Update) + router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete) + router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats) + router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys) + + router.GET("/api/v1/admin/proxies", proxyHandler.List) + router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll) + router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID) + router.POST("/api/v1/admin/proxies", proxyHandler.Create) + router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update) + router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) + router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) + router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) + router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) + router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) + + router.GET("/api/v1/admin/redeem-codes", redeemHandler.List) + router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID) + router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate) + router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete) + router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete) + router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire) + router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats) + + return router, adminSvc +} + +func TestUserHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} + body, _ := json.Marshal(createBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + updateBody := map[string]any{"email": "updated@example.com"} + body, _ = json.Marshal(updateBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestGroupHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ = json.Marshal(map[string]any{"name": "update"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestProxyHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ = json.Marshal(map[string]any{"name": "proxy2"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestRedeemHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go new file mode 100644 index 00000000..863c755c --- /dev/null +++ b/backend/internal/handler/admin/admin_helpers_test.go @@ -0,0 +1,134 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestParseTimeRange(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil) + c.Request = req + + start, end := parseTimeRange(c) + require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start) + require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end) + + req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil) + c.Request = req + start, end = parseTimeRange(c) + require.False(t, start.IsZero()) + require.False(t, end.IsZero()) +} + +func TestParseOpsViewParam(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil) + require.Equal(t, opsListViewExcluded, parseOpsViewParam(c)) + + c2, _ := gin.CreateTestContext(w) + c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil) + require.Equal(t, opsListViewAll, parseOpsViewParam(c2)) + + c3, _ := gin.CreateTestContext(w) + c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil) + require.Equal(t, opsListViewErrors, parseOpsViewParam(c3)) + + require.Equal(t, "", parseOpsViewParam(nil)) +} + +func TestParseOpsDuration(t *testing.T) { + dur, ok := parseOpsDuration("1h") + require.True(t, ok) + require.Equal(t, time.Hour, dur) + + _, ok = parseOpsDuration("invalid") + require.False(t, ok) +} + +func TestParseOpsTimeRange(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + now := time.Now().UTC() + startStr := now.Add(-time.Hour).Format(time.RFC3339) + endStr := now.Format(time.RFC3339) + c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil) + start, end, err := parseOpsTimeRange(c, "1h") + require.NoError(t, err) + require.True(t, start.Before(end)) + + c2, _ := gin.CreateTestContext(w) + c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil) + _, _, err = parseOpsTimeRange(c2, "1h") + require.Error(t, err) +} + +func TestParseOpsRealtimeWindow(t *testing.T) { + dur, label, ok := parseOpsRealtimeWindow("5m") + require.True(t, ok) + require.Equal(t, 5*time.Minute, dur) + require.Equal(t, "5min", label) + + _, _, ok = parseOpsRealtimeWindow("invalid") + require.False(t, ok) +} + +func TestPickThroughputBucketSeconds(t *testing.T) { + require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute)) + require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour)) + require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour)) +} + +func TestParseOpsQueryMode(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil) + require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c)) + require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil)) +} + +func TestOpsAlertRuleValidation(t *testing.T) { + raw := map[string]json.RawMessage{ + "name": json.RawMessage(`"High error rate"`), + "metric_type": json.RawMessage(`"error_rate"`), + "operator": json.RawMessage(`">"`), + "threshold": json.RawMessage(`90`), + } + + validated, err := validateOpsAlertRulePayload(raw) + require.NoError(t, err) + require.Equal(t, "High error rate", validated.Name) + + _, err = validateOpsAlertRulePayload(map[string]json.RawMessage{}) + require.Error(t, err) + + require.True(t, isPercentOrRateMetric("error_rate")) + require.False(t, isPercentOrRateMetric("concurrency_queue_depth")) +} + +func TestOpsWSHelpers(t *testing.T) { + prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid") + require.Len(t, prefixes, 1) + require.Len(t, invalid, 1) + + host := hostWithoutPort("example.com:443") + require.Equal(t, "example.com", host) + + addr := netip.MustParseAddr("10.0.0.1") + require.True(t, isAddrInTrustedProxies(addr, prefixes)) + require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes)) +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go new file mode 100644 index 00000000..457d52fc --- /dev/null +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -0,0 +1,290 @@ +package admin + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type stubAdminService struct { + users []service.User + apiKeys []service.APIKey + groups []service.Group + accounts []service.Account + proxies []service.Proxy + proxyCounts []service.ProxyWithAccountCount + redeems []service.RedeemCode +} + +func newStubAdminService() *stubAdminService { + now := time.Now().UTC() + user := service.User{ + ID: 1, + Email: "user@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + apiKey := service.APIKey{ + ID: 10, + UserID: user.ID, + Key: "sk-test", + Name: "test", + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + group := service.Group{ + ID: 2, + Name: "group", + Platform: service.PlatformAnthropic, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + account := service.Account{ + ID: 3, + Name: "account", + Platform: service.PlatformAnthropic, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + proxy := service.Proxy{ + ID: 4, + Name: "proxy", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + redeem := service.RedeemCode{ + ID: 5, + Code: "R-TEST", + Type: service.RedeemTypeBalance, + Value: 10, + Status: service.StatusUnused, + CreatedAt: now, + } + return &stubAdminService{ + users: []service.User{user}, + apiKeys: []service.APIKey{apiKey}, + groups: []service.Group{group}, + accounts: []service.Account{account}, + proxies: []service.Proxy{proxy}, + proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}}, + redeems: []service.RedeemCode{redeem}, + } +} + +func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) { + return s.users, int64(len(s.users)), nil +} + +func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) { + for i := range s.users { + if s.users[i].ID == id { + return &s.users[i], nil + } + } + user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) { + user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) { + user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) { + user := service.User{ID: userID, Balance: balance, Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) { + return s.apiKeys, int64(len(s.apiKeys)), nil +} + +func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { + return map[string]any{"user_id": userID}, nil +} + +func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) { + return s.groups, int64(len(s.groups)), nil +} + +func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) { + return s.groups, nil +} + +func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) { + return s.groups, nil +} + +func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) { + group := service.Group{ID: id, Name: "group", Status: service.StatusActive} + return &group, nil +} + +func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) { + group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive} + return &group, nil +} + +func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) { + group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive} + return &group, nil +} + +func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) { + return s.apiKeys, int64(len(s.apiKeys)), nil +} + +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { + return s.accounts, int64(len(s.accounts)), nil +} + +func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { + out := make([]*service.Account, 0, len(ids)) + for _, id := range ids { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + out = append(out, &account) + } + return out, nil +} + +func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) { + account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable} + return &account, nil +} + +func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) { + return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil +} + +func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { + return s.proxies, int64(len(s.proxies)), nil +} + +func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) { + return s.proxyCounts, int64(len(s.proxyCounts)), nil +} + +func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) { + return s.proxies, nil +} + +func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) { + return s.proxyCounts, nil +} + +func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) { + proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive} + return &proxy, nil +} + +func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) { + proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive} + return &proxy, nil +} + +func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) { + proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive} + return &proxy, nil +} + +func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) { + return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil +} + +func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) { + return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil +} + +func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { + return false, nil +} + +func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) { + return &service.ProxyTestResult{Success: true, Message: "ok"}, nil +} + +func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { + return s.redeems, int64(len(s.redeems)), nil +} + +func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) { + code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused} + return &code, nil +} + +func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) { + return s.redeems, nil +} + +func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { + return int64(len(ids)), nil +} + +func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) { + code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed} + return &code, nil +} + +// Ensure stub implements interface. +var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 3f07403d..18365186 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -186,7 +186,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) { // GetUsageTrend handles getting usage trend data // GET /api/v1/admin/dashboard/trend -// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream +// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { startTime, endTime := parseTimeRange(c) granularity := c.DefaultQuery("granularity", "day") @@ -195,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { var userID, apiKeyID, accountID, groupID int64 var model string var stream *bool + var billingType *int8 if userIDStr := c.Query("user_id"); userIDStr != "" { if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { @@ -224,8 +225,17 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { stream = &streamVal } } + if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { + if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil { + bt := int8(v) + billingType = &bt + } else { + response.BadRequest(c, "Invalid billing_type") + return + } + } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream) + trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return @@ -241,13 +251,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // GetModelStats handles getting model usage statistics // GET /api/v1/admin/dashboard/models -// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type func (h *DashboardHandler) GetModelStats(c *gin.Context) { startTime, endTime := parseTimeRange(c) // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 var stream *bool + var billingType *int8 if userIDStr := c.Query("user_id"); userIDStr != "" { if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { @@ -274,8 +285,17 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { stream = &streamVal } } + if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { + if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil { + bt := int8(v) + billingType = &bt + } else { + response.BadRequest(c, "Invalid billing_type") + return + } + } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream) + stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go new file mode 100644 index 00000000..d8684c39 --- /dev/null +++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go @@ -0,0 +1,377 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "database/sql" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type cleanupRepoStub struct { + mu sync.Mutex + created []*service.UsageCleanupTask + listTasks []service.UsageCleanupTask + listResult *pagination.PaginationResult + listErr error + statusByID map[int64]string +} + +func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error { + if task == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if task.ID == 0 { + task.ID = int64(len(s.created) + 1) + } + if task.CreatedAt.IsZero() { + task.CreatedAt = time.Now().UTC() + } + task.UpdatedAt = task.CreatedAt + clone := *task + s.created = append(s.created, &clone) + return nil +} + +func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.listTasks, s.listResult, s.listErr +} + +func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) { + return nil, nil +} + +func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statusByID == nil { + return "", sql.ErrNoRows + } + status, ok := s.statusByID[taskID] + if !ok { + return "", sql.ErrNoRows + } + return status, nil +} + +func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error { + return nil +} + +func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + status := s.statusByID[taskID] + if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning { + return false, nil + } + s.statusByID[taskID] = service.UsageCleanupStatusCanceled + return true, nil +} + +func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error { + return nil +} + +func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error { + return nil +} + +func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) { + return 0, nil +} + +var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil) + +func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + if userID > 0 { + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID}) + c.Next() + }) + } + + handler := NewUsageHandler(nil, nil, nil, cleanupService) + router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask) + router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks) + router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask) + return router +} + +func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 0) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) { + router := setupCleanupRouter(nil, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusServiceUnavailable, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json")) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "timezone": "UTC", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-13-01", + "end_date": "2024-01-02", + "timezone": "UTC", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-02-40", + "timezone": "UTC", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": " 2024-01-01 ", + "end_date": "2024-01-02", + "timezone": "UTC", + "model": "gpt-4", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp response.Response + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.Equal(t, int64(99), created.CreatedBy) + require.NotNil(t, created.Filters.Model) + require.Equal(t, "gpt-4", *created.Filters.Model) + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond) + require.True(t, created.Filters.StartTime.Equal(start)) + require.True(t, created.Filters.EndTime.Equal(end)) +} + +func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) { + router := setupCleanupRouter(nil, 0) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusServiceUnavailable, recorder.Code) +} + +func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) { + repo := &cleanupRepoStub{} + repo.listTasks = []service.UsageCleanupTask{ + { + ID: 7, + Status: service.UsageCleanupStatusSucceeded, + CreatedBy: 4, + }, + } + repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Items []dto.UsageCleanupTask `json:"items"` + Total int64 `json:"total"` + Page int `json:"page"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Items, 1) + require.Equal(t, int64(7), resp.Data.Items[0].ID) + require.Equal(t, int64(1), resp.Data.Total) + require.Equal(t, 1, resp.Data.Page) +} + +func TestUsageHandlerListCleanupTasksError(t *testing.T) { + repo := &cleanupRepoStub{listErr: errors.New("boom")} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) +} + +func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 0) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) { + repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) { + repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index c7b983f1..81aa78e1 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -1,7 +1,10 @@ package admin import ( + "log" + "net/http" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -9,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -16,9 +20,10 @@ import ( // UsageHandler handles admin usage-related requests type UsageHandler struct { - usageService *service.UsageService - apiKeyService *service.APIKeyService - adminService service.AdminService + usageService *service.UsageService + apiKeyService *service.APIKeyService + adminService service.AdminService + cleanupService *service.UsageCleanupService } // NewUsageHandler creates a new admin usage handler @@ -26,14 +31,30 @@ func NewUsageHandler( usageService *service.UsageService, apiKeyService *service.APIKeyService, adminService service.AdminService, + cleanupService *service.UsageCleanupService, ) *UsageHandler { return &UsageHandler{ - usageService: usageService, - apiKeyService: apiKeyService, - adminService: adminService, + usageService: usageService, + apiKeyService: apiKeyService, + adminService: adminService, + cleanupService: cleanupService, } } +// CreateUsageCleanupTaskRequest represents cleanup task creation request +type CreateUsageCleanupTaskRequest struct { + StartDate string `json:"start_date"` + EndDate string `json:"end_date"` + UserID *int64 `json:"user_id"` + APIKeyID *int64 `json:"api_key_id"` + AccountID *int64 `json:"account_id"` + GroupID *int64 `json:"group_id"` + Model *string `json:"model"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` + Timezone string `json:"timezone"` +} + // List handles listing all usage records with filters // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { @@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) { response.Success(c, result) } + +// ListCleanupTasks handles listing usage cleanup tasks +// GET /api/v1/admin/usage/cleanup-tasks +func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { + if h.cleanupService == nil { + response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") + return + } + operator := int64(0) + if subject, ok := middleware.GetAuthSubjectFromContext(c); ok { + operator = subject.UserID + } + page, pageSize := response.ParsePagination(c) + log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params) + if err != nil { + log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) + response.ErrorFrom(c, err) + return + } + out := make([]dto.UsageCleanupTask, 0, len(tasks)) + for i := range tasks { + out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i])) + } + log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) + response.Paginated(c, out, result.Total, page, pageSize) +} + +// CreateCleanupTask handles creating a usage cleanup task +// POST /api/v1/admin/usage/cleanup-tasks +func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { + if h.cleanupService == nil { + response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") + return + } + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + var req CreateUsageCleanupTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + req.StartDate = strings.TrimSpace(req.StartDate) + req.EndDate = strings.TrimSpace(req.EndDate) + if req.StartDate == "" || req.EndDate == "" { + response.BadRequest(c, "start_date and end_date are required") + return + } + + startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone) + if err != nil { + response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") + return + } + endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone) + if err != nil { + response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") + return + } + endTime = endTime.Add(24*time.Hour - time.Nanosecond) + + filters := service.UsageCleanupFilters{ + StartTime: startTime, + EndTime: endTime, + UserID: req.UserID, + APIKeyID: req.APIKeyID, + AccountID: req.AccountID, + GroupID: req.GroupID, + Model: req.Model, + Stream: req.Stream, + BillingType: req.BillingType, + } + + var userID any + if filters.UserID != nil { + userID = *filters.UserID + } + var apiKeyID any + if filters.APIKeyID != nil { + apiKeyID = *filters.APIKeyID + } + var accountID any + if filters.AccountID != nil { + accountID = *filters.AccountID + } + var groupID any + if filters.GroupID != nil { + groupID = *filters.GroupID + } + var model any + if filters.Model != nil { + model = *filters.Model + } + var stream any + if filters.Stream != nil { + stream = *filters.Stream + } + var billingType any + if filters.BillingType != nil { + billingType = *filters.BillingType + } + + log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", + subject.UserID, + filters.StartTime.Format(time.RFC3339), + filters.EndTime.Format(time.RFC3339), + userID, + apiKeyID, + accountID, + groupID, + model, + stream, + billingType, + req.Timezone, + ) + + task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) + if err != nil { + log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) + response.ErrorFrom(c, err) + return + } + + log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) + response.Success(c, dto.UsageCleanupTaskFromService(task)) +} + +// CancelCleanupTask handles canceling a usage cleanup task +// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel +func (h *UsageHandler) CancelCleanupTask(c *gin.Context) { + if h.cleanupService == nil { + response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") + return + } + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + idStr := strings.TrimSpace(c.Param("id")) + taskID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || taskID <= 0 { + response.BadRequest(c, "Invalid task id") + return + } + log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) + if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil { + log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) + response.ErrorFrom(c, err) + return + } + log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) + response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled}) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 4d59ddff..f43fac27 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -340,6 +340,36 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog { return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true) } +func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask { + if task == nil { + return nil + } + return &UsageCleanupTask{ + ID: task.ID, + Status: task.Status, + Filters: UsageCleanupFilters{ + StartTime: task.Filters.StartTime, + EndTime: task.Filters.EndTime, + UserID: task.Filters.UserID, + APIKeyID: task.Filters.APIKeyID, + AccountID: task.Filters.AccountID, + GroupID: task.Filters.GroupID, + Model: task.Filters.Model, + Stream: task.Filters.Stream, + BillingType: task.Filters.BillingType, + }, + CreatedBy: task.CreatedBy, + DeletedRows: task.DeletedRows, + ErrorMessage: task.ErrorMsg, + CanceledBy: task.CanceledBy, + CanceledAt: task.CanceledAt, + StartedAt: task.StartedAt, + FinishedAt: task.FinishedAt, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, + } +} + func SettingFromService(s *service.Setting) *Setting { if s == nil { return nil diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 914f2b23..5fa5a3fd 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -223,6 +223,33 @@ type UsageLog struct { Subscription *UserSubscription `json:"subscription,omitempty"` } +type UsageCleanupFilters struct { + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + UserID *int64 `json:"user_id,omitempty"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + AccountID *int64 `json:"account_id,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` + Model *string `json:"model,omitempty"` + Stream *bool `json:"stream,omitempty"` + BillingType *int8 `json:"billing_type,omitempty"` +} + +type UsageCleanupTask struct { + ID int64 `json:"id"` + Status string `json:"status"` + Filters UsageCleanupFilters `json:"filters"` + CreatedBy int64 `json:"created_by"` + DeletedRows int64 `json:"deleted_rows"` + ErrorMessage *string `json:"error_message,omitempty"` + CanceledBy *int64 `json:"canceled_by,omitempty"` + CanceledAt *time.Time `json:"canceled_at,omitempty"` + StartedAt *time.Time `json:"started_at,omitempty"` + FinishedAt *time.Time `json:"finished_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + // AccountSummary is a minimal account info for usage log display. // It intentionally excludes sensitive fields like Credentials, Proxy, etc. type AccountSummary struct { diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index 3543e061..59bbd6a3 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -77,6 +77,75 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta return nil } +func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error { + if r == nil || r.sql == nil { + return nil + } + loc := timezone.Location() + startLocal := start.In(loc) + endLocal := end.In(loc) + if !endLocal.After(startLocal) { + return nil + } + + hourStart := startLocal.Truncate(time.Hour) + hourEnd := endLocal.Truncate(time.Hour) + if endLocal.After(hourEnd) { + hourEnd = hourEnd.Add(time.Hour) + } + + dayStart := truncateToDay(startLocal) + dayEnd := truncateToDay(endLocal) + if endLocal.After(dayEnd) { + dayEnd = dayEnd.Add(24 * time.Hour) + } + + // 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。 + if db, ok := r.sql.(*sql.DB); ok { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + txRepo := newDashboardAggregationRepositoryWithSQL(tx) + if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() + } + return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd) +} + +func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error { + // 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。 + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil { + return err + } + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil { + return err + } + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil { + return err + } + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil { + return err + } + + if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil { + return err + } + return nil +} + func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) { var ts time.Time query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1" diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go new file mode 100644 index 00000000..b703cc9f --- /dev/null +++ b/backend/internal/repository/usage_cleanup_repo.go @@ -0,0 +1,363 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type usageCleanupRepository struct { + sql sqlExecutor +} + +func NewUsageCleanupRepository(sqlDB *sql.DB) service.UsageCleanupRepository { + return &usageCleanupRepository{sql: sqlDB} +} + +func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error { + if task == nil { + return nil + } + filtersJSON, err := json.Marshal(task.Filters) + if err != nil { + return fmt.Errorf("marshal cleanup filters: %w", err) + } + query := ` + INSERT INTO usage_cleanup_tasks ( + status, + filters, + created_by, + deleted_rows + ) VALUES ($1, $2, $3, $4) + RETURNING id, created_at, updated_at + ` + if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil { + return err + } + return nil +} + +func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) { + var total int64 + if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil { + return nil, nil, err + } + if total == 0 { + return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil + } + + query := ` + SELECT id, status, filters, created_by, deleted_rows, error_message, + canceled_by, canceled_at, + started_at, finished_at, created_at, updated_at + FROM usage_cleanup_tasks + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + ` + rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset()) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + tasks := make([]service.UsageCleanupTask, 0) + for rows.Next() { + var task service.UsageCleanupTask + var filtersJSON []byte + var errMsg sql.NullString + var canceledBy sql.NullInt64 + var canceledAt sql.NullTime + var startedAt sql.NullTime + var finishedAt sql.NullTime + if err := rows.Scan( + &task.ID, + &task.Status, + &filtersJSON, + &task.CreatedBy, + &task.DeletedRows, + &errMsg, + &canceledBy, + &canceledAt, + &startedAt, + &finishedAt, + &task.CreatedAt, + &task.UpdatedAt, + ); err != nil { + return nil, nil, err + } + if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil { + return nil, nil, fmt.Errorf("parse cleanup filters: %w", err) + } + if errMsg.Valid { + task.ErrorMsg = &errMsg.String + } + if canceledBy.Valid { + v := canceledBy.Int64 + task.CanceledBy = &v + } + if canceledAt.Valid { + task.CanceledAt = &canceledAt.Time + } + if startedAt.Valid { + task.StartedAt = &startedAt.Time + } + if finishedAt.Valid { + task.FinishedAt = &finishedAt.Time + } + tasks = append(tasks, task) + } + if err := rows.Err(); err != nil { + return nil, nil, err + } + return tasks, paginationResultFromTotal(total, params), nil +} + +func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) { + if staleRunningAfterSeconds <= 0 { + staleRunningAfterSeconds = 1800 + } + query := ` + WITH next AS ( + SELECT id + FROM usage_cleanup_tasks + WHERE status = $1 + OR ( + status = $2 + AND started_at IS NOT NULL + AND started_at < NOW() - ($3 * interval '1 second') + ) + ORDER BY created_at ASC + LIMIT 1 + FOR UPDATE SKIP LOCKED + ) + UPDATE usage_cleanup_tasks + SET status = $4, + started_at = NOW(), + finished_at = NULL, + error_message = NULL, + updated_at = NOW() + FROM next + WHERE usage_cleanup_tasks.id = next.id + RETURNING id, status, filters, created_by, deleted_rows, error_message, + started_at, finished_at, created_at, updated_at + ` + var task service.UsageCleanupTask + var filtersJSON []byte + var errMsg sql.NullString + var startedAt sql.NullTime + var finishedAt sql.NullTime + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{ + service.UsageCleanupStatusPending, + service.UsageCleanupStatusRunning, + staleRunningAfterSeconds, + service.UsageCleanupStatusRunning, + }, + &task.ID, + &task.Status, + &filtersJSON, + &task.CreatedBy, + &task.DeletedRows, + &errMsg, + &startedAt, + &finishedAt, + &task.CreatedAt, + &task.UpdatedAt, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil { + return nil, fmt.Errorf("parse cleanup filters: %w", err) + } + if errMsg.Valid { + task.ErrorMsg = &errMsg.String + } + if startedAt.Valid { + task.StartedAt = &startedAt.Time + } + if finishedAt.Valid { + task.FinishedAt = &finishedAt.Time + } + return &task, nil +} + +func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) { + var status string + if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil { + return "", err + } + return status, nil +} + +func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error { + query := ` + UPDATE usage_cleanup_tasks + SET deleted_rows = $1, + updated_at = NOW() + WHERE id = $2 + ` + _, err := r.sql.ExecContext(ctx, query, deletedRows, taskID) + return err +} + +func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) { + query := ` + UPDATE usage_cleanup_tasks + SET status = $1, + canceled_by = $3, + canceled_at = NOW(), + finished_at = NOW(), + error_message = NULL, + updated_at = NOW() + WHERE id = $2 + AND status IN ($4, $5) + RETURNING id + ` + var id int64 + err := scanSingleRow(ctx, r.sql, query, []any{ + service.UsageCleanupStatusCanceled, + taskID, + canceledBy, + service.UsageCleanupStatusPending, + service.UsageCleanupStatusRunning, + }, &id) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + return true, nil +} + +func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error { + query := ` + UPDATE usage_cleanup_tasks + SET status = $1, + deleted_rows = $2, + finished_at = NOW(), + updated_at = NOW() + WHERE id = $3 + ` + _, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID) + return err +} + +func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error { + query := ` + UPDATE usage_cleanup_tasks + SET status = $1, + deleted_rows = $2, + error_message = $3, + finished_at = NOW(), + updated_at = NOW() + WHERE id = $4 + ` + _, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID) + return err +} + +func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) { + if filters.StartTime.IsZero() || filters.EndTime.IsZero() { + return 0, fmt.Errorf("cleanup filters missing time range") + } + whereClause, args := buildUsageCleanupWhere(filters) + if whereClause == "" { + return 0, fmt.Errorf("cleanup filters missing time range") + } + args = append(args, limit) + query := fmt.Sprintf(` + WITH target AS ( + SELECT id + FROM usage_logs + WHERE %s + ORDER BY created_at ASC, id ASC + LIMIT $%d + ) + DELETE FROM usage_logs + WHERE id IN (SELECT id FROM target) + RETURNING id + `, whereClause, len(args)) + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + defer rows.Close() + + var deleted int64 + for rows.Next() { + deleted++ + } + if err := rows.Err(); err != nil { + return 0, err + } + return deleted, nil +} + +func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) { + conditions := make([]string, 0, 8) + args := make([]any, 0, 8) + idx := 1 + if !filters.StartTime.IsZero() { + conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx)) + args = append(args, filters.StartTime) + idx++ + } + if !filters.EndTime.IsZero() { + conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx)) + args = append(args, filters.EndTime) + idx++ + } + if filters.UserID != nil { + conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx)) + args = append(args, *filters.UserID) + idx++ + } + if filters.APIKeyID != nil { + conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx)) + args = append(args, *filters.APIKeyID) + idx++ + } + if filters.AccountID != nil { + conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx)) + args = append(args, *filters.AccountID) + idx++ + } + if filters.GroupID != nil { + conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx)) + args = append(args, *filters.GroupID) + idx++ + } + if filters.Model != nil { + model := strings.TrimSpace(*filters.Model) + if model != "" { + conditions = append(conditions, fmt.Sprintf("model = $%d", idx)) + args = append(args, model) + idx++ + } + } + if filters.Stream != nil { + conditions = append(conditions, fmt.Sprintf("stream = $%d", idx)) + args = append(args, *filters.Stream) + idx++ + } + if filters.BillingType != nil { + conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx)) + args = append(args, *filters.BillingType) + idx++ + } + return strings.Join(conditions, " AND "), args +} diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go new file mode 100644 index 00000000..e5582709 --- /dev/null +++ b/backend/internal/repository/usage_cleanup_repo_test.go @@ -0,0 +1,440 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp)) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +func TestNewUsageCleanupRepository(t *testing.T) { + db, _ := newSQLMock(t) + repo := NewUsageCleanupRepository(db) + require.NotNil(t, repo) +} + +func TestUsageCleanupRepositoryCreateTask(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end}, + CreatedBy: 12, + } + now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + + mock.ExpectQuery("INSERT INTO usage_cleanup_tasks"). + WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now)) + + err := repo.CreateTask(context.Background(), task) + require.NoError(t, err) + require.Equal(t, int64(1), task.ID) + require.Equal(t, now, task.CreatedAt) + require.Equal(t, now, task.UpdatedAt) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + err := repo.CreateTask(context.Background(), nil) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)}, + CreatedBy: 1, + } + + mock.ExpectQuery("INSERT INTO usage_cleanup_tasks"). + WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows). + WillReturnError(sql.ErrConnDone) + + err := repo.CreateTask(context.Background(), task) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + + tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Empty(t, tasks) + require.Equal(t, int64(0), result.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryListTasks(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(2 * time.Hour) + filters := service.UsageCleanupFilters{StartTime: start, EndTime: end} + filtersJSON, err := json.Marshal(filters) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC) + updatedAt := createdAt.Add(time.Minute) + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "canceled_by", "canceled_at", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(1), + service.UsageCleanupStatusSucceeded, + filtersJSON, + int64(2), + int64(9), + "error", + nil, + nil, + start, + end, + createdAt, + updatedAt, + ) + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1))) + mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message"). + WithArgs(20, 0). + WillReturnRows(rows) + + tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Len(t, tasks, 1) + require.Equal(t, int64(1), tasks[0].ID) + require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status) + require.Equal(t, int64(2), tasks[0].CreatedBy) + require.Equal(t, int64(9), tasks[0].DeletedRows) + require.NotNil(t, tasks[0].ErrorMsg) + require.Equal(t, "error", *tasks[0].ErrorMsg) + require.NotNil(t, tasks[0].StartedAt) + require.NotNil(t, tasks[0].FinishedAt) + require.Equal(t, int64(1), result.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "canceled_by", "canceled_at", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(1), + service.UsageCleanupStatusSucceeded, + []byte("not-json"), + int64(2), + int64(9), + nil, + nil, + nil, + nil, + nil, + time.Now().UTC(), + time.Now().UTC(), + ) + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1))) + mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message"). + WithArgs(20, 0). + WillReturnRows(rows) + + _, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnRows(sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "started_at", "finished_at", "created_at", "updated_at", + })) + + task, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.NoError(t, err) + require.Nil(t, task) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + filters := service.UsageCleanupFilters{StartTime: start, EndTime: end} + filtersJSON, err := json.Marshal(filters) + require.NoError(t, err) + + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(4), + service.UsageCleanupStatusRunning, + filtersJSON, + int64(7), + int64(0), + nil, + start, + nil, + start, + start, + ) + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnRows(rows) + + task, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, int64(4), task.ID) + require.Equal(t, service.UsageCleanupStatusRunning, task.Status) + require.Equal(t, int64(7), task.CreatedBy) + require.NotNil(t, task.StartedAt) + require.Nil(t, task.ErrorMsg) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnError(sql.ErrConnDone) + + _, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(4), + service.UsageCleanupStatusRunning, + []byte("invalid"), + int64(7), + int64(0), + nil, + nil, + nil, + time.Now().UTC(), + time.Now().UTC(), + ) + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnRows(rows) + + _, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectExec("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := repo.MarkTaskSucceeded(context.Background(), 9, 12) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectExec("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom") + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks"). + WithArgs(int64(9)). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending)) + + status, err := repo.GetTaskStatus(context.Background(), 9) + require.NoError(t, err) + require.Equal(t, service.UsageCleanupStatusPending, status) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectExec("UPDATE usage_cleanup_tasks"). + WithArgs(int64(123), int64(8)). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := repo.UpdateTaskProgress(context.Background(), 8, 123) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryCancelTask(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6))) + + ok, err := repo.CancelTask(context.Background(), 6, 9) + require.NoError(t, err) + require.True(t, ok) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) { + db, _ := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + _, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10) + require.Error(t, err) +} + +func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + userID := int64(3) + model := " gpt-4 " + filters := service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + UserID: &userID, + Model: &model, + } + + mock.ExpectQuery("DELETE FROM usage_logs"). + WithArgs(start, end, userID, "gpt-4", 2). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2))) + + deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2) + require.NoError(t, err) + require.Equal(t, int64(2), deleted) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + filters := service.UsageCleanupFilters{StartTime: start, EndTime: end} + + mock.ExpectQuery("DELETE FROM usage_logs"). + WithArgs(start, end, 5). + WillReturnError(sql.ErrConnDone) + + _, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestBuildUsageCleanupWhere(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + userID := int64(1) + apiKeyID := int64(2) + accountID := int64(3) + groupID := int64(4) + model := " gpt-4 " + stream := true + billingType := int8(2) + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + UserID: &userID, + APIKeyID: &apiKeyID, + AccountID: &accountID, + GroupID: &groupID, + Model: &model, + Stream: &stream, + BillingType: &billingType, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where) + require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args) +} + +func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + model := " " + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + Model: &model, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2", where) + require.Equal(t, []any{start, end}, args) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 4a2aaade..963db7ba 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) { +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { dateFormat := "YYYY-MM-DD" if granularity == "hour" { dateFormat = "YYYY-MM-DD HH24:00" @@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND stream = $%d", len(args)+1) args = append(args, *stream) } + if billingType != nil { + query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } query += " GROUP BY date ORDER BY date ASC" rows, err := r.sql.QueryContext(ctx, query, args...) @@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start } // GetModelStatsWithFilters returns model statistics with optional filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { @@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND stream = $%d", len(args)+1) args = append(args, *stream) } + if billingType != nil { + query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } query += " GROUP BY model ORDER BY total_tokens DESC" rows, err := r.sql.QueryContext(ctx, query, args...) @@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID } } - models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil) + models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil) if err != nil { models = []ModelStat{} } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 7174be18..eb220f22 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { endTime := base.Add(48 * time.Hour) // Test with user filter - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().Len(trend, 2) // Test with apiKey filter - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().Len(trend, 2) // Test with both filters - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().Len(trend, 2) } @@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(3 * time.Hour) - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().Len(trend, 2) } @@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { endTime := base.Add(2 * time.Hour) // Test with user filter - stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil) + stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().Len(stats, 2) // Test with apiKey filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().Len(stats, 2) // Test with account filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().Len(stats, 2) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 91ef9413..9dc91eca 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet( NewRedeemCodeRepository, NewPromoCodeRepository, NewUsageLogRepository, + NewUsageCleanupRepository, NewDashboardAggregationRepository, NewSettingRepository, NewOpsRepository, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 7971c65f..7076f8c5 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1242,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { +func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { +func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index ff05b32a..050e724d 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -354,6 +354,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) { usage.GET("/stats", h.Admin.Usage.Stats) usage.GET("/search-users", h.Admin.Usage.SearchUsers) usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys) + usage.GET("/cleanup-tasks", h.Admin.Usage.ListCleanupTasks) + usage.POST("/cleanup-tasks", h.Admin.Usage.CreateCleanupTask) + usage.POST("/cleanup-tasks/:id/cancel", h.Admin.Usage.CancelCleanupTask) } } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index d9ed5609..f1c07d5e 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -32,8 +32,8 @@ type UsageLogRepository interface { // Admin dashboard stats GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) - GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) - GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) + GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) + GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) @@ -272,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou } dayStart := geminiDailyWindowStart(now) - stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil) + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini usage stats failed: %w", err) } @@ -294,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) minuteStart := now.Truncate(time.Minute) minuteResetAt := minuteStart.Add(time.Minute) - minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil) + minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) } diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index da5c0e7d..8f7e8144 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -21,11 +21,15 @@ var ( ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") // ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。 ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") + errDashboardAggregationRunning = errors.New("聚合作业正在运行") ) // DashboardAggregationRepository 定义仪表盘预聚合仓储接口。 type DashboardAggregationRepository interface { AggregateRange(ctx context.Context, start, end time.Time) error + // RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。 + // 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。 + RecomputeRange(ctx context.Context, start, end time.Time) error GetAggregationWatermark(ctx context.Context) (time.Time, error) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error @@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro return nil } +// TriggerRecomputeRange 触发指定范围的重新计算(异步)。 +// 与 TriggerBackfill 不同: +// - 不依赖 backfill_enabled(这是内部一致性修复) +// - 不更新 watermark(避免影响正常增量聚合游标) +func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error { + if s == nil || s.repo == nil { + return errors.New("聚合服务未初始化") + } + if !s.cfg.Enabled { + return errors.New("聚合服务已禁用") + } + if !end.After(start) { + return errors.New("重新计算时间范围无效") + } + + go func() { + const maxRetries = 3 + for i := 0; i < maxRetries; i++ { + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) + err := s.recomputeRange(ctx, start, end) + cancel() + if err == nil { + return + } + if !errors.Is(err, errDashboardAggregationRunning) { + log.Printf("[DashboardAggregation] 重新计算失败: %v", err) + return + } + time.Sleep(5 * time.Second) + } + log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") + }() + return nil +} + func (s *DashboardAggregationService) recomputeRecentDays() { days := s.cfg.RecomputeDays if days <= 0 { @@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() { } } +func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error { + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + return errDashboardAggregationRunning + } + defer atomic.StoreInt32(&s.running, 0) + + jobStart := time.Now().UTC() + if err := s.repo.RecomputeRange(ctx, start, end); err != nil { + return err + } + log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", + start.UTC().Format(time.RFC3339), + end.UTC().Format(time.RFC3339), + time.Since(jobStart).String(), + ) + return nil +} + func (s *DashboardAggregationService) runScheduledAggregation() { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { return @@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() { func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { - return errors.New("聚合作业正在运行") + return errDashboardAggregationRunning } defer atomic.StoreInt32(&s.running, 0) diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go index 2fc22105..a7058985 100644 --- a/backend/internal/service/dashboard_aggregation_service_test.go +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s return s.aggregateErr } +func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + return s.AggregateRange(ctx, start, end) +} + func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { return s.watermark, nil } diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index a9811919..cd11923e 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return stats, nil } -func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { - trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream) +func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) if err != nil { return nil, fmt.Errorf("get usage trend with filters: %w", err) } return trend, nil } -func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream) +func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) if err != nil { return nil, fmt.Errorf("get model stats with filters: %w", err) } diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index db3c78c3..59b83e66 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start return nil } +func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + return nil +} + func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { if s.err != nil { return time.Time{}, s.err diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 47a04cf5..2d75dd5a 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -190,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, start := geminiDailyWindowStart(now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now) if !ok { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return true, err } @@ -237,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, if limit > 0 { start := now.Truncate(time.Minute) - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return true, err } diff --git a/backend/internal/service/usage_cleanup.go b/backend/internal/service/usage_cleanup.go new file mode 100644 index 00000000..7e3ffbb9 --- /dev/null +++ b/backend/internal/service/usage_cleanup.go @@ -0,0 +1,74 @@ +package service + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +const ( + UsageCleanupStatusPending = "pending" + UsageCleanupStatusRunning = "running" + UsageCleanupStatusSucceeded = "succeeded" + UsageCleanupStatusFailed = "failed" + UsageCleanupStatusCanceled = "canceled" +) + +// UsageCleanupFilters 定义清理任务过滤条件 +// 时间范围为必填,其他字段可选 +// JSON 序列化用于存储任务参数 +// +// start_time/end_time 使用 RFC3339 时间格式 +// 以 UTC 或用户时区解析后的时间为准 +// +// 说明: +// - nil 表示未设置该过滤条件 +// - 过滤条件均为精确匹配 +type UsageCleanupFilters struct { + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + UserID *int64 `json:"user_id,omitempty"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + AccountID *int64 `json:"account_id,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` + Model *string `json:"model,omitempty"` + Stream *bool `json:"stream,omitempty"` + BillingType *int8 `json:"billing_type,omitempty"` +} + +// UsageCleanupTask 表示使用记录清理任务 +// 状态包含 pending/running/succeeded/failed/canceled +type UsageCleanupTask struct { + ID int64 + Status string + Filters UsageCleanupFilters + CreatedBy int64 + DeletedRows int64 + ErrorMsg *string + CanceledBy *int64 + CanceledAt *time.Time + StartedAt *time.Time + FinishedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +// UsageCleanupRepository 定义清理任务持久层接口 +type UsageCleanupRepository interface { + CreateTask(ctx context.Context, task *UsageCleanupTask) error + ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) + // ClaimNextPendingTask 抢占下一条可执行任务: + // - 优先 pending + // - 若 running 超过 staleRunningAfterSeconds(可能由于进程退出/崩溃/超时),允许重新抢占继续执行 + ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) + // GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows + GetTaskStatus(ctx context.Context, taskID int64) (string, error) + // UpdateTaskProgress 更新任务进度(deleted_rows)用于断点续跑/展示 + UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error + // CancelTask 将任务标记为 canceled(仅允许 pending/running) + CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) + MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error + MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error + DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) +} diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go new file mode 100644 index 00000000..8ca02cfc --- /dev/null +++ b/backend/internal/service/usage_cleanup_service.go @@ -0,0 +1,400 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +const ( + usageCleanupWorkerName = "usage_cleanup_worker" +) + +// UsageCleanupService 负责创建与执行使用记录清理任务 +type UsageCleanupService struct { + repo UsageCleanupRepository + timingWheel *TimingWheelService + dashboard *DashboardAggregationService + cfg *config.Config + + running int32 + startOnce sync.Once + stopOnce sync.Once + + workerCtx context.Context + workerCancel context.CancelFunc +} + +func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService { + workerCtx, workerCancel := context.WithCancel(context.Background()) + return &UsageCleanupService{ + repo: repo, + timingWheel: timingWheel, + dashboard: dashboard, + cfg: cfg, + workerCtx: workerCtx, + workerCancel: workerCancel, + } +} + +func describeUsageCleanupFilters(filters UsageCleanupFilters) string { + var parts []string + parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339)) + parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339)) + if filters.UserID != nil { + parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID)) + } + if filters.APIKeyID != nil { + parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID)) + } + if filters.AccountID != nil { + parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID)) + } + if filters.GroupID != nil { + parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID)) + } + if filters.Model != nil { + parts = append(parts, "model="+strings.TrimSpace(*filters.Model)) + } + if filters.Stream != nil { + parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream)) + } + if filters.BillingType != nil { + parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType)) + } + return strings.Join(parts, " ") +} + +func (s *UsageCleanupService) Start() { + if s == nil { + return + } + if s.cfg != nil && !s.cfg.UsageCleanup.Enabled { + log.Printf("[UsageCleanup] not started (disabled)") + return + } + if s.repo == nil || s.timingWheel == nil { + log.Printf("[UsageCleanup] not started (missing deps)") + return + } + + interval := s.workerInterval() + s.startOnce.Do(func() { + s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce) + log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout()) + }) +} + +func (s *UsageCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.workerCancel != nil { + s.workerCancel() + } + if s.timingWheel != nil { + s.timingWheel.Cancel(usageCleanupWorkerName) + } + log.Printf("[UsageCleanup] stopped") + }) +} + +func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) { + if s == nil || s.repo == nil { + return nil, nil, fmt.Errorf("cleanup service not ready") + } + return s.repo.ListTasks(ctx, params) +} + +func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) { + if s == nil || s.repo == nil { + return nil, fmt.Errorf("cleanup service not ready") + } + if s.cfg != nil && !s.cfg.UsageCleanup.Enabled { + return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled") + } + if createdBy <= 0 { + return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator") + } + + log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters)) + sanitizeUsageCleanupFilters(&filters) + if err := s.validateFilters(filters); err != nil { + log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters)) + return nil, err + } + + task := &UsageCleanupTask{ + Status: UsageCleanupStatusPending, + Filters: filters, + CreatedBy: createdBy, + } + if err := s.repo.CreateTask(ctx, task); err != nil { + log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters)) + return nil, fmt.Errorf("create cleanup task: %w", err) + } + log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters)) + go s.runOnce() + return task, nil +} + +func (s *UsageCleanupService) runOnce() { + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + log.Printf("[UsageCleanup] run_once skipped: already_running=true") + return + } + defer atomic.StoreInt32(&s.running, 0) + + parent := context.Background() + if s != nil && s.workerCtx != nil { + parent = s.workerCtx + } + ctx, cancel := context.WithTimeout(parent, s.taskTimeout()) + defer cancel() + + task, err := s.repo.ClaimNextPendingTask(ctx, int64(s.taskTimeout().Seconds())) + if err != nil { + log.Printf("[UsageCleanup] claim pending task failed: %v", err) + return + } + if task == nil { + log.Printf("[UsageCleanup] run_once done: no_task=true") + return + } + + log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters)) + s.executeTask(ctx, task) +} + +func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) { + if task == nil { + return + } + + batchSize := s.batchSize() + deletedTotal := task.DeletedRows + start := time.Now() + log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters)) + var batchNum int + + for { + if ctx != nil && ctx.Err() != nil { + log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err()) + return + } + canceled, err := s.isTaskCanceled(ctx, task.ID) + if err != nil { + s.markTaskFailed(task.ID, deletedTotal, err) + return + } + if canceled { + log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start)) + return + } + + batchNum++ + deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + // 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。 + log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err) + return + } + s.markTaskFailed(task.ID, deletedTotal, err) + return + } + deletedTotal += deleted + if deleted > 0 { + updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil { + log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err) + } + cancel() + } + if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) { + log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal) + } + if deleted == 0 || deleted < int64(batchSize) { + break + } + } + + updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil { + log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err) + } else { + log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start)) + } + + if s.dashboard != nil { + if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil { + log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err) + } else { + log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339)) + } + } +} + +func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) { + msg := strings.TrimSpace(err.Error()) + if len(msg) > 500 { + msg = msg[:500] + } + log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil { + log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr) + } +} + +func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) { + if s == nil || s.repo == nil { + return false, fmt.Errorf("cleanup service not ready") + } + checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + status, err := s.repo.GetTaskStatus(checkCtx, taskID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + return false, err + } + if status == UsageCleanupStatusCanceled { + log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID) + } + return status == UsageCleanupStatusCanceled, nil +} + +func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error { + if filters.StartTime.IsZero() || filters.EndTime.IsZero() { + return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required") + } + if filters.EndTime.Before(filters.StartTime) { + return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date") + } + maxDays := s.maxRangeDays() + if maxDays > 0 { + delta := filters.EndTime.Sub(filters.StartTime) + if delta > time.Duration(maxDays)*24*time.Hour { + return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays)) + } + } + return nil +} + +func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error { + if s == nil || s.repo == nil { + return fmt.Errorf("cleanup service not ready") + } + if s.cfg != nil && !s.cfg.UsageCleanup.Enabled { + return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled") + } + if canceledBy <= 0 { + return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller") + } + status, err := s.repo.GetTaskStatus(ctx, taskID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found") + } + return err + } + log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status) + if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning { + return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") + } + ok, err := s.repo.CancelTask(ctx, taskID, canceledBy) + if err != nil { + return err + } + if !ok { + // 状态可能并发改变 + return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") + } + log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy) + return nil +} + +func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) { + if filters == nil { + return + } + if filters.UserID != nil && *filters.UserID <= 0 { + filters.UserID = nil + } + if filters.APIKeyID != nil && *filters.APIKeyID <= 0 { + filters.APIKeyID = nil + } + if filters.AccountID != nil && *filters.AccountID <= 0 { + filters.AccountID = nil + } + if filters.GroupID != nil && *filters.GroupID <= 0 { + filters.GroupID = nil + } + if filters.Model != nil { + model := strings.TrimSpace(*filters.Model) + if model == "" { + filters.Model = nil + } else { + filters.Model = &model + } + } + if filters.BillingType != nil && *filters.BillingType < 0 { + filters.BillingType = nil + } +} + +func (s *UsageCleanupService) maxRangeDays() int { + if s == nil || s.cfg == nil { + return 31 + } + if s.cfg.UsageCleanup.MaxRangeDays > 0 { + return s.cfg.UsageCleanup.MaxRangeDays + } + return 31 +} + +func (s *UsageCleanupService) batchSize() int { + if s == nil || s.cfg == nil { + return 5000 + } + if s.cfg.UsageCleanup.BatchSize > 0 { + return s.cfg.UsageCleanup.BatchSize + } + return 5000 +} + +func (s *UsageCleanupService) workerInterval() time.Duration { + if s == nil || s.cfg == nil { + return 10 * time.Second + } + if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 { + return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second + } + return 10 * time.Second +} + +func (s *UsageCleanupService) taskTimeout() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Minute + } + if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 { + return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second + } + return 30 * time.Minute +} diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go new file mode 100644 index 00000000..37d3eb19 --- /dev/null +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -0,0 +1,420 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type cleanupDeleteResponse struct { + deleted int64 + err error +} + +type cleanupDeleteCall struct { + filters UsageCleanupFilters + limit int +} + +type cleanupMarkCall struct { + taskID int64 + deletedRows int64 + errMsg string +} + +type cleanupRepoStub struct { + mu sync.Mutex + created []*UsageCleanupTask + createErr error + listTasks []UsageCleanupTask + listResult *pagination.PaginationResult + listErr error + claimQueue []*UsageCleanupTask + claimErr error + deleteQueue []cleanupDeleteResponse + deleteCalls []cleanupDeleteCall + markSucceeded []cleanupMarkCall + markFailed []cleanupMarkCall + statusByID map[int64]string + progressCalls []cleanupMarkCall + cancelCalls []int64 +} + +func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error { + if task == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if s.createErr != nil { + return s.createErr + } + if task.ID == 0 { + task.ID = int64(len(s.created) + 1) + } + if task.CreatedAt.IsZero() { + task.CreatedAt = time.Now().UTC() + } + if task.UpdatedAt.IsZero() { + task.UpdatedAt = task.CreatedAt + } + clone := *task + s.created = append(s.created, &clone) + return nil +} + +func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.listTasks, s.listResult, s.listErr +} + +func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.claimErr != nil { + return nil, s.claimErr + } + if len(s.claimQueue) == 0 { + return nil, nil + } + task := s.claimQueue[0] + s.claimQueue = s.claimQueue[1:] + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + s.statusByID[task.ID] = UsageCleanupStatusRunning + return task, nil +} + +func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statusByID == nil { + return "", sql.ErrNoRows + } + status, ok := s.statusByID[taskID] + if !ok { + return "", sql.ErrNoRows + } + return status, nil +} + +func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error { + s.mu.Lock() + defer s.mu.Unlock() + s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows}) + return nil +} + +func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.cancelCalls = append(s.cancelCalls, taskID) + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + status := s.statusByID[taskID] + if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning { + return false, nil + } + s.statusByID[taskID] = UsageCleanupStatusCanceled + return true, nil +} + +func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error { + s.mu.Lock() + defer s.mu.Unlock() + s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows}) + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + s.statusByID[taskID] = UsageCleanupStatusSucceeded + return nil +} + +func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg}) + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + s.statusByID[taskID] = UsageCleanupStatusFailed + return nil +} + +func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit}) + if len(s.deleteQueue) == 0 { + return 0, nil + } + resp := s.deleteQueue[0] + s.deleteQueue = s.deleteQueue[1:] + return resp.deleted, resp.err +} + +func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + userID := int64(-1) + apiKeyID := int64(10) + model := " gpt-4 " + billingType := int8(-2) + filters := UsageCleanupFilters{ + StartTime: start, + EndTime: end, + UserID: &userID, + APIKeyID: &apiKeyID, + Model: &model, + BillingType: &billingType, + } + + task, err := svc.CreateTask(context.Background(), filters, 9) + require.NoError(t, err) + require.Equal(t, UsageCleanupStatusPending, task.Status) + require.Nil(t, task.Filters.UserID) + require.NotNil(t, task.Filters.APIKeyID) + require.Equal(t, apiKeyID, *task.Filters.APIKeyID) + require.NotNil(t, task.Filters.Model) + require.Equal(t, "gpt-4", *task.Filters.Model) + require.Nil(t, task.Filters.BillingType) + require.Equal(t, int64(9), task.CreatedBy) +} + +func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + filters := UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + } + _, err := svc.CreateTask(context.Background(), filters, 0) + require.Error(t, err) + require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + filters := UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + } + _, err := svc.CreateTask(context.Background(), filters, 1) + require.Error(t, err) + require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err)) + require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(48 * time.Hour) + filters := UsageCleanupFilters{StartTime: start, EndTime: end} + + _, err := svc.CreateTask(context.Background(), filters, 1) + require.Error(t, err) + require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + _, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1) + require.Error(t, err) + require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) { + repo := &cleanupRepoStub{createErr: errors.New("db down")} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + filters := UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + } + _, err := svc.CreateTask(context.Background(), filters, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "create cleanup task") +} + +func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) { + repo := &cleanupRepoStub{ + claimQueue: []*UsageCleanupTask{ + {ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}}, + }, + deleteQueue: []cleanupDeleteResponse{ + {deleted: 2}, + {deleted: 2}, + {deleted: 1}, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + svc.runOnce() + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.deleteCalls, 3) + require.Len(t, repo.markSucceeded, 1) + require.Empty(t, repo.markFailed) + require.Equal(t, int64(5), repo.markSucceeded[0].taskID) + require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows) +} + +func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) { + repo := &cleanupRepoStub{claimErr: errors.New("claim failed")} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + svc.runOnce() + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Empty(t, repo.markSucceeded) + require.Empty(t, repo.markFailed) +} + +func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + svc.running = 1 + svc.runOnce() +} + +func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) { + longMsg := strings.Repeat("x", 600) + repo := &cleanupRepoStub{ + deleteQueue: []cleanupDeleteResponse{ + {err: errors.New(longMsg)}, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + task := &UsageCleanupTask{ + ID: 11, + Filters: UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.markFailed, 1) + require.Equal(t, int64(11), repo.markFailed[0].taskID) + require.Equal(t, 500, len(repo.markFailed[0].errMsg)) +} + +func TestUsageCleanupServiceListTasks(t *testing.T) { + repo := &cleanupRepoStub{ + listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}}, + listResult: &pagination.PaginationResult{ + Total: 2, + Page: 1, + PageSize: 20, + Pages: 1, + }, + } + svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}) + + tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Len(t, tasks, 2) + require.Equal(t, int64(2), result.Total) +} + +func TestUsageCleanupServiceListTasksNotReady(t *testing.T) { + var nilSvc *UsageCleanupService + _, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.Error(t, err) + + svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}) + _, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.Error(t, err) +} + +func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) { + var nilSvc *UsageCleanupService + require.Equal(t, 31, nilSvc.maxRangeDays()) + require.Equal(t, 5000, nilSvc.batchSize()) + require.Equal(t, 10*time.Second, nilSvc.workerInterval()) + require.Equal(t, 30*time.Minute, nilSvc.taskTimeout()) + nilSvc.Start() + nilSvc.Stop() + + repo := &cleanupRepoStub{} + cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}} + svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled) + svcDisabled.Start() + svcDisabled.Stop() + + timingWheel, err := NewTimingWheelService() + require.NoError(t, err) + + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}} + svc := NewUsageCleanupService(repo, timingWheel, nil, cfg) + require.Equal(t, 5*time.Second, svc.workerInterval()) + svc.Start() + svc.Stop() + + cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback) + require.Equal(t, 31, svcFallback.maxRangeDays()) + require.Equal(t, 5000, svcFallback.batchSize()) + require.Equal(t, 10*time.Second, svcFallback.workerInterval()) + + svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback) + svcMissingDeps.Start() +} + +func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) { + model := " " + apiKeyID := int64(-5) + accountID := int64(-1) + groupID := int64(-2) + filters := UsageCleanupFilters{ + UserID: &apiKeyID, + APIKeyID: &apiKeyID, + AccountID: &accountID, + GroupID: &groupID, + Model: &model, + } + + sanitizeUsageCleanupFilters(&filters) + require.Nil(t, filters.UserID) + require.Nil(t, filters.APIKeyID) + require.Nil(t, filters.AccountID) + require.Nil(t, filters.GroupID) + require.Nil(t, filters.Model) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index acc0a5fb..0b9bc20c 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -57,6 +57,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim return svc } +// ProvideUsageCleanupService 创建并启动使用记录清理任务服务 +func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService { + svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg) + svc.Start() + return svc +} + // ProvideAccountExpiryService creates and starts AccountExpiryService. func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService { svc := NewAccountExpiryService(accountRepo, time.Minute) @@ -248,6 +255,7 @@ var ProviderSet = wire.NewSet( ProvideAccountExpiryService, ProvideTimingWheelService, ProvideDashboardAggregationService, + ProvideUsageCleanupService, ProvideDeferredService, NewAntigravityQuotaFetcher, NewUserAttributeService, diff --git a/backend/migrations/042_add_usage_cleanup_tasks.sql b/backend/migrations/042_add_usage_cleanup_tasks.sql new file mode 100644 index 00000000..ce4be91f --- /dev/null +++ b/backend/migrations/042_add_usage_cleanup_tasks.sql @@ -0,0 +1,21 @@ +-- 042_add_usage_cleanup_tasks.sql +-- 使用记录清理任务表 + +CREATE TABLE IF NOT EXISTS usage_cleanup_tasks ( + id BIGSERIAL PRIMARY KEY, + status VARCHAR(20) NOT NULL, + filters JSONB NOT NULL, + created_by BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT, + deleted_rows BIGINT NOT NULL DEFAULT 0, + error_message TEXT, + started_at TIMESTAMPTZ, + finished_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_status_created_at + ON usage_cleanup_tasks(status, created_at DESC); + +CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_created_at + ON usage_cleanup_tasks(created_at DESC); diff --git a/backend/migrations/043_add_usage_cleanup_cancel_audit.sql b/backend/migrations/043_add_usage_cleanup_cancel_audit.sql new file mode 100644 index 00000000..42ca6696 --- /dev/null +++ b/backend/migrations/043_add_usage_cleanup_cancel_audit.sql @@ -0,0 +1,10 @@ +-- 043_add_usage_cleanup_cancel_audit.sql +-- usage_cleanup_tasks 取消任务审计字段 + +ALTER TABLE usage_cleanup_tasks + ADD COLUMN IF NOT EXISTS canceled_by BIGINT REFERENCES users(id) ON DELETE SET NULL, + ADD COLUMN IF NOT EXISTS canceled_at TIMESTAMPTZ; + +CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_canceled_at + ON usage_cleanup_tasks(canceled_at DESC); + diff --git a/config.yaml b/config.yaml index 424ce9eb..5e7513fb 100644 --- a/config.yaml +++ b/config.yaml @@ -251,6 +251,27 @@ dashboard_aggregation: # 日聚合保留天数 daily_days: 730 +# ============================================================================= +# Usage Cleanup Task Configuration +# 使用记录清理任务配置(重启生效) +# ============================================================================= +usage_cleanup: + # Enable cleanup task worker + # 启用清理任务执行器 + enabled: true + # Max date range (days) per task + # 单次任务最大时间跨度(天) + max_range_days: 31 + # Batch delete size + # 单批删除数量 + batch_size: 5000 + # Worker interval (seconds) + # 执行器轮询间隔(秒) + worker_interval_seconds: 10 + # Task execution timeout (seconds) + # 单次任务最大执行时长(秒) + task_timeout_seconds: 1800 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 9e85d1ff..1f4aa266 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -292,6 +292,27 @@ dashboard_aggregation: # 日聚合保留天数 daily_days: 730 +# ============================================================================= +# Usage Cleanup Task Configuration +# 使用记录清理任务配置(重启生效) +# ============================================================================= +usage_cleanup: + # Enable cleanup task worker + # 启用清理任务执行器 + enabled: true + # Max date range (days) per task + # 单次任务最大时间跨度(天) + max_range_days: 31 + # Batch delete size + # 单批删除数量 + batch_size: 5000 + # Worker interval (seconds) + # 执行器轮询间隔(秒) + worker_interval_seconds: 10 + # Task execution timeout (seconds) + # 单次任务最大执行时长(秒) + task_timeout_seconds: 1800 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 9b338788..ae48bec2 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -50,6 +50,7 @@ export interface TrendParams { account_id?: number group_id?: number stream?: boolean + billing_type?: number | null } export interface TrendResponse { @@ -78,6 +79,7 @@ export interface ModelStatsParams { account_id?: number group_id?: number stream?: boolean + billing_type?: number | null } export interface ModelStatsResponse { diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index dd85fc24..c271a2d0 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -31,6 +31,46 @@ export interface SimpleApiKey { user_id: number } +export interface UsageCleanupFilters { + start_time: string + end_time: string + user_id?: number + api_key_id?: number + account_id?: number + group_id?: number + model?: string | null + stream?: boolean | null + billing_type?: number | null +} + +export interface UsageCleanupTask { + id: number + status: string + filters: UsageCleanupFilters + created_by: number + deleted_rows: number + error_message?: string | null + canceled_by?: number | null + canceled_at?: string | null + started_at?: string | null + finished_at?: string | null + created_at: string + updated_at: string +} + +export interface CreateUsageCleanupTaskRequest { + start_date: string + end_date: string + user_id?: number + api_key_id?: number + account_id?: number + group_id?: number + model?: string | null + stream?: boolean | null + billing_type?: number | null + timezone?: string +} + export interface AdminUsageQueryParams extends UsageQueryParams { user_id?: number } @@ -108,11 +148,51 @@ export async function searchApiKeys(userId?: number, keyword?: string): Promise< return data } +/** + * List usage cleanup tasks (admin only) + * @param params - Query parameters for pagination + * @returns Paginated list of cleanup tasks + */ +export async function listCleanupTasks( + params: { page?: number; page_size?: number }, + options?: { signal?: AbortSignal } +): Promise> { + const { data } = await apiClient.get>('/admin/usage/cleanup-tasks', { + params, + signal: options?.signal + }) + return data +} + +/** + * Create a usage cleanup task (admin only) + * @param payload - Cleanup task parameters + * @returns Created cleanup task + */ +export async function createCleanupTask(payload: CreateUsageCleanupTaskRequest): Promise { + const { data } = await apiClient.post('/admin/usage/cleanup-tasks', payload) + return data +} + +/** + * Cancel a usage cleanup task (admin only) + * @param taskId - Task ID to cancel + */ +export async function cancelCleanupTask(taskId: number): Promise<{ id: number; status: string }> { + const { data } = await apiClient.post<{ id: number; status: string }>( + `/admin/usage/cleanup-tasks/${taskId}/cancel` + ) + return data +} + export const adminUsageAPI = { list, getStats, searchUsers, - searchApiKeys + searchApiKeys, + listCleanupTasks, + createCleanupTask, + cancelCleanupTask } export default adminUsageAPI diff --git a/frontend/src/components/admin/usage/UsageCleanupDialog.vue b/frontend/src/components/admin/usage/UsageCleanupDialog.vue new file mode 100644 index 00000000..4cd562e8 --- /dev/null +++ b/frontend/src/components/admin/usage/UsageCleanupDialog.vue @@ -0,0 +1,339 @@ + + + diff --git a/frontend/src/components/admin/usage/UsageFilters.vue b/frontend/src/components/admin/usage/UsageFilters.vue index 0926d83c..b17e0fdc 100644 --- a/frontend/src/components/admin/usage/UsageFilters.vue +++ b/frontend/src/components/admin/usage/UsageFilters.vue @@ -127,6 +127,12 @@ + +
@@ -147,10 +153,13 @@
-
+
+ @@ -174,16 +183,20 @@ interface Props { exporting: boolean startDate: string endDate: string + showActions?: boolean } -const props = defineProps() +const props = withDefaults(defineProps(), { + showActions: true +}) const emit = defineEmits([ 'update:modelValue', 'update:startDate', 'update:endDate', 'change', 'reset', - 'export' + 'export', + 'cleanup' ]) const { t } = useI18n() @@ -221,6 +234,12 @@ const streamTypeOptions = ref([ { value: false, label: t('usage.sync') } ]) +const billingTypeOptions = ref([ + { value: null, label: t('admin.usage.allBillingTypes') }, + { value: 0, label: t('admin.usage.billingTypeBalance') }, + { value: 1, label: t('admin.usage.billingTypeSubscription') } +]) + const emitChange = () => emit('change') const updateStartDate = (value: string) => { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index e4fe1bd1..2a000d0b 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1893,7 +1893,43 @@ export default { cacheCreationTokens: 'Cache Creation Tokens', cacheReadTokens: 'Cache Read Tokens', failedToLoad: 'Failed to load usage records', - ipAddress: 'IP' + billingType: 'Billing Type', + allBillingTypes: 'All Billing Types', + billingTypeBalance: 'Balance', + billingTypeSubscription: 'Subscription', + ipAddress: 'IP', + cleanup: { + button: 'Cleanup', + title: 'Cleanup Usage Records', + warning: 'Cleanup is irreversible and will affect historical stats.', + submit: 'Submit Cleanup', + submitting: 'Submitting...', + confirmTitle: 'Confirm Cleanup', + confirmMessage: 'Are you sure you want to submit this cleanup task? This action cannot be undone.', + confirmSubmit: 'Confirm Cleanup', + cancel: 'Cancel', + cancelConfirmTitle: 'Confirm Cancel', + cancelConfirmMessage: 'Are you sure you want to cancel this cleanup task?', + cancelConfirm: 'Confirm Cancel', + cancelSuccess: 'Cleanup task canceled', + cancelFailed: 'Failed to cancel cleanup task', + recentTasks: 'Recent Cleanup Tasks', + loadingTasks: 'Loading tasks...', + noTasks: 'No cleanup tasks yet', + range: 'Range', + deletedRows: 'Deleted', + missingRange: 'Please select a date range', + submitSuccess: 'Cleanup task created', + submitFailed: 'Failed to create cleanup task', + loadFailed: 'Failed to load cleanup tasks', + status: { + pending: 'Pending', + running: 'Running', + succeeded: 'Succeeded', + failed: 'Failed', + canceled: 'Canceled' + } + } }, // Ops Monitoring diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 35242c69..0c27f7a3 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2041,7 +2041,43 @@ export default { cacheCreationTokens: '缓存创建 Token', cacheReadTokens: '缓存读取 Token', failedToLoad: '加载使用记录失败', - ipAddress: 'IP' + billingType: '计费类型', + allBillingTypes: '全部计费类型', + billingTypeBalance: '钱包余额', + billingTypeSubscription: '订阅套餐', + ipAddress: 'IP', + cleanup: { + button: '清理', + title: '清理使用记录', + warning: '清理不可恢复,且会影响历史统计回看。', + submit: '提交清理', + submitting: '提交中...', + confirmTitle: '确认清理', + confirmMessage: '确定要提交清理任务吗?清理不可恢复。', + confirmSubmit: '确认清理', + cancel: '取消任务', + cancelConfirmTitle: '确认取消', + cancelConfirmMessage: '确定要取消该清理任务吗?', + cancelConfirm: '确认取消', + cancelSuccess: '清理任务已取消', + cancelFailed: '取消清理任务失败', + recentTasks: '最近清理任务', + loadingTasks: '正在加载任务...', + noTasks: '暂无清理任务', + range: '时间范围', + deletedRows: '删除数量', + missingRange: '请选择时间范围', + submitSuccess: '清理任务已创建', + submitFailed: '创建清理任务失败', + loadFailed: '加载清理任务失败', + status: { + pending: '待执行', + running: '执行中', + succeeded: '已完成', + failed: '失败', + canceled: '已取消' + } + } }, // Ops Monitoring diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 523033c2..1bb6e5d6 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -618,6 +618,7 @@ export interface UsageLog { actual_cost: number rate_multiplier: number account_rate_multiplier?: number | null + billing_type: number stream: boolean duration_ms: number @@ -642,6 +643,33 @@ export interface UsageLog { subscription?: UserSubscription } +export interface UsageCleanupFilters { + start_time: string + end_time: string + user_id?: number + api_key_id?: number + account_id?: number + group_id?: number + model?: string | null + stream?: boolean | null + billing_type?: number | null +} + +export interface UsageCleanupTask { + id: number + status: string + filters: UsageCleanupFilters + created_by: number + deleted_rows: number + error_message?: string | null + canceled_by?: number | null + canceled_at?: string | null + started_at?: string | null + finished_at?: string | null + created_at: string + updated_at: string +} + export interface RedeemCode { id: number code: string @@ -865,6 +893,7 @@ export interface UsageQueryParams { group_id?: number model?: string stream?: boolean + billing_type?: number | null start_date?: string end_date?: string } diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue index 6f62f59e..40b63ec3 100644 --- a/frontend/src/views/admin/UsageView.vue +++ b/frontend/src/views/admin/UsageView.vue @@ -17,12 +17,19 @@
- + +