feat(usage): 添加清理任务与统计过滤

This commit is contained in:
yangjianbo
2026-01-18 10:52:18 +08:00
parent 74a3c74514
commit ef5a41057f
44 changed files with 4478 additions and 46 deletions

View File

@@ -70,6 +70,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService, schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService, tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService, accountExpiry *service.AccountExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService, billingCache *service.BillingCacheService,
@@ -123,6 +124,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"UsageCleanupService", func() error {
if usageCleanup != nil {
usageCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
tokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil

View File

@@ -153,7 +153,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService) systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) usageCleanupRepository := repository.NewUsageCleanupRepository(db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
@@ -175,7 +177,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -208,6 +210,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService, schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService, tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService, accountExpiry *service.AccountExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService, billingCache *service.BillingCacheService,
@@ -260,6 +263,12 @@ func provideCleanup(
} }
return nil return nil
}}, }},
{"UsageCleanupService", func() error {
if usageCleanup != nil {
usageCleanup.Stop()
}
return nil
}},
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
tokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil

View File

@@ -31,6 +31,7 @@ require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
dario.cat/mergo v1.0.2 // indirect dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect

View File

@@ -141,6 +141,7 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=

View File

@@ -55,6 +55,7 @@ type Config struct {
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
@@ -489,6 +490,20 @@ type DashboardAggregationRetentionConfig struct {
DailyDays int `mapstructure:"daily_days"` DailyDays int `mapstructure:"daily_days"`
} }
// UsageCleanupConfig 使用记录清理任务配置
type UsageCleanupConfig struct {
// Enabled: 是否启用清理任务执行器
Enabled bool `mapstructure:"enabled"`
// MaxRangeDays: 单次任务允许的最大时间跨度(天)
MaxRangeDays int `mapstructure:"max_range_days"`
// BatchSize: 单批删除数量
BatchSize int `mapstructure:"batch_size"`
// WorkerIntervalSeconds: 后台任务轮询间隔(秒)
WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
// TaskTimeoutSeconds: 单次任务最大执行时长(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
}
func NormalizeRunMode(value string) string { func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value)) normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized { switch normalized {
@@ -749,6 +764,13 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2) viper.SetDefault("dashboard_aggregation.recompute_days", 2)
// Usage cleanup task
viper.SetDefault("usage_cleanup.enabled", true)
viper.SetDefault("usage_cleanup.max_range_days", 31)
viper.SetDefault("usage_cleanup.batch_size", 5000)
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true) viper.SetDefault("gateway.log_upstream_error_body", true)
@@ -985,6 +1007,33 @@ func (c *Config) Validate() error {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
} }
} }
if c.UsageCleanup.Enabled {
if c.UsageCleanup.MaxRangeDays <= 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be positive")
}
if c.UsageCleanup.BatchSize <= 0 {
return fmt.Errorf("usage_cleanup.batch_size must be positive")
}
if c.UsageCleanup.WorkerIntervalSeconds <= 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive")
}
if c.UsageCleanup.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive")
}
} else {
if c.UsageCleanup.MaxRangeDays < 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be non-negative")
}
if c.UsageCleanup.BatchSize < 0 {
return fmt.Errorf("usage_cleanup.batch_size must be non-negative")
}
if c.UsageCleanup.WorkerIntervalSeconds < 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative")
}
if c.UsageCleanup.TaskTimeoutSeconds < 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
}
}
if c.Gateway.MaxBodySize <= 0 { if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive") return fmt.Errorf("gateway.max_body_size must be positive")
} }

View File

@@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err) t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
} }
} }
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.UsageCleanup.Enabled {
t.Fatalf("UsageCleanup.Enabled = false, want true")
}
if cfg.UsageCleanup.MaxRangeDays != 31 {
t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays)
}
if cfg.UsageCleanup.BatchSize != 5000 {
t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize)
}
if cfg.UsageCleanup.WorkerIntervalSeconds != 10 {
t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds)
}
if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 {
t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds)
}
}
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = true
cfg.UsageCleanup.MaxRangeDays = 0
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") {
t.Fatalf("Validate() expected max_range_days error, got: %v", err)
}
}
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = false
cfg.UsageCleanup.BatchSize = -1
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.batch_size") {
t.Fatalf("Validate() expected batch_size error, got: %v", err)
}
}
func TestConfigAddressHelpers(t *testing.T) {
server := ServerConfig{Host: "127.0.0.1", Port: 9000}
if server.Address() != "127.0.0.1:9000" {
t.Fatalf("ServerConfig.Address() = %q", server.Address())
}
dbCfg := DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "",
DBName: "sub2api",
SSLMode: "disable",
}
if !strings.Contains(dbCfg.DSN(), "password=") {
} else {
t.Fatalf("DatabaseConfig.DSN() should not include password when empty")
}
dbCfg.Password = "secret"
if !strings.Contains(dbCfg.DSN(), "password=secret") {
t.Fatalf("DatabaseConfig.DSN() missing password")
}
dbCfg.Password = ""
if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty")
}
if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone")
}
if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone")
}
redis := RedisConfig{Host: "redis", Port: 6379}
if redis.Address() != "redis:6379" {
t.Fatalf("RedisConfig.Address() = %q", redis.Address())
}
}
func TestNormalizeStringSlice(t *testing.T) {
values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"})
if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" {
t.Fatalf("normalizeStringSlice() unexpected result: %#v", values)
}
if normalizeStringSlice(nil) != nil {
t.Fatalf("normalizeStringSlice(nil) expected nil slice")
}
}
func TestGetServerAddressFromEnv(t *testing.T) {
t.Setenv("SERVER_HOST", "127.0.0.1")
t.Setenv("SERVER_PORT", "9090")
address := GetServerAddress()
if address != "127.0.0.1:9090" {
t.Fatalf("GetServerAddress() = %q", address)
}
}
func TestValidateAbsoluteHTTPURL(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil {
t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err)
}
if err := ValidateAbsoluteHTTPURL(""); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url")
}
if err := ValidateAbsoluteHTTPURL("/relative"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url")
}
if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme")
}
if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
}
if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err)
}
if err := ValidateFrontendRedirectURL("example.com/path"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url")
}
if err := ValidateFrontendRedirectURL("//evil.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject // prefix")
}
if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme")
}
}
func TestWarnIfInsecureURL(t *testing.T) {
warnIfInsecureURL("test", "http://example.com")
warnIfInsecureURL("test", "bad://url")
}
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
secret, err := generateJWTSecret(0)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Ops.Cleanup.Enabled = true
cfg.Ops.Cleanup.Schedule = ""
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for ops.cleanup.schedule")
}
if !strings.Contains(err.Error(), "ops.cleanup.schedule") {
t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err)
}
}
func TestValidateConcurrencyPingInterval(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Concurrency.PingInterval = 3
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for concurrency.ping_interval")
}
if !strings.Contains(err.Error(), "concurrency.ping_interval") {
t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err)
}
}
func TestProvideConfig(t *testing.T) {
viper.Reset()
if _, err := ProvideConfig(); err != nil {
t.Fatalf("ProvideConfig() error: %v", err)
}
}
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Security.CSP.Enabled = true
cfg.Security.CSP.Policy = "default-src 'self'"
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "client"
cfg.LinuxDo.ClientSecret = "secret"
cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize"
cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token"
cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
}
}
func TestValidateJWTSecretStrength(t *testing.T) {
if !isWeakJWTSecret("change-me-in-production") {
t.Fatalf("isWeakJWTSecret should detect weak secret")
}
if isWeakJWTSecret("StrongSecretValue") {
t.Fatalf("isWeakJWTSecret should accept strong secret")
}
}
func TestGenerateJWTSecretWithLength(t *testing.T) {
secret, err := generateJWTSecret(16)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
}
}
func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars")
}
if err := ValidateFrontendRedirectURL("http://"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject missing host")
}
if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject mailto")
}
}
func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL("secure", "https://example.com")
}
func TestValidateConfigErrors(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
return cfg
}
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "jwt expire hour positive",
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
wantErr: "jwt.expire_hour must be positive",
},
{
name: "jwt expire hour max",
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
wantErr: "jwt.expire_hour must be <= 168",
},
{
name: "csp policy required",
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
wantErr: "security.csp.policy",
},
{
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
},
{
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
c.LinuxDo.TokenURL = "https://example.com/token"
c.LinuxDo.UserInfoURL = "https://example.com/userinfo"
c.LinuxDo.RedirectURL = "https://example.com/callback"
c.LinuxDo.FrontendRedirectURL = "/auth/callback"
c.LinuxDo.TokenAuthMethod = "invalid"
},
wantErr: "linuxdo_connect.token_auth_method",
},
{
name: "billing circuit breaker threshold",
mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 },
wantErr: "billing.circuit_breaker.failure_threshold",
},
{
name: "billing circuit breaker reset",
mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 },
wantErr: "billing.circuit_breaker.reset_timeout_seconds",
},
{
name: "billing circuit breaker half open",
mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 },
wantErr: "billing.circuit_breaker.half_open_requests",
},
{
name: "database max open conns",
mutate: func(c *Config) { c.Database.MaxOpenConns = 0 },
wantErr: "database.max_open_conns",
},
{
name: "database max lifetime",
mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 },
wantErr: "database.conn_max_lifetime_minutes",
},
{
name: "database idle exceeds open",
mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 },
wantErr: "database.max_idle_conns cannot exceed",
},
{
name: "redis dial timeout",
mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 },
wantErr: "redis.dial_timeout_seconds",
},
{
name: "redis read timeout",
mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 },
wantErr: "redis.read_timeout_seconds",
},
{
name: "redis write timeout",
mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 },
wantErr: "redis.write_timeout_seconds",
},
{
name: "redis pool size",
mutate: func(c *Config) { c.Redis.PoolSize = 0 },
wantErr: "redis.pool_size",
},
{
name: "redis idle exceeds pool",
mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 },
wantErr: "redis.min_idle_conns cannot exceed",
},
{
name: "dashboard cache disabled negative",
mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 },
wantErr: "dashboard_cache.stats_ttl_seconds",
},
{
name: "dashboard cache fresh ttl positive",
mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 },
wantErr: "dashboard_cache.stats_fresh_ttl_seconds",
},
{
name: "dashboard aggregation enabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "dashboard aggregation backfill positive",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.BackfillEnabled = true
c.DashboardAgg.BackfillMaxDays = 0
},
wantErr: "dashboard_aggregation.backfill_max_days",
},
{
name: "dashboard aggregation retention",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "usage cleanup max range",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 },
wantErr: "usage_cleanup.max_range_days",
},
{
name: "usage cleanup worker interval",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 },
wantErr: "usage_cleanup.worker_interval_seconds",
},
{
name: "usage cleanup batch size",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "usage cleanup disabled negative",
mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "gateway max body size",
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
wantErr: "gateway.max_body_size",
},
{
name: "gateway max idle conns",
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
wantErr: "gateway.max_idle_conns",
},
{
name: "gateway max idle conns per host",
mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 },
wantErr: "gateway.max_idle_conns_per_host",
},
{
name: "gateway idle timeout",
mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 },
wantErr: "gateway.idle_conn_timeout_seconds",
},
{
name: "gateway max upstream clients",
mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 },
wantErr: "gateway.max_upstream_clients",
},
{
name: "gateway client idle ttl",
mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 },
wantErr: "gateway.client_idle_ttl_seconds",
},
{
name: "gateway concurrency slot ttl",
mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 },
wantErr: "gateway.concurrency_slot_ttl_minutes",
},
{
name: "gateway max conns per host",
mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 },
wantErr: "gateway.max_conns_per_host",
},
{
name: "gateway connection isolation",
mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" },
wantErr: "gateway.connection_pool_isolation",
},
{
name: "gateway stream keepalive range",
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
wantErr: "gateway.stream_data_interval_timeout",
},
{
name: "gateway stream data interval negative",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
},
{
name: "gateway max line size",
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
wantErr: "gateway.max_line_size must be at least",
},
{
name: "gateway max line size negative",
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative",
},
{
name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
wantErr: "gateway.scheduling.sticky_session_max_waiting",
},
{
name: "gateway scheduling outbox poll",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
wantErr: "gateway.scheduling.outbox_poll_interval_seconds",
},
{
name: "gateway scheduling outbox failures",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 },
wantErr: "gateway.scheduling.outbox_lag_rebuild_failures",
},
{
name: "gateway outbox lag rebuild",
mutate: func(c *Config) {
c.Gateway.Scheduling.OutboxLagWarnSeconds = 10
c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5
},
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
},
{
name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
wantErr: "ops.metrics_collector_cache.ttl",
},
{
name: "ops cleanup retention",
mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 },
wantErr: "ops.cleanup.error_log_retention_days",
},
{
name: "ops cleanup minute retention",
mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 },
wantErr: "ops.cleanup.minute_metrics_retention_days",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg := buildValid(t)
tt.mutate(cfg)
err := cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,262 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
router.GET("/api/v1/admin/groups", groupHandler.List)
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
router.POST("/api/v1/admin/groups", groupHandler.Create)
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
router.GET("/api/v1/admin/proxies", proxyHandler.List)
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
return router, adminSvc
}
func TestUserHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
body, _ := json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
updateBody := map[string]any{"email": "updated@example.com"}
body, _ = json.Marshal(updateBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "update"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestProxyHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestRedeemHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -0,0 +1,134 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestParseTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
c.Request = req
start, end := parseTimeRange(c)
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
c.Request = req
start, end = parseTimeRange(c)
require.False(t, start.IsZero())
require.False(t, end.IsZero())
}
func TestParseOpsViewParam(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
c3, _ := gin.CreateTestContext(w)
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
require.Equal(t, "", parseOpsViewParam(nil))
}
func TestParseOpsDuration(t *testing.T) {
dur, ok := parseOpsDuration("1h")
require.True(t, ok)
require.Equal(t, time.Hour, dur)
_, ok = parseOpsDuration("invalid")
require.False(t, ok)
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
now := time.Now().UTC()
startStr := now.Add(-time.Hour).Format(time.RFC3339)
endStr := now.Format(time.RFC3339)
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
start, end, err := parseOpsTimeRange(c, "1h")
require.NoError(t, err)
require.True(t, start.Before(end))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
_, _, err = parseOpsTimeRange(c2, "1h")
require.Error(t, err)
}
func TestParseOpsRealtimeWindow(t *testing.T) {
dur, label, ok := parseOpsRealtimeWindow("5m")
require.True(t, ok)
require.Equal(t, 5*time.Minute, dur)
require.Equal(t, "5min", label)
_, _, ok = parseOpsRealtimeWindow("invalid")
require.False(t, ok)
}
func TestPickThroughputBucketSeconds(t *testing.T) {
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
}
func TestParseOpsQueryMode(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
}
func TestOpsAlertRuleValidation(t *testing.T) {
raw := map[string]json.RawMessage{
"name": json.RawMessage(`"High error rate"`),
"metric_type": json.RawMessage(`"error_rate"`),
"operator": json.RawMessage(`">"`),
"threshold": json.RawMessage(`90`),
}
validated, err := validateOpsAlertRulePayload(raw)
require.NoError(t, err)
require.Equal(t, "High error rate", validated.Name)
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
require.Error(t, err)
require.True(t, isPercentOrRateMetric("error_rate"))
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
}
func TestOpsWSHelpers(t *testing.T) {
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
require.Len(t, prefixes, 1)
require.Len(t, invalid, 1)
host := hostWithoutPort("example.com:443")
require.Equal(t, "example.com", host)
addr := netip.MustParseAddr("10.0.0.1")
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}

View File

@@ -0,0 +1,290 @@
package admin
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
}
func newStubAdminService() *stubAdminService {
now := time.Now().UTC()
user := service.User{
ID: 1,
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
apiKey := service.APIKey{
ID: 10,
UserID: user.ID,
Key: "sk-test",
Name: "test",
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
group := service.Group{
ID: 2,
Name: "group",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
account := service.Account{
ID: 3,
Name: "account",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
proxy := service.Proxy{
ID: 4,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
redeem := service.RedeemCode{
ID: 5,
Code: "R-TEST",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
CreatedAt: now,
}
return &stubAdminService{
users: []service.User{user},
apiKeys: []service.APIKey{apiKey},
groups: []service.Group{group},
accounts: []service.Account{account},
proxies: []service.Proxy{proxy},
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
redeems: []service.RedeemCode{redeem},
}
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil
}
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
for i := range s.users {
if s.users[i].ID == id {
return &s.users[i], nil
}
}
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
out := make([]*service.Account, 0, len(ids))
for _, id := range ids {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
out = append(out, &account)
}
return out, nil
}
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
return &account, nil
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
return s.proxies, int64(len(s.proxies)), nil
}
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil
}
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
return s.proxies, nil
}
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
return s.proxyCounts, nil
}
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
}
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
}
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return false, nil
}
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
return &code, nil
}
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
return s.redeems, nil
}
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
return int64(len(ids)), nil
}
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
return &code, nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -186,7 +186,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data // GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend // GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream // Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c) startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day") granularity := c.DefaultQuery("granularity", "day")
@@ -195,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
var userID, apiKeyID, accountID, groupID int64 var userID, apiKeyID, accountID, groupID int64
var model string var model string
var stream *bool var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" { if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
@@ -224,8 +225,17 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
stream = &streamVal stream = &streamVal
} }
} }
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream) trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get usage trend") response.Error(c, 500, "Failed to get usage trend")
return return
@@ -241,13 +251,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics // GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models // GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream // Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) { func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c) startTime, endTime := parseTimeRange(c)
// Parse optional filter params // Parse optional filter params
var userID, apiKeyID, accountID, groupID int64 var userID, apiKeyID, accountID, groupID int64
var stream *bool var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" { if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
@@ -274,8 +285,17 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
stream = &streamVal stream = &streamVal
} }
} }
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream) stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get model statistics") response.Error(c, 500, "Failed to get model statistics")
return return

View File

@@ -0,0 +1,377 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"database/sql"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type cleanupRepoStub struct {
mu sync.Mutex
created []*service.UsageCleanupTask
listTasks []service.UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
statusByID map[int64]string
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
task.UpdatedAt = task.CreatedAt
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
return nil, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
return 0, nil
}
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
if userID > 0 {
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
})
}
handler := NewUsageHandler(nil, nil, nil, cleanupService)
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
return router
}
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-13-01",
"end_date": "2024-01-02",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-02-40",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": " 2024-01-01 ",
"end_date": "2024-01-02",
"timezone": "UTC",
"model": "gpt-4",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Equal(t, int64(99), created.CreatedBy)
require.NotNil(t, created.Filters.Model)
require.Equal(t, "gpt-4", *created.Filters.Model)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
require.True(t, created.Filters.StartTime.Equal(start))
require.True(t, created.Filters.EndTime.Equal(end))
}
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 0)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
repo.listTasks = []service.UsageCleanupTask{
{
ID: 7,
Status: service.UsageCleanupStatusSucceeded,
CreatedBy: 4,
},
}
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Items []dto.UsageCleanupTask `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Items, 1)
require.Equal(t, int64(7), resp.Data.Items[0].ID)
require.Equal(t, int64(1), resp.Data.Total)
require.Equal(t, 1, resp.Data.Page)
}
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
repo := &cleanupRepoStub{listErr: errors.New("boom")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
}
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -1,7 +1,10 @@
package admin package admin
import ( import (
"log"
"net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -9,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -16,9 +20,10 @@ import (
// UsageHandler handles admin usage-related requests // UsageHandler handles admin usage-related requests
type UsageHandler struct { type UsageHandler struct {
usageService *service.UsageService usageService *service.UsageService
apiKeyService *service.APIKeyService apiKeyService *service.APIKeyService
adminService service.AdminService adminService service.AdminService
cleanupService *service.UsageCleanupService
} }
// NewUsageHandler creates a new admin usage handler // NewUsageHandler creates a new admin usage handler
@@ -26,14 +31,30 @@ func NewUsageHandler(
usageService *service.UsageService, usageService *service.UsageService,
apiKeyService *service.APIKeyService, apiKeyService *service.APIKeyService,
adminService service.AdminService, adminService service.AdminService,
cleanupService *service.UsageCleanupService,
) *UsageHandler { ) *UsageHandler {
return &UsageHandler{ return &UsageHandler{
usageService: usageService, usageService: usageService,
apiKeyService: apiKeyService, apiKeyService: apiKeyService,
adminService: adminService, adminService: adminService,
cleanupService: cleanupService,
} }
} }
// CreateUsageCleanupTaskRequest represents cleanup task creation request
type CreateUsageCleanupTaskRequest struct {
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
UserID *int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
}
// List handles listing all usage records with filters // List handles listing all usage records with filters
// GET /api/v1/admin/usage // GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) { func (h *UsageHandler) List(c *gin.Context) {
@@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
// ListCleanupTasks handles listing usage cleanup tasks
// GET /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
operator := int64(0)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
out := make([]dto.UsageCleanupTask, 0, len(tasks))
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
// CreateCleanupTask handles creating a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
var req CreateUsageCleanupTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.StartDate = strings.TrimSpace(req.StartDate)
req.EndDate = strings.TrimSpace(req.EndDate)
if req.StartDate == "" || req.EndDate == "" {
response.BadRequest(c, "start_date and end_date are required")
return
}
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
UserID: req.UserID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
Stream: req.Stream,
BillingType: req.BillingType,
}
var userID any
if filters.UserID != nil {
userID = *filters.UserID
}
var apiKeyID any
if filters.APIKeyID != nil {
apiKeyID = *filters.APIKeyID
}
var accountID any
if filters.AccountID != nil {
accountID = *filters.AccountID
}
var groupID any
if filters.GroupID != nil {
groupID = *filters.GroupID
}
var model any
if filters.Model != nil {
model = *filters.Model
}
var stream any
if filters.Stream != nil {
stream = *filters.Stream
}
var billingType any
if filters.BillingType != nil {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
}
// CancelCleanupTask handles canceling a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
taskID, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || taskID <= 0 {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -340,6 +340,36 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true) return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
} }
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
if task == nil {
return nil
}
return &UsageCleanupTask{
ID: task.ID,
Status: task.Status,
Filters: UsageCleanupFilters{
StartTime: task.Filters.StartTime,
EndTime: task.Filters.EndTime,
UserID: task.Filters.UserID,
APIKeyID: task.Filters.APIKeyID,
AccountID: task.Filters.AccountID,
GroupID: task.Filters.GroupID,
Model: task.Filters.Model,
Stream: task.Filters.Stream,
BillingType: task.Filters.BillingType,
},
CreatedBy: task.CreatedBy,
DeletedRows: task.DeletedRows,
ErrorMessage: task.ErrorMsg,
CanceledBy: task.CanceledBy,
CanceledAt: task.CanceledAt,
StartedAt: task.StartedAt,
FinishedAt: task.FinishedAt,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
}
}
func SettingFromService(s *service.Setting) *Setting { func SettingFromService(s *service.Setting) *Setting {
if s == nil { if s == nil {
return nil return nil

View File

@@ -223,6 +223,33 @@ type UsageLog struct {
Subscription *UserSubscription `json:"subscription,omitempty"` Subscription *UserSubscription `json:"subscription,omitempty"`
} }
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
type UsageCleanupTask struct {
ID int64 `json:"id"`
Status string `json:"status"`
Filters UsageCleanupFilters `json:"filters"`
CreatedBy int64 `json:"created_by"`
DeletedRows int64 `json:"deleted_rows"`
ErrorMessage *string `json:"error_message,omitempty"`
CanceledBy *int64 `json:"canceled_by,omitempty"`
CanceledAt *time.Time `json:"canceled_at,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AccountSummary is a minimal account info for usage log display. // AccountSummary is a minimal account info for usage log display.
// It intentionally excludes sensitive fields like Credentials, Proxy, etc. // It intentionally excludes sensitive fields like Credentials, Proxy, etc.
type AccountSummary struct { type AccountSummary struct {

View File

@@ -77,6 +77,75 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
return nil return nil
} }
func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
// 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
return err
}
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) { func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
var ts time.Time var ts time.Time
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1" query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"

View File

@@ -0,0 +1,363 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageCleanupRepository struct {
sql sqlExecutor
}
func NewUsageCleanupRepository(sqlDB *sql.DB) service.UsageCleanupRepository {
return &usageCleanupRepository{sql: sqlDB}
}
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
filtersJSON, err := json.Marshal(task.Filters)
if err != nil {
return fmt.Errorf("marshal cleanup filters: %w", err)
}
query := `
INSERT INTO usage_cleanup_tasks (
status,
filters,
created_by,
deleted_rows
) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at
`
if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
return err
}
return nil
}
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
var total int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
return nil, nil, err
}
if total == 0 {
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
}
query := `
SELECT id, status, filters, created_by, deleted_rows, error_message,
canceled_by, canceled_at,
started_at, finished_at, created_at, updated_at
FROM usage_cleanup_tasks
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
if err != nil {
return nil, nil, err
}
defer rows.Close()
tasks := make([]service.UsageCleanupTask, 0)
for rows.Next() {
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var canceledBy sql.NullInt64
var canceledAt sql.NullTime
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := rows.Scan(
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&canceledBy,
&canceledAt,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
return nil, nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if canceledBy.Valid {
v := canceledBy.Int64
task.CanceledBy = &v
}
if canceledAt.Valid {
task.CanceledAt = &canceledAt.Time
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
tasks = append(tasks, task)
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return tasks, paginationResultFromTotal(total, params), nil
}
func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
if staleRunningAfterSeconds <= 0 {
staleRunningAfterSeconds = 1800
}
query := `
WITH next AS (
SELECT id
FROM usage_cleanup_tasks
WHERE status = $1
OR (
status = $2
AND started_at IS NOT NULL
AND started_at < NOW() - ($3 * interval '1 second')
)
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE usage_cleanup_tasks
SET status = $4,
started_at = NOW(),
finished_at = NULL,
error_message = NULL,
updated_at = NOW()
FROM next
WHERE usage_cleanup_tasks.id = next.id
RETURNING id, status, filters, created_by, deleted_rows, error_message,
started_at, finished_at, created_at, updated_at
`
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
staleRunningAfterSeconds,
service.UsageCleanupStatusRunning,
},
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
return &task, nil
}
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
var status string
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
return "", err
}
return status, nil
}
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
query := `
UPDATE usage_cleanup_tasks
SET deleted_rows = $1,
updated_at = NOW()
WHERE id = $2
`
_, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
canceled_by = $3,
canceled_at = NOW(),
finished_at = NOW(),
error_message = NULL,
updated_at = NOW()
WHERE id = $2
AND status IN ($4, $5)
RETURNING id
`
var id int64
err := scanSingleRow(ctx, r.sql, query, []any{
service.UsageCleanupStatusCanceled,
taskID,
canceledBy,
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
}, &id)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $3
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
error_message = $3,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $4
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
return err
}
func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
return 0, fmt.Errorf("cleanup filters missing time range")
}
whereClause, args := buildUsageCleanupWhere(filters)
if whereClause == "" {
return 0, fmt.Errorf("cleanup filters missing time range")
}
args = append(args, limit)
query := fmt.Sprintf(`
WITH target AS (
SELECT id
FROM usage_logs
WHERE %s
ORDER BY created_at ASC, id ASC
LIMIT $%d
)
DELETE FROM usage_logs
WHERE id IN (SELECT id FROM target)
RETURNING id
`, whereClause, len(args))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rows.Close()
var deleted int64
for rows.Next() {
deleted++
}
if err := rows.Err(); err != nil {
return 0, err
}
return deleted, nil
}
func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if !filters.StartTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
args = append(args, filters.StartTime)
idx++
}
if !filters.EndTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
args = append(args, filters.EndTime)
idx++
}
if filters.UserID != nil {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
args = append(args, *filters.UserID)
idx++
}
if filters.APIKeyID != nil {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
args = append(args, *filters.APIKeyID)
idx++
}
if filters.AccountID != nil {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
args = append(args, *filters.AccountID)
idx++
}
if filters.GroupID != nil {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
args = append(args, *filters.GroupID)
idx++
}
if filters.Model != nil {
model := strings.TrimSpace(*filters.Model)
if model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
args = append(args, model)
idx++
}
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream)
idx++
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
args = append(args, *filters.BillingType)
idx++
}
return strings.Join(conditions, " AND "), args
}

View File

@@ -0,0 +1,440 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
t.Helper()
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp))
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
return db, mock
}
func TestNewUsageCleanupRepository(t *testing.T) {
db, _ := newSQLMock(t)
repo := NewUsageCleanupRepository(db)
require.NotNil(t, repo)
}
func TestUsageCleanupRepositoryCreateTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
CreatedBy: 12,
}
now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now))
err := repo.CreateTask(context.Background(), task)
require.NoError(t, err)
require.Equal(t, int64(1), task.ID)
require.Equal(t, now, task.CreatedAt)
require.Equal(t, now, task.UpdatedAt)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
err := repo.CreateTask(context.Background(), nil)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)},
CreatedBy: 1,
}
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
WillReturnError(sql.ErrConnDone)
err := repo.CreateTask(context.Background(), task)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Empty(t, tasks)
require.Equal(t, int64(0), result.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasks(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(2 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
filtersJSON, err := json.Marshal(filters)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)
updatedAt := createdAt.Add(time.Minute)
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"canceled_by", "canceled_at",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(1),
service.UsageCleanupStatusSucceeded,
filtersJSON,
int64(2),
int64(9),
"error",
nil,
nil,
start,
end,
createdAt,
updatedAt,
)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
WithArgs(20, 0).
WillReturnRows(rows)
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, tasks, 1)
require.Equal(t, int64(1), tasks[0].ID)
require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status)
require.Equal(t, int64(2), tasks[0].CreatedBy)
require.Equal(t, int64(9), tasks[0].DeletedRows)
require.NotNil(t, tasks[0].ErrorMsg)
require.Equal(t, "error", *tasks[0].ErrorMsg)
require.NotNil(t, tasks[0].StartedAt)
require.NotNil(t, tasks[0].FinishedAt)
require.Equal(t, int64(1), result.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"canceled_by", "canceled_at",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(1),
service.UsageCleanupStatusSucceeded,
[]byte("not-json"),
int64(2),
int64(9),
nil,
nil,
nil,
nil,
nil,
time.Now().UTC(),
time.Now().UTC(),
)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
WithArgs(20, 0).
WillReturnRows(rows)
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}))
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.NoError(t, err)
require.Nil(t, task)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
filtersJSON, err := json.Marshal(filters)
require.NoError(t, err)
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(4),
service.UsageCleanupStatusRunning,
filtersJSON,
int64(7),
int64(0),
nil,
start,
nil,
start,
start,
)
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(rows)
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.NoError(t, err)
require.NotNil(t, task)
require.Equal(t, int64(4), task.ID)
require.Equal(t, service.UsageCleanupStatusRunning, task.Status)
require.Equal(t, int64(7), task.CreatedBy)
require.NotNil(t, task.StartedAt)
require.Nil(t, task.ErrorMsg)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnError(sql.ErrConnDone)
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(4),
service.UsageCleanupStatusRunning,
[]byte("invalid"),
int64(7),
int64(0),
nil,
nil,
nil,
time.Now().UTC(),
time.Now().UTC(),
)
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(rows)
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.MarkTaskSucceeded(context.Background(), 9, 12)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom")
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
WithArgs(int64(9)).
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending))
status, err := repo.GetTaskStatus(context.Background(), 9)
require.NoError(t, err)
require.Equal(t, service.UsageCleanupStatusPending, status)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(int64(123), int64(8)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.UpdateTaskProgress(context.Background(), 8, 123)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCancelTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6)))
ok, err := repo.CancelTask(context.Background(), 6, 9)
require.NoError(t, err)
require.True(t, ok)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) {
db, _ := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
_, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10)
require.Error(t, err)
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(3)
model := " gpt-4 "
filters := service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
Model: &model,
}
mock.ExpectQuery("DELETE FROM usage_logs").
WithArgs(start, end, userID, "gpt-4", 2).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2)))
deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2)
require.NoError(t, err)
require.Equal(t, int64(2), deleted)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
mock.ExpectQuery("DELETE FROM usage_logs").
WithArgs(start, end, 5).
WillReturnError(sql.ErrConnDone)
_, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildUsageCleanupWhere(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(1)
apiKeyID := int64(2)
accountID := int64(3)
groupID := int64(4)
model := " gpt-4 "
stream := true
billingType := int8(2)
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
APIKeyID: &apiKeyID,
AccountID: &accountID,
GroupID: &groupID,
Model: &model,
Stream: &stream,
BillingType: &billingType,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where)
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
}
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
model := " "
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
Model: &model,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2", where)
require.Equal(t, []any{start, end}, args)
}

View File

@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
} }
// GetUsageTrendWithFilters returns usage trend data with optional filters // GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1) query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream) args = append(args, *stream)
} }
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY date ORDER BY date ASC" query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
} }
// GetModelStatsWithFilters returns model statistics with optional filters // GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier // 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 { if accountID > 0 && userID == 0 && apiKeyID == 0 {
@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1) query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream) args = append(args, *stream)
} }
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY model ORDER BY total_tokens DESC" query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
} }
} }
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil) models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
if err != nil { if err != nil {
models = []ModelStat{} models = []ModelStat{}
} }

View File

@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour) endTime := base.Add(48 * time.Hour)
// Test with user filter // Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with apiKey filter // Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
// Test with both filters // Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil) trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour) startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour) endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil) trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2) s.Require().Len(trend, 2)
} }
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour) endTime := base.Add(2 * time.Hour)
// Test with user filter // Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil) stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with apiKey filter // Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
// Test with account filter // Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil) stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2) s.Require().Len(stats, 2)
} }

View File

@@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet(
NewRedeemCodeRepository, NewRedeemCodeRepository,
NewPromoCodeRepository, NewPromoCodeRepository,
NewUsageLogRepository, NewUsageLogRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository, NewDashboardAggregationRepository,
NewSettingRepository, NewSettingRepository,
NewOpsRepository, NewOpsRepository,

View File

@@ -1242,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }

View File

@@ -354,6 +354,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("/stats", h.Admin.Usage.Stats) usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers) usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys) usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
usage.GET("/cleanup-tasks", h.Admin.Usage.ListCleanupTasks)
usage.POST("/cleanup-tasks", h.Admin.Usage.CreateCleanupTask)
usage.POST("/cleanup-tasks/:id/cancel", h.Admin.Usage.CancelCleanupTask)
} }
} }

View File

@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats // Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -272,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
} }
dayStart := geminiDailyWindowStart(now) dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil) stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err) return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
} }
@@ -294,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute) minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute) minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil) minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
} }

View File

@@ -21,11 +21,15 @@ var (
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。 // ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
) )
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。 // DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface { type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error AggregateRange(ctx context.Context, start, end time.Time) error
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
RecomputeRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error) GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return nil return nil
} }
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
// 与 TriggerBackfill 不同:
// - 不依赖 backfill_enabled这是内部一致性修复
// - 不更新 watermark避免影响正常增量聚合游标
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.Enabled {
return errors.New("聚合服务已禁用")
}
if !end.After(start) {
return errors.New("重新计算时间范围无效")
}
go func() {
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
err := s.recomputeRange(ctx, start, end)
cancel()
if err == nil {
return
}
if !errors.Is(err, errDashboardAggregationRunning) {
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
return
}
time.Sleep(5 * time.Second)
}
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
}()
return nil
}
func (s *DashboardAggregationService) recomputeRecentDays() { func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays days := s.cfg.RecomputeDays
if days <= 0 { if days <= 0 {
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
} }
} }
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
return err
}
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
start.UTC().Format(time.RFC3339),
end.UTC().Format(time.RFC3339),
time.Since(jobStart).String(),
)
return nil
}
func (s *DashboardAggregationService) runScheduledAggregation() { func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return return
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error { func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errors.New("聚合作业正在运行") return errDashboardAggregationRunning
} }
defer atomic.StoreInt32(&s.running, 0) defer atomic.StoreInt32(&s.running, 0)

View File

@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
return s.aggregateErr return s.aggregateErr
} }
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return s.AggregateRange(ctx, start, end)
}
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil return s.watermark, nil
} }

View File

@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil return stats, nil
} }
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream) trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil { if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err) return nil, fmt.Errorf("get usage trend with filters: %w", err)
} }
return trend, nil return trend, nil
} }
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil { if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err) return nil, fmt.Errorf("get model stats with filters: %w", err)
} }

View File

@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return nil return nil
} }
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil { if s.err != nil {
return time.Time{}, s.err return time.Time{}, s.err

View File

@@ -190,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now) start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok { if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return true, err return true, err
} }
@@ -237,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 { if limit > 0 {
start := now.Truncate(time.Minute) start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return true, err return true, err
} }

View File

@@ -0,0 +1,74 @@
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
UsageCleanupStatusPending = "pending"
UsageCleanupStatusRunning = "running"
UsageCleanupStatusSucceeded = "succeeded"
UsageCleanupStatusFailed = "failed"
UsageCleanupStatusCanceled = "canceled"
)
// UsageCleanupFilters 定义清理任务过滤条件
// 时间范围为必填,其他字段可选
// JSON 序列化用于存储任务参数
//
// start_time/end_time 使用 RFC3339 时间格式
// 以 UTC 或用户时区解析后的时间为准
//
// 说明:
// - nil 表示未设置该过滤条件
// - 过滤条件均为精确匹配
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
// UsageCleanupTask 表示使用记录清理任务
// 状态包含 pending/running/succeeded/failed/canceled
type UsageCleanupTask struct {
ID int64
Status string
Filters UsageCleanupFilters
CreatedBy int64
DeletedRows int64
ErrorMsg *string
CanceledBy *int64
CanceledAt *time.Time
StartedAt *time.Time
FinishedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
// UsageCleanupRepository 定义清理任务持久层接口
type UsageCleanupRepository interface {
CreateTask(ctx context.Context, task *UsageCleanupTask) error
ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error)
// ClaimNextPendingTask 抢占下一条可执行任务:
// - 优先 pending
// - 若 running 超过 staleRunningAfterSeconds可能由于进程退出/崩溃/超时),允许重新抢占继续执行
ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error)
// GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows
GetTaskStatus(ctx context.Context, taskID int64) (string, error)
// UpdateTaskProgress 更新任务进度deleted_rows用于断点续跑/展示
UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error
// CancelTask 将任务标记为 canceled仅允许 pending/running
CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error)
MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error
MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error
DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error)
}

View File

@@ -0,0 +1,400 @@
package service
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
usageCleanupWorkerName = "usage_cleanup_worker"
)
// UsageCleanupService 负责创建与执行使用记录清理任务
type UsageCleanupService struct {
repo UsageCleanupRepository
timingWheel *TimingWheelService
dashboard *DashboardAggregationService
cfg *config.Config
running int32
startOnce sync.Once
stopOnce sync.Once
workerCtx context.Context
workerCancel context.CancelFunc
}
func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
workerCtx, workerCancel := context.WithCancel(context.Background())
return &UsageCleanupService{
repo: repo,
timingWheel: timingWheel,
dashboard: dashboard,
cfg: cfg,
workerCtx: workerCtx,
workerCancel: workerCancel,
}
}
func describeUsageCleanupFilters(filters UsageCleanupFilters) string {
var parts []string
parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339))
parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339))
if filters.UserID != nil {
parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID))
}
if filters.APIKeyID != nil {
parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID))
}
if filters.AccountID != nil {
parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID))
}
if filters.GroupID != nil {
parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID))
}
if filters.Model != nil {
parts = append(parts, "model="+strings.TrimSpace(*filters.Model))
}
if filters.Stream != nil {
parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream))
}
if filters.BillingType != nil {
parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType))
}
return strings.Join(parts, " ")
}
func (s *UsageCleanupService) Start() {
if s == nil {
return
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
log.Printf("[UsageCleanup] not started (disabled)")
return
}
if s.repo == nil || s.timingWheel == nil {
log.Printf("[UsageCleanup] not started (missing deps)")
return
}
interval := s.workerInterval()
s.startOnce.Do(func() {
s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce)
log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
})
}
func (s *UsageCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.workerCancel != nil {
s.workerCancel()
}
if s.timingWheel != nil {
s.timingWheel.Cancel(usageCleanupWorkerName)
}
log.Printf("[UsageCleanup] stopped")
})
}
func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
if s == nil || s.repo == nil {
return nil, nil, fmt.Errorf("cleanup service not ready")
}
return s.repo.ListTasks(ctx, params)
}
func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) {
if s == nil || s.repo == nil {
return nil, fmt.Errorf("cleanup service not ready")
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
}
if createdBy <= 0 {
return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator")
}
log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
sanitizeUsageCleanupFilters(&filters)
if err := s.validateFilters(filters); err != nil {
log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
return nil, err
}
task := &UsageCleanupTask{
Status: UsageCleanupStatusPending,
Filters: filters,
CreatedBy: createdBy,
}
if err := s.repo.CreateTask(ctx, task); err != nil {
log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
return nil, fmt.Errorf("create cleanup task: %w", err)
}
log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
go s.runOnce()
return task, nil
}
func (s *UsageCleanupService) runOnce() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
return
}
defer atomic.StoreInt32(&s.running, 0)
parent := context.Background()
if s != nil && s.workerCtx != nil {
parent = s.workerCtx
}
ctx, cancel := context.WithTimeout(parent, s.taskTimeout())
defer cancel()
task, err := s.repo.ClaimNextPendingTask(ctx, int64(s.taskTimeout().Seconds()))
if err != nil {
log.Printf("[UsageCleanup] claim pending task failed: %v", err)
return
}
if task == nil {
log.Printf("[UsageCleanup] run_once done: no_task=true")
return
}
log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
s.executeTask(ctx, task)
}
func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) {
if task == nil {
return
}
batchSize := s.batchSize()
deletedTotal := task.DeletedRows
start := time.Now()
log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
var batchNum int
for {
if ctx != nil && ctx.Err() != nil {
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
return
}
canceled, err := s.isTaskCanceled(ctx, task.ID)
if err != nil {
s.markTaskFailed(task.ID, deletedTotal, err)
return
}
if canceled {
log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
return
}
batchNum++
deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
return
}
s.markTaskFailed(task.ID, deletedTotal, err)
return
}
deletedTotal += deleted
if deleted > 0 {
updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil {
log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
}
cancel()
}
if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) {
log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
}
if deleted == 0 || deleted < int64(batchSize) {
break
}
}
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil {
log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
} else {
log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
}
if s.dashboard != nil {
if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil {
log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
} else {
log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
}
}
}
func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) {
msg := strings.TrimSpace(err.Error())
if len(msg) > 500 {
msg = msg[:500]
}
log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil {
log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
}
}
func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) {
if s == nil || s.repo == nil {
return false, fmt.Errorf("cleanup service not ready")
}
checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
status, err := s.repo.GetTaskStatus(checkCtx, taskID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
if status == UsageCleanupStatusCanceled {
log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID)
}
return status == UsageCleanupStatusCanceled, nil
}
func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error {
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required")
}
if filters.EndTime.Before(filters.StartTime) {
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date")
}
maxDays := s.maxRangeDays()
if maxDays > 0 {
delta := filters.EndTime.Sub(filters.StartTime)
if delta > time.Duration(maxDays)*24*time.Hour {
return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays))
}
}
return nil
}
func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error {
if s == nil || s.repo == nil {
return fmt.Errorf("cleanup service not ready")
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
}
if canceledBy <= 0 {
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller")
}
status, err := s.repo.GetTaskStatus(ctx, taskID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found")
}
return err
}
log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
ok, err := s.repo.CancelTask(ctx, taskID, canceledBy)
if err != nil {
return err
}
if !ok {
// 状态可能并发改变
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
return nil
}
func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) {
if filters == nil {
return
}
if filters.UserID != nil && *filters.UserID <= 0 {
filters.UserID = nil
}
if filters.APIKeyID != nil && *filters.APIKeyID <= 0 {
filters.APIKeyID = nil
}
if filters.AccountID != nil && *filters.AccountID <= 0 {
filters.AccountID = nil
}
if filters.GroupID != nil && *filters.GroupID <= 0 {
filters.GroupID = nil
}
if filters.Model != nil {
model := strings.TrimSpace(*filters.Model)
if model == "" {
filters.Model = nil
} else {
filters.Model = &model
}
}
if filters.BillingType != nil && *filters.BillingType < 0 {
filters.BillingType = nil
}
}
func (s *UsageCleanupService) maxRangeDays() int {
if s == nil || s.cfg == nil {
return 31
}
if s.cfg.UsageCleanup.MaxRangeDays > 0 {
return s.cfg.UsageCleanup.MaxRangeDays
}
return 31
}
func (s *UsageCleanupService) batchSize() int {
if s == nil || s.cfg == nil {
return 5000
}
if s.cfg.UsageCleanup.BatchSize > 0 {
return s.cfg.UsageCleanup.BatchSize
}
return 5000
}
func (s *UsageCleanupService) workerInterval() time.Duration {
if s == nil || s.cfg == nil {
return 10 * time.Second
}
if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 {
return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second
}
return 10 * time.Second
}
func (s *UsageCleanupService) taskTimeout() time.Duration {
if s == nil || s.cfg == nil {
return 30 * time.Minute
}
if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 {
return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second
}
return 30 * time.Minute
}

View File

@@ -0,0 +1,420 @@
package service
import (
"context"
"database/sql"
"errors"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type cleanupDeleteResponse struct {
deleted int64
err error
}
type cleanupDeleteCall struct {
filters UsageCleanupFilters
limit int
}
type cleanupMarkCall struct {
taskID int64
deletedRows int64
errMsg string
}
type cleanupRepoStub struct {
mu sync.Mutex
created []*UsageCleanupTask
createErr error
listTasks []UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
claimQueue []*UsageCleanupTask
claimErr error
deleteQueue []cleanupDeleteResponse
deleteCalls []cleanupDeleteCall
markSucceeded []cleanupMarkCall
markFailed []cleanupMarkCall
statusByID map[int64]string
progressCalls []cleanupMarkCall
cancelCalls []int64
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.createErr != nil {
return s.createErr
}
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
if task.UpdatedAt.IsZero() {
task.UpdatedAt = task.CreatedAt
}
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.claimErr != nil {
return nil, s.claimErr
}
if len(s.claimQueue) == 0 {
return nil, nil
}
task := s.claimQueue[0]
s.claimQueue = s.claimQueue[1:]
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[task.ID] = UsageCleanupStatusRunning
return task, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
s.mu.Lock()
defer s.mu.Unlock()
s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.cancelCalls = append(s.cancelCalls, taskID)
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
s.mu.Lock()
defer s.mu.Unlock()
s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusSucceeded
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg})
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusFailed
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit})
if len(s.deleteQueue) == 0 {
return 0, nil
}
resp := s.deleteQueue[0]
s.deleteQueue = s.deleteQueue[1:]
return resp.deleted, resp.err
}
func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(-1)
apiKeyID := int64(10)
model := " gpt-4 "
billingType := int8(-2)
filters := UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
APIKeyID: &apiKeyID,
Model: &model,
BillingType: &billingType,
}
task, err := svc.CreateTask(context.Background(), filters, 9)
require.NoError(t, err)
require.Equal(t, UsageCleanupStatusPending, task.Status)
require.Nil(t, task.Filters.UserID)
require.NotNil(t, task.Filters.APIKeyID)
require.Equal(t, apiKeyID, *task.Filters.APIKeyID)
require.NotNil(t, task.Filters.Model)
require.Equal(t, "gpt-4", *task.Filters.Model)
require.Nil(t, task.Filters.BillingType)
require.Equal(t, int64(9), task.CreatedBy)
}
func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 0)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(48 * time.Hour)
filters := UsageCleanupFilters{StartTime: start, EndTime: end}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
_, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
repo := &cleanupRepoStub{createErr: errors.New("db down")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "create cleanup task")
}
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
repo := &cleanupRepoStub{
claimQueue: []*UsageCleanupTask{
{ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}},
},
deleteQueue: []cleanupDeleteResponse{
{deleted: 2},
{deleted: 2},
{deleted: 1},
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.runOnce()
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.deleteCalls, 3)
require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows)
}
func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) {
repo := &cleanupRepoStub{claimErr: errors.New("claim failed")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.runOnce()
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.markSucceeded)
require.Empty(t, repo.markFailed)
}
func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.running = 1
svc.runOnce()
}
func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) {
longMsg := strings.Repeat("x", 600)
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{err: errors.New(longMsg)},
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 11,
Filters: UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markFailed, 1)
require.Equal(t, int64(11), repo.markFailed[0].taskID)
require.Equal(t, 500, len(repo.markFailed[0].errMsg))
}
func TestUsageCleanupServiceListTasks(t *testing.T) {
repo := &cleanupRepoStub{
listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}},
listResult: &pagination.PaginationResult{
Total: 2,
Page: 1,
PageSize: 20,
Pages: 1,
},
}
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, tasks, 2)
require.Equal(t, int64(2), result.Total)
}
func TestUsageCleanupServiceListTasksNotReady(t *testing.T) {
var nilSvc *UsageCleanupService
_, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
_, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
}
func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) {
var nilSvc *UsageCleanupService
require.Equal(t, 31, nilSvc.maxRangeDays())
require.Equal(t, 5000, nilSvc.batchSize())
require.Equal(t, 10*time.Second, nilSvc.workerInterval())
require.Equal(t, 30*time.Minute, nilSvc.taskTimeout())
nilSvc.Start()
nilSvc.Stop()
repo := &cleanupRepoStub{}
cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled)
svcDisabled.Start()
svcDisabled.Stop()
timingWheel, err := NewTimingWheelService()
require.NoError(t, err)
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}}
svc := NewUsageCleanupService(repo, timingWheel, nil, cfg)
require.Equal(t, 5*time.Second, svc.workerInterval())
svc.Start()
svc.Stop()
cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback)
require.Equal(t, 31, svcFallback.maxRangeDays())
require.Equal(t, 5000, svcFallback.batchSize())
require.Equal(t, 10*time.Second, svcFallback.workerInterval())
svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback)
svcMissingDeps.Start()
}
func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) {
model := " "
apiKeyID := int64(-5)
accountID := int64(-1)
groupID := int64(-2)
filters := UsageCleanupFilters{
UserID: &apiKeyID,
APIKeyID: &apiKeyID,
AccountID: &accountID,
GroupID: &groupID,
Model: &model,
}
sanitizeUsageCleanupFilters(&filters)
require.Nil(t, filters.UserID)
require.Nil(t, filters.APIKeyID)
require.Nil(t, filters.AccountID)
require.Nil(t, filters.GroupID)
require.Nil(t, filters.Model)
}

View File

@@ -57,6 +57,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim
return svc return svc
} }
// ProvideUsageCleanupService 创建并启动使用记录清理任务服务
func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg)
svc.Start()
return svc
}
// ProvideAccountExpiryService creates and starts AccountExpiryService. // ProvideAccountExpiryService creates and starts AccountExpiryService.
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService { func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
svc := NewAccountExpiryService(accountRepo, time.Minute) svc := NewAccountExpiryService(accountRepo, time.Minute)
@@ -248,6 +255,7 @@ var ProviderSet = wire.NewSet(
ProvideAccountExpiryService, ProvideAccountExpiryService,
ProvideTimingWheelService, ProvideTimingWheelService,
ProvideDashboardAggregationService, ProvideDashboardAggregationService,
ProvideUsageCleanupService,
ProvideDeferredService, ProvideDeferredService,
NewAntigravityQuotaFetcher, NewAntigravityQuotaFetcher,
NewUserAttributeService, NewUserAttributeService,

View File

@@ -0,0 +1,21 @@
-- 042_add_usage_cleanup_tasks.sql
-- 使用记录清理任务表
CREATE TABLE IF NOT EXISTS usage_cleanup_tasks (
id BIGSERIAL PRIMARY KEY,
status VARCHAR(20) NOT NULL,
filters JSONB NOT NULL,
created_by BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
deleted_rows BIGINT NOT NULL DEFAULT 0,
error_message TEXT,
started_at TIMESTAMPTZ,
finished_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_status_created_at
ON usage_cleanup_tasks(status, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_created_at
ON usage_cleanup_tasks(created_at DESC);

View File

@@ -0,0 +1,10 @@
-- 043_add_usage_cleanup_cancel_audit.sql
-- usage_cleanup_tasks 取消任务审计字段
ALTER TABLE usage_cleanup_tasks
ADD COLUMN IF NOT EXISTS canceled_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
ADD COLUMN IF NOT EXISTS canceled_at TIMESTAMPTZ;
CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_canceled_at
ON usage_cleanup_tasks(canceled_at DESC);

View File

@@ -251,6 +251,27 @@ dashboard_aggregation:
# 日聚合保留天数 # 日聚合保留天数
daily_days: 730 daily_days: 730
# =============================================================================
# Usage Cleanup Task Configuration
# 使用记录清理任务配置(重启生效)
# =============================================================================
usage_cleanup:
# Enable cleanup task worker
# 启用清理任务执行器
enabled: true
# Max date range (days) per task
# 单次任务最大时间跨度(天)
max_range_days: 31
# Batch delete size
# 单批删除数量
batch_size: 5000
# Worker interval (seconds)
# 执行器轮询间隔(秒)
worker_interval_seconds: 10
# Task execution timeout (seconds)
# 单次任务最大执行时长(秒)
task_timeout_seconds: 1800
# ============================================================================= # =============================================================================
# Concurrency Wait Configuration # Concurrency Wait Configuration
# 并发等待配置 # 并发等待配置

View File

@@ -292,6 +292,27 @@ dashboard_aggregation:
# 日聚合保留天数 # 日聚合保留天数
daily_days: 730 daily_days: 730
# =============================================================================
# Usage Cleanup Task Configuration
# 使用记录清理任务配置(重启生效)
# =============================================================================
usage_cleanup:
# Enable cleanup task worker
# 启用清理任务执行器
enabled: true
# Max date range (days) per task
# 单次任务最大时间跨度(天)
max_range_days: 31
# Batch delete size
# 单批删除数量
batch_size: 5000
# Worker interval (seconds)
# 执行器轮询间隔(秒)
worker_interval_seconds: 10
# Task execution timeout (seconds)
# 单次任务最大执行时长(秒)
task_timeout_seconds: 1800
# ============================================================================= # =============================================================================
# Concurrency Wait Configuration # Concurrency Wait Configuration
# 并发等待配置 # 并发等待配置

View File

@@ -50,6 +50,7 @@ export interface TrendParams {
account_id?: number account_id?: number
group_id?: number group_id?: number
stream?: boolean stream?: boolean
billing_type?: number | null
} }
export interface TrendResponse { export interface TrendResponse {
@@ -78,6 +79,7 @@ export interface ModelStatsParams {
account_id?: number account_id?: number
group_id?: number group_id?: number
stream?: boolean stream?: boolean
billing_type?: number | null
} }
export interface ModelStatsResponse { export interface ModelStatsResponse {

View File

@@ -31,6 +31,46 @@ export interface SimpleApiKey {
user_id: number user_id: number
} }
export interface UsageCleanupFilters {
start_time: string
end_time: string
user_id?: number
api_key_id?: number
account_id?: number
group_id?: number
model?: string | null
stream?: boolean | null
billing_type?: number | null
}
export interface UsageCleanupTask {
id: number
status: string
filters: UsageCleanupFilters
created_by: number
deleted_rows: number
error_message?: string | null
canceled_by?: number | null
canceled_at?: string | null
started_at?: string | null
finished_at?: string | null
created_at: string
updated_at: string
}
export interface CreateUsageCleanupTaskRequest {
start_date: string
end_date: string
user_id?: number
api_key_id?: number
account_id?: number
group_id?: number
model?: string | null
stream?: boolean | null
billing_type?: number | null
timezone?: string
}
export interface AdminUsageQueryParams extends UsageQueryParams { export interface AdminUsageQueryParams extends UsageQueryParams {
user_id?: number user_id?: number
} }
@@ -108,11 +148,51 @@ export async function searchApiKeys(userId?: number, keyword?: string): Promise<
return data return data
} }
/**
* List usage cleanup tasks (admin only)
* @param params - Query parameters for pagination
* @returns Paginated list of cleanup tasks
*/
export async function listCleanupTasks(
params: { page?: number; page_size?: number },
options?: { signal?: AbortSignal }
): Promise<PaginatedResponse<UsageCleanupTask>> {
const { data } = await apiClient.get<PaginatedResponse<UsageCleanupTask>>('/admin/usage/cleanup-tasks', {
params,
signal: options?.signal
})
return data
}
/**
* Create a usage cleanup task (admin only)
* @param payload - Cleanup task parameters
* @returns Created cleanup task
*/
export async function createCleanupTask(payload: CreateUsageCleanupTaskRequest): Promise<UsageCleanupTask> {
const { data } = await apiClient.post<UsageCleanupTask>('/admin/usage/cleanup-tasks', payload)
return data
}
/**
* Cancel a usage cleanup task (admin only)
* @param taskId - Task ID to cancel
*/
export async function cancelCleanupTask(taskId: number): Promise<{ id: number; status: string }> {
const { data } = await apiClient.post<{ id: number; status: string }>(
`/admin/usage/cleanup-tasks/${taskId}/cancel`
)
return data
}
export const adminUsageAPI = { export const adminUsageAPI = {
list, list,
getStats, getStats,
searchUsers, searchUsers,
searchApiKeys searchApiKeys,
listCleanupTasks,
createCleanupTask,
cancelCleanupTask
} }
export default adminUsageAPI export default adminUsageAPI

View File

@@ -0,0 +1,339 @@
<template>
<BaseDialog :show="show" :title="t('admin.usage.cleanup.title')" width="wide" @close="handleClose">
<div class="space-y-4">
<UsageFilters
v-model="localFilters"
v-model:startDate="localStartDate"
v-model:endDate="localEndDate"
:exporting="false"
:show-actions="false"
@change="noop"
/>
<div class="rounded-xl border border-amber-200 bg-amber-50 px-4 py-3 text-sm text-amber-700 dark:border-amber-500/30 dark:bg-amber-500/10 dark:text-amber-200">
{{ t('admin.usage.cleanup.warning') }}
</div>
<div class="rounded-xl border border-gray-200 p-4 dark:border-dark-700">
<div class="flex items-center justify-between">
<h4 class="text-sm font-semibold text-gray-700 dark:text-gray-200">
{{ t('admin.usage.cleanup.recentTasks') }}
</h4>
<button type="button" class="btn btn-ghost btn-sm" @click="loadTasks">
{{ t('common.refresh') }}
</button>
</div>
<div class="mt-3 space-y-2">
<div v-if="tasksLoading" class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.usage.cleanup.loadingTasks') }}
</div>
<div v-else-if="tasks.length === 0" class="text-sm text-gray-500 dark:text-gray-400">
{{ t('admin.usage.cleanup.noTasks') }}
</div>
<div v-else class="space-y-2">
<div
v-for="task in tasks"
:key="task.id"
class="flex flex-col gap-2 rounded-lg border border-gray-100 px-3 py-2 text-sm text-gray-600 dark:border-dark-700 dark:text-gray-300"
>
<div class="flex flex-wrap items-center justify-between gap-2">
<div class="flex items-center gap-2">
<span :class="statusClass(task.status)" class="rounded-full px-2 py-0.5 text-xs font-semibold">
{{ statusLabel(task.status) }}
</span>
<span class="text-xs text-gray-400">#{{ task.id }}</span>
<button
v-if="canCancel(task)"
type="button"
class="btn btn-ghost btn-xs text-rose-600 hover:text-rose-700 dark:text-rose-300"
@click="openCancelConfirm(task)"
>
{{ t('admin.usage.cleanup.cancel') }}
</button>
</div>
<div class="text-xs text-gray-400">
{{ formatDateTime(task.created_at) }}
</div>
</div>
<div class="flex flex-wrap items-center gap-4 text-xs text-gray-500 dark:text-gray-400">
<span>{{ t('admin.usage.cleanup.range') }}: {{ formatRange(task) }}</span>
<span>{{ t('admin.usage.cleanup.deletedRows') }}: {{ task.deleted_rows.toLocaleString() }}</span>
</div>
<div v-if="task.error_message" class="text-xs text-rose-500">
{{ task.error_message }}
</div>
</div>
</div>
</div>
</div>
</div>
<template #footer>
<div class="flex justify-end gap-3">
<button type="button" class="btn btn-secondary" @click="handleClose">
{{ t('common.cancel') }}
</button>
<button type="button" class="btn btn-danger" :disabled="submitting" @click="openConfirm">
{{ submitting ? t('admin.usage.cleanup.submitting') : t('admin.usage.cleanup.submit') }}
</button>
</div>
</template>
</BaseDialog>
<ConfirmDialog
:show="confirmVisible"
:title="t('admin.usage.cleanup.confirmTitle')"
:message="t('admin.usage.cleanup.confirmMessage')"
:confirm-text="t('admin.usage.cleanup.confirmSubmit')"
danger
@confirm="submitCleanup"
@cancel="confirmVisible = false"
/>
<ConfirmDialog
:show="cancelConfirmVisible"
:title="t('admin.usage.cleanup.cancelConfirmTitle')"
:message="t('admin.usage.cleanup.cancelConfirmMessage')"
:confirm-text="t('admin.usage.cleanup.cancelConfirm')"
danger
@confirm="cancelTask"
@cancel="cancelConfirmVisible = false"
/>
</template>
<script setup lang="ts">
import { ref, watch, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import UsageFilters from '@/components/admin/usage/UsageFilters.vue'
import { adminUsageAPI } from '@/api/admin/usage'
import type { AdminUsageQueryParams, UsageCleanupTask, CreateUsageCleanupTaskRequest } from '@/api/admin/usage'
interface Props {
show: boolean
filters: AdminUsageQueryParams
startDate: string
endDate: string
}
const props = defineProps<Props>()
const emit = defineEmits(['close'])
const { t } = useI18n()
const appStore = useAppStore()
const localFilters = ref<AdminUsageQueryParams>({})
const localStartDate = ref('')
const localEndDate = ref('')
const tasks = ref<UsageCleanupTask[]>([])
const tasksLoading = ref(false)
const submitting = ref(false)
const confirmVisible = ref(false)
const cancelConfirmVisible = ref(false)
const canceling = ref(false)
const cancelTarget = ref<UsageCleanupTask | null>(null)
let pollTimer: number | null = null
const noop = () => {}
const resetFilters = () => {
localFilters.value = { ...props.filters }
localStartDate.value = props.startDate
localEndDate.value = props.endDate
localFilters.value.start_date = localStartDate.value
localFilters.value.end_date = localEndDate.value
}
const startPolling = () => {
stopPolling()
pollTimer = window.setInterval(() => {
loadTasks()
}, 10000)
}
const stopPolling = () => {
if (pollTimer !== null) {
window.clearInterval(pollTimer)
pollTimer = null
}
}
const handleClose = () => {
stopPolling()
confirmVisible.value = false
cancelConfirmVisible.value = false
canceling.value = false
cancelTarget.value = null
submitting.value = false
emit('close')
}
const statusLabel = (status: string) => {
const map: Record<string, string> = {
pending: t('admin.usage.cleanup.status.pending'),
running: t('admin.usage.cleanup.status.running'),
succeeded: t('admin.usage.cleanup.status.succeeded'),
failed: t('admin.usage.cleanup.status.failed'),
canceled: t('admin.usage.cleanup.status.canceled')
}
return map[status] || status
}
const statusClass = (status: string) => {
const map: Record<string, string> = {
pending: 'bg-amber-100 text-amber-700 dark:bg-amber-500/20 dark:text-amber-200',
running: 'bg-blue-100 text-blue-700 dark:bg-blue-500/20 dark:text-blue-200',
succeeded: 'bg-emerald-100 text-emerald-700 dark:bg-emerald-500/20 dark:text-emerald-200',
failed: 'bg-rose-100 text-rose-700 dark:bg-rose-500/20 dark:text-rose-200',
canceled: 'bg-gray-200 text-gray-600 dark:bg-dark-600 dark:text-gray-300'
}
return map[status] || 'bg-gray-100 text-gray-600'
}
const formatDateTime = (value?: string | null) => {
if (!value) return '--'
const date = new Date(value)
if (Number.isNaN(date.getTime())) return value
return date.toLocaleString()
}
const formatRange = (task: UsageCleanupTask) => {
const start = formatDateTime(task.filters.start_time)
const end = formatDateTime(task.filters.end_time)
return `${start} ~ ${end}`
}
const getUserTimezone = () => {
try {
return Intl.DateTimeFormat().resolvedOptions().timeZone
} catch {
return 'UTC'
}
}
const loadTasks = async () => {
if (!props.show) return
tasksLoading.value = true
try {
const res = await adminUsageAPI.listCleanupTasks({ page: 1, page_size: 10 })
tasks.value = res.items || []
} catch (error) {
console.error('Failed to load cleanup tasks:', error)
appStore.showError(t('admin.usage.cleanup.loadFailed'))
} finally {
tasksLoading.value = false
}
}
const openConfirm = () => {
confirmVisible.value = true
}
const canCancel = (task: UsageCleanupTask) => {
return task.status === 'pending' || task.status === 'running'
}
const openCancelConfirm = (task: UsageCleanupTask) => {
cancelTarget.value = task
cancelConfirmVisible.value = true
}
const buildPayload = (): CreateUsageCleanupTaskRequest | null => {
if (!localStartDate.value || !localEndDate.value) {
appStore.showError(t('admin.usage.cleanup.missingRange'))
return null
}
const payload: CreateUsageCleanupTaskRequest = {
start_date: localStartDate.value,
end_date: localEndDate.value,
timezone: getUserTimezone()
}
if (localFilters.value.user_id && localFilters.value.user_id > 0) {
payload.user_id = localFilters.value.user_id
}
if (localFilters.value.api_key_id && localFilters.value.api_key_id > 0) {
payload.api_key_id = localFilters.value.api_key_id
}
if (localFilters.value.account_id && localFilters.value.account_id > 0) {
payload.account_id = localFilters.value.account_id
}
if (localFilters.value.group_id && localFilters.value.group_id > 0) {
payload.group_id = localFilters.value.group_id
}
if (localFilters.value.model) {
payload.model = localFilters.value.model
}
if (localFilters.value.stream !== null && localFilters.value.stream !== undefined) {
payload.stream = localFilters.value.stream
}
if (localFilters.value.billing_type !== null && localFilters.value.billing_type !== undefined) {
payload.billing_type = localFilters.value.billing_type
}
return payload
}
const submitCleanup = async () => {
const payload = buildPayload()
if (!payload) {
confirmVisible.value = false
return
}
submitting.value = true
confirmVisible.value = false
try {
await adminUsageAPI.createCleanupTask(payload)
appStore.showSuccess(t('admin.usage.cleanup.submitSuccess'))
loadTasks()
} catch (error) {
console.error('Failed to create cleanup task:', error)
appStore.showError(t('admin.usage.cleanup.submitFailed'))
} finally {
submitting.value = false
}
}
const cancelTask = async () => {
const task = cancelTarget.value
if (!task) {
cancelConfirmVisible.value = false
return
}
canceling.value = true
cancelConfirmVisible.value = false
try {
await adminUsageAPI.cancelCleanupTask(task.id)
appStore.showSuccess(t('admin.usage.cleanup.cancelSuccess'))
loadTasks()
} catch (error) {
console.error('Failed to cancel cleanup task:', error)
appStore.showError(t('admin.usage.cleanup.cancelFailed'))
} finally {
canceling.value = false
cancelTarget.value = null
}
}
watch(
() => props.show,
(show) => {
if (show) {
resetFilters()
loadTasks()
startPolling()
} else {
stopPolling()
}
}
)
onUnmounted(() => {
stopPolling()
})
</script>

View File

@@ -127,6 +127,12 @@
<Select v-model="filters.stream" :options="streamTypeOptions" @change="emitChange" /> <Select v-model="filters.stream" :options="streamTypeOptions" @change="emitChange" />
</div> </div>
<!-- Billing Type Filter -->
<div class="w-full sm:w-auto sm:min-w-[200px]">
<label class="input-label">{{ t('admin.usage.billingType') }}</label>
<Select v-model="filters.billing_type" :options="billingTypeOptions" @change="emitChange" />
</div>
<!-- Group Filter --> <!-- Group Filter -->
<div class="w-full sm:w-auto sm:min-w-[200px]"> <div class="w-full sm:w-auto sm:min-w-[200px]">
<label class="input-label">{{ t('admin.usage.group') }}</label> <label class="input-label">{{ t('admin.usage.group') }}</label>
@@ -147,10 +153,13 @@
</div> </div>
<!-- Right: actions --> <!-- Right: actions -->
<div class="flex w-full flex-wrap items-center justify-end gap-3 sm:w-auto"> <div v-if="showActions" class="flex w-full flex-wrap items-center justify-end gap-3 sm:w-auto">
<button type="button" @click="$emit('reset')" class="btn btn-secondary"> <button type="button" @click="$emit('reset')" class="btn btn-secondary">
{{ t('common.reset') }} {{ t('common.reset') }}
</button> </button>
<button type="button" @click="$emit('cleanup')" class="btn btn-danger">
{{ t('admin.usage.cleanup.button') }}
</button>
<button type="button" @click="$emit('export')" :disabled="exporting" class="btn btn-primary"> <button type="button" @click="$emit('export')" :disabled="exporting" class="btn btn-primary">
{{ t('usage.exportExcel') }} {{ t('usage.exportExcel') }}
</button> </button>
@@ -174,16 +183,20 @@ interface Props {
exporting: boolean exporting: boolean
startDate: string startDate: string
endDate: string endDate: string
showActions?: boolean
} }
const props = defineProps<Props>() const props = withDefaults(defineProps<Props>(), {
showActions: true
})
const emit = defineEmits([ const emit = defineEmits([
'update:modelValue', 'update:modelValue',
'update:startDate', 'update:startDate',
'update:endDate', 'update:endDate',
'change', 'change',
'reset', 'reset',
'export' 'export',
'cleanup'
]) ])
const { t } = useI18n() const { t } = useI18n()
@@ -221,6 +234,12 @@ const streamTypeOptions = ref<SelectOption[]>([
{ value: false, label: t('usage.sync') } { value: false, label: t('usage.sync') }
]) ])
const billingTypeOptions = ref<SelectOption[]>([
{ value: null, label: t('admin.usage.allBillingTypes') },
{ value: 0, label: t('admin.usage.billingTypeBalance') },
{ value: 1, label: t('admin.usage.billingTypeSubscription') }
])
const emitChange = () => emit('change') const emitChange = () => emit('change')
const updateStartDate = (value: string) => { const updateStartDate = (value: string) => {

View File

@@ -1893,7 +1893,43 @@ export default {
cacheCreationTokens: 'Cache Creation Tokens', cacheCreationTokens: 'Cache Creation Tokens',
cacheReadTokens: 'Cache Read Tokens', cacheReadTokens: 'Cache Read Tokens',
failedToLoad: 'Failed to load usage records', failedToLoad: 'Failed to load usage records',
ipAddress: 'IP' billingType: 'Billing Type',
allBillingTypes: 'All Billing Types',
billingTypeBalance: 'Balance',
billingTypeSubscription: 'Subscription',
ipAddress: 'IP',
cleanup: {
button: 'Cleanup',
title: 'Cleanup Usage Records',
warning: 'Cleanup is irreversible and will affect historical stats.',
submit: 'Submit Cleanup',
submitting: 'Submitting...',
confirmTitle: 'Confirm Cleanup',
confirmMessage: 'Are you sure you want to submit this cleanup task? This action cannot be undone.',
confirmSubmit: 'Confirm Cleanup',
cancel: 'Cancel',
cancelConfirmTitle: 'Confirm Cancel',
cancelConfirmMessage: 'Are you sure you want to cancel this cleanup task?',
cancelConfirm: 'Confirm Cancel',
cancelSuccess: 'Cleanup task canceled',
cancelFailed: 'Failed to cancel cleanup task',
recentTasks: 'Recent Cleanup Tasks',
loadingTasks: 'Loading tasks...',
noTasks: 'No cleanup tasks yet',
range: 'Range',
deletedRows: 'Deleted',
missingRange: 'Please select a date range',
submitSuccess: 'Cleanup task created',
submitFailed: 'Failed to create cleanup task',
loadFailed: 'Failed to load cleanup tasks',
status: {
pending: 'Pending',
running: 'Running',
succeeded: 'Succeeded',
failed: 'Failed',
canceled: 'Canceled'
}
}
}, },
// Ops Monitoring // Ops Monitoring

View File

@@ -2041,7 +2041,43 @@ export default {
cacheCreationTokens: '缓存创建 Token', cacheCreationTokens: '缓存创建 Token',
cacheReadTokens: '缓存读取 Token', cacheReadTokens: '缓存读取 Token',
failedToLoad: '加载使用记录失败', failedToLoad: '加载使用记录失败',
ipAddress: 'IP' billingType: '计费类型',
allBillingTypes: '全部计费类型',
billingTypeBalance: '钱包余额',
billingTypeSubscription: '订阅套餐',
ipAddress: 'IP',
cleanup: {
button: '清理',
title: '清理使用记录',
warning: '清理不可恢复,且会影响历史统计回看。',
submit: '提交清理',
submitting: '提交中...',
confirmTitle: '确认清理',
confirmMessage: '确定要提交清理任务吗?清理不可恢复。',
confirmSubmit: '确认清理',
cancel: '取消任务',
cancelConfirmTitle: '确认取消',
cancelConfirmMessage: '确定要取消该清理任务吗?',
cancelConfirm: '确认取消',
cancelSuccess: '清理任务已取消',
cancelFailed: '取消清理任务失败',
recentTasks: '最近清理任务',
loadingTasks: '正在加载任务...',
noTasks: '暂无清理任务',
range: '时间范围',
deletedRows: '删除数量',
missingRange: '请选择时间范围',
submitSuccess: '清理任务已创建',
submitFailed: '创建清理任务失败',
loadFailed: '加载清理任务失败',
status: {
pending: '待执行',
running: '执行中',
succeeded: '已完成',
failed: '失败',
canceled: '已取消'
}
}
}, },
// Ops Monitoring // Ops Monitoring

View File

@@ -618,6 +618,7 @@ export interface UsageLog {
actual_cost: number actual_cost: number
rate_multiplier: number rate_multiplier: number
account_rate_multiplier?: number | null account_rate_multiplier?: number | null
billing_type: number
stream: boolean stream: boolean
duration_ms: number duration_ms: number
@@ -642,6 +643,33 @@ export interface UsageLog {
subscription?: UserSubscription subscription?: UserSubscription
} }
export interface UsageCleanupFilters {
start_time: string
end_time: string
user_id?: number
api_key_id?: number
account_id?: number
group_id?: number
model?: string | null
stream?: boolean | null
billing_type?: number | null
}
export interface UsageCleanupTask {
id: number
status: string
filters: UsageCleanupFilters
created_by: number
deleted_rows: number
error_message?: string | null
canceled_by?: number | null
canceled_at?: string | null
started_at?: string | null
finished_at?: string | null
created_at: string
updated_at: string
}
export interface RedeemCode { export interface RedeemCode {
id: number id: number
code: string code: string
@@ -865,6 +893,7 @@ export interface UsageQueryParams {
group_id?: number group_id?: number
model?: string model?: string
stream?: boolean stream?: boolean
billing_type?: number | null
start_date?: string start_date?: string
end_date?: string end_date?: string
} }

View File

@@ -17,12 +17,19 @@
<TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" /> <TokenUsageTrend :trend-data="trendData" :loading="chartsLoading" />
</div> </div>
</div> </div>
<UsageFilters v-model="filters" v-model:startDate="startDate" v-model:endDate="endDate" :exporting="exporting" @change="applyFilters" @reset="resetFilters" @export="exportToExcel" /> <UsageFilters v-model="filters" v-model:startDate="startDate" v-model:endDate="endDate" :exporting="exporting" @change="applyFilters" @reset="resetFilters" @cleanup="openCleanupDialog" @export="exportToExcel" />
<UsageTable :data="usageLogs" :loading="loading" /> <UsageTable :data="usageLogs" :loading="loading" />
<Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" /> <Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" />
</div> </div>
</AppLayout> </AppLayout>
<UsageExportProgress :show="exportProgress.show" :progress="exportProgress.progress" :current="exportProgress.current" :total="exportProgress.total" :estimated-time="exportProgress.estimatedTime" @cancel="cancelExport" /> <UsageExportProgress :show="exportProgress.show" :progress="exportProgress.progress" :current="exportProgress.current" :total="exportProgress.total" :estimated-time="exportProgress.estimatedTime" @cancel="cancelExport" />
<UsageCleanupDialog
:show="cleanupDialogVisible"
:filters="filters"
:start-date="startDate"
:end-date="endDate"
@close="cleanupDialogVisible = false"
/>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
@@ -33,6 +40,7 @@ import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admi
import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination from '@/components/common/Pagination.vue'; import Select from '@/components/common/Select.vue' import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination from '@/components/common/Pagination.vue'; import Select from '@/components/common/Select.vue'
import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; import UsageFilters from '@/components/admin/usage/UsageFilters.vue' import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; import UsageFilters from '@/components/admin/usage/UsageFilters.vue'
import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue' import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue'
import UsageCleanupDialog from '@/components/admin/usage/UsageCleanupDialog.vue'
import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue' import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
import type { UsageLog, TrendDataPoint, ModelStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage' import type { UsageLog, TrendDataPoint, ModelStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
@@ -42,6 +50,7 @@ const usageStats = ref<AdminUsageStatsResponse | null>(null); const usageLogs =
const trendData = ref<TrendDataPoint[]>([]); const modelStats = ref<ModelStat[]>([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day') const trendData = ref<TrendDataPoint[]>([]); const modelStats = ref<ModelStat[]>([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null
const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' }) const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' })
const cleanupDialogVisible = ref(false)
const granularityOptions = computed(() => [{ value: 'day', label: t('admin.dashboard.day') }, { value: 'hour', label: t('admin.dashboard.hour') }]) const granularityOptions = computed(() => [{ value: 'day', label: t('admin.dashboard.day') }, { value: 'hour', label: t('admin.dashboard.hour') }])
// Use local timezone to avoid UTC timezone issues // Use local timezone to avoid UTC timezone issues
@@ -53,7 +62,7 @@ const formatLD = (d: Date) => {
} }
const now = new Date(); const weekAgo = new Date(); weekAgo.setDate(weekAgo.getDate() - 6) const now = new Date(); const weekAgo = new Date(); weekAgo.setDate(weekAgo.getDate() - 6)
const startDate = ref(formatLD(weekAgo)); const endDate = ref(formatLD(now)) const startDate = ref(formatLD(weekAgo)); const endDate = ref(formatLD(now))
const filters = ref<AdminUsageQueryParams>({ user_id: undefined, model: undefined, group_id: undefined, start_date: startDate.value, end_date: endDate.value }) const filters = ref<AdminUsageQueryParams>({ user_id: undefined, model: undefined, group_id: undefined, billing_type: null, start_date: startDate.value, end_date: endDate.value })
const pagination = reactive({ page: 1, page_size: 20, total: 0 }) const pagination = reactive({ page: 1, page_size: 20, total: 0 })
const loadLogs = async () => { const loadLogs = async () => {
@@ -67,16 +76,17 @@ const loadStats = async () => { try { const s = await adminAPI.usage.getStats(fi
const loadChartData = async () => { const loadChartData = async () => {
chartsLoading.value = true chartsLoading.value = true
try { try {
const params = { start_date: filters.value.start_date || startDate.value, end_date: filters.value.end_date || endDate.value, granularity: granularity.value, user_id: filters.value.user_id, model: filters.value.model, api_key_id: filters.value.api_key_id, account_id: filters.value.account_id, group_id: filters.value.group_id, stream: filters.value.stream } const params = { start_date: filters.value.start_date || startDate.value, end_date: filters.value.end_date || endDate.value, granularity: granularity.value, user_id: filters.value.user_id, model: filters.value.model, api_key_id: filters.value.api_key_id, account_id: filters.value.account_id, group_id: filters.value.group_id, stream: filters.value.stream, billing_type: filters.value.billing_type }
const [trendRes, modelRes] = await Promise.all([adminAPI.dashboard.getUsageTrend(params), adminAPI.dashboard.getModelStats({ start_date: params.start_date, end_date: params.end_date, user_id: params.user_id, model: params.model, api_key_id: params.api_key_id, account_id: params.account_id, group_id: params.group_id, stream: params.stream })]) const [trendRes, modelRes] = await Promise.all([adminAPI.dashboard.getUsageTrend(params), adminAPI.dashboard.getModelStats({ start_date: params.start_date, end_date: params.end_date, user_id: params.user_id, model: params.model, api_key_id: params.api_key_id, account_id: params.account_id, group_id: params.group_id, stream: params.stream, billing_type: params.billing_type })])
trendData.value = trendRes.trend || []; modelStats.value = modelRes.models || [] trendData.value = trendRes.trend || []; modelStats.value = modelRes.models || []
} catch (error) { console.error('Failed to load chart data:', error) } finally { chartsLoading.value = false } } catch (error) { console.error('Failed to load chart data:', error) } finally { chartsLoading.value = false }
} }
const applyFilters = () => { pagination.page = 1; loadLogs(); loadStats(); loadChartData() } const applyFilters = () => { pagination.page = 1; loadLogs(); loadStats(); loadChartData() }
const resetFilters = () => { startDate.value = formatLD(weekAgo); endDate.value = formatLD(now); filters.value = { start_date: startDate.value, end_date: endDate.value }; granularity.value = 'day'; applyFilters() } const resetFilters = () => { startDate.value = formatLD(weekAgo); endDate.value = formatLD(now); filters.value = { start_date: startDate.value, end_date: endDate.value, billing_type: null }; granularity.value = 'day'; applyFilters() }
const handlePageChange = (p: number) => { pagination.page = p; loadLogs() } const handlePageChange = (p: number) => { pagination.page = p; loadLogs() }
const handlePageSizeChange = (s: number) => { pagination.page_size = s; pagination.page = 1; loadLogs() } const handlePageSizeChange = (s: number) => { pagination.page_size = s; pagination.page = 1; loadLogs() }
const cancelExport = () => exportAbortController?.abort() const cancelExport = () => exportAbortController?.abort()
const openCleanupDialog = () => { cleanupDialogVisible.value = true }
const exportToExcel = async () => { const exportToExcel = async () => {
if (exporting.value) return; exporting.value = true; exportProgress.show = true if (exporting.value) return; exporting.value = true; exportProgress.show = true