feat(usage): 添加清理任务与统计过滤

This commit is contained in:
yangjianbo
2026-01-18 10:52:18 +08:00
parent 74a3c74514
commit ef5a41057f
44 changed files with 4478 additions and 46 deletions

View File

@@ -55,6 +55,7 @@ type Config struct {
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
@@ -489,6 +490,20 @@ type DashboardAggregationRetentionConfig struct {
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
type UsageCleanupConfig struct {
// Enabled: 是否启用清理任务执行器
Enabled bool `mapstructure:"enabled"`
// MaxRangeDays: 单次任务允许的最大时间跨度(天)
MaxRangeDays int `mapstructure:"max_range_days"`
// BatchSize: 单批删除数量
BatchSize int `mapstructure:"batch_size"`
// WorkerIntervalSeconds: 后台任务轮询间隔(秒)
WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
// TaskTimeoutSeconds: 单次任务最大执行时长(秒)
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
}
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
@@ -749,6 +764,13 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
// Usage cleanup task
viper.SetDefault("usage_cleanup.enabled", true)
viper.SetDefault("usage_cleanup.max_range_days", 31)
viper.SetDefault("usage_cleanup.batch_size", 5000)
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true)
@@ -985,6 +1007,33 @@ func (c *Config) Validate() error {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
}
}
if c.UsageCleanup.Enabled {
if c.UsageCleanup.MaxRangeDays <= 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be positive")
}
if c.UsageCleanup.BatchSize <= 0 {
return fmt.Errorf("usage_cleanup.batch_size must be positive")
}
if c.UsageCleanup.WorkerIntervalSeconds <= 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive")
}
if c.UsageCleanup.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive")
}
} else {
if c.UsageCleanup.MaxRangeDays < 0 {
return fmt.Errorf("usage_cleanup.max_range_days must be non-negative")
}
if c.UsageCleanup.BatchSize < 0 {
return fmt.Errorf("usage_cleanup.batch_size must be non-negative")
}
if c.UsageCleanup.WorkerIntervalSeconds < 0 {
return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative")
}
if c.UsageCleanup.TaskTimeoutSeconds < 0 {
return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
}
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}

View File

@@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
}
}
func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if !cfg.UsageCleanup.Enabled {
t.Fatalf("UsageCleanup.Enabled = false, want true")
}
if cfg.UsageCleanup.MaxRangeDays != 31 {
t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays)
}
if cfg.UsageCleanup.BatchSize != 5000 {
t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize)
}
if cfg.UsageCleanup.WorkerIntervalSeconds != 10 {
t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds)
}
if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 {
t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds)
}
}
func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = true
cfg.UsageCleanup.MaxRangeDays = 0
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") {
t.Fatalf("Validate() expected max_range_days error, got: %v", err)
}
}
func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.UsageCleanup.Enabled = false
cfg.UsageCleanup.BatchSize = -1
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil")
}
if !strings.Contains(err.Error(), "usage_cleanup.batch_size") {
t.Fatalf("Validate() expected batch_size error, got: %v", err)
}
}
func TestConfigAddressHelpers(t *testing.T) {
server := ServerConfig{Host: "127.0.0.1", Port: 9000}
if server.Address() != "127.0.0.1:9000" {
t.Fatalf("ServerConfig.Address() = %q", server.Address())
}
dbCfg := DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "",
DBName: "sub2api",
SSLMode: "disable",
}
if !strings.Contains(dbCfg.DSN(), "password=") {
} else {
t.Fatalf("DatabaseConfig.DSN() should not include password when empty")
}
dbCfg.Password = "secret"
if !strings.Contains(dbCfg.DSN(), "password=secret") {
t.Fatalf("DatabaseConfig.DSN() missing password")
}
dbCfg.Password = ""
if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty")
}
if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone")
}
if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") {
t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone")
}
redis := RedisConfig{Host: "redis", Port: 6379}
if redis.Address() != "redis:6379" {
t.Fatalf("RedisConfig.Address() = %q", redis.Address())
}
}
func TestNormalizeStringSlice(t *testing.T) {
values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"})
if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" {
t.Fatalf("normalizeStringSlice() unexpected result: %#v", values)
}
if normalizeStringSlice(nil) != nil {
t.Fatalf("normalizeStringSlice(nil) expected nil slice")
}
}
func TestGetServerAddressFromEnv(t *testing.T) {
t.Setenv("SERVER_HOST", "127.0.0.1")
t.Setenv("SERVER_PORT", "9090")
address := GetServerAddress()
if address != "127.0.0.1:9090" {
t.Fatalf("GetServerAddress() = %q", address)
}
}
func TestValidateAbsoluteHTTPURL(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil {
t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err)
}
if err := ValidateAbsoluteHTTPURL(""); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url")
}
if err := ValidateAbsoluteHTTPURL("/relative"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url")
}
if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme")
}
if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment")
}
}
func TestValidateFrontendRedirectURL(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
}
if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil {
t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err)
}
if err := ValidateFrontendRedirectURL("example.com/path"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url")
}
if err := ValidateFrontendRedirectURL("//evil.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject // prefix")
}
if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme")
}
}
func TestWarnIfInsecureURL(t *testing.T) {
warnIfInsecureURL("test", "http://example.com")
warnIfInsecureURL("test", "bad://url")
}
func TestGenerateJWTSecretDefaultLength(t *testing.T) {
secret, err := generateJWTSecret(0)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Ops.Cleanup.Enabled = true
cfg.Ops.Cleanup.Schedule = ""
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for ops.cleanup.schedule")
}
if !strings.Contains(err.Error(), "ops.cleanup.schedule") {
t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err)
}
}
func TestValidateConcurrencyPingInterval(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Concurrency.PingInterval = 3
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for concurrency.ping_interval")
}
if !strings.Contains(err.Error(), "concurrency.ping_interval") {
t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err)
}
}
func TestProvideConfig(t *testing.T) {
viper.Reset()
if _, err := ProvideConfig(); err != nil {
t.Fatalf("ProvideConfig() error: %v", err)
}
}
func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.Security.CSP.Enabled = true
cfg.Security.CSP.Policy = "default-src 'self'"
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "client"
cfg.LinuxDo.ClientSecret = "secret"
cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize"
cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token"
cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
}
}
func TestValidateJWTSecretStrength(t *testing.T) {
if !isWeakJWTSecret("change-me-in-production") {
t.Fatalf("isWeakJWTSecret should detect weak secret")
}
if isWeakJWTSecret("StrongSecretValue") {
t.Fatalf("isWeakJWTSecret should accept strong secret")
}
}
func TestGenerateJWTSecretWithLength(t *testing.T) {
secret, err := generateJWTSecret(16)
if err != nil {
t.Fatalf("generateJWTSecret error: %v", err)
}
if len(secret) == 0 {
t.Fatalf("generateJWTSecret returned empty string")
}
}
func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
}
}
func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) {
if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars")
}
if err := ValidateFrontendRedirectURL("http://"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject missing host")
}
if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil {
t.Fatalf("ValidateFrontendRedirectURL should reject mailto")
}
}
func TestWarnIfInsecureURLHTTPS(t *testing.T) {
warnIfInsecureURL("secure", "https://example.com")
}
func TestValidateConfigErrors(t *testing.T) {
buildValid := func(t *testing.T) *Config {
t.Helper()
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
return cfg
}
cases := []struct {
name string
mutate func(*Config)
wantErr string
}{
{
name: "jwt expire hour positive",
mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
wantErr: "jwt.expire_hour must be positive",
},
{
name: "jwt expire hour max",
mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
wantErr: "jwt.expire_hour must be <= 168",
},
{
name: "csp policy required",
mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
wantErr: "security.csp.policy",
},
{
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
},
{
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
c.LinuxDo.TokenURL = "https://example.com/token"
c.LinuxDo.UserInfoURL = "https://example.com/userinfo"
c.LinuxDo.RedirectURL = "https://example.com/callback"
c.LinuxDo.FrontendRedirectURL = "/auth/callback"
c.LinuxDo.TokenAuthMethod = "invalid"
},
wantErr: "linuxdo_connect.token_auth_method",
},
{
name: "billing circuit breaker threshold",
mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 },
wantErr: "billing.circuit_breaker.failure_threshold",
},
{
name: "billing circuit breaker reset",
mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 },
wantErr: "billing.circuit_breaker.reset_timeout_seconds",
},
{
name: "billing circuit breaker half open",
mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 },
wantErr: "billing.circuit_breaker.half_open_requests",
},
{
name: "database max open conns",
mutate: func(c *Config) { c.Database.MaxOpenConns = 0 },
wantErr: "database.max_open_conns",
},
{
name: "database max lifetime",
mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 },
wantErr: "database.conn_max_lifetime_minutes",
},
{
name: "database idle exceeds open",
mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 },
wantErr: "database.max_idle_conns cannot exceed",
},
{
name: "redis dial timeout",
mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 },
wantErr: "redis.dial_timeout_seconds",
},
{
name: "redis read timeout",
mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 },
wantErr: "redis.read_timeout_seconds",
},
{
name: "redis write timeout",
mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 },
wantErr: "redis.write_timeout_seconds",
},
{
name: "redis pool size",
mutate: func(c *Config) { c.Redis.PoolSize = 0 },
wantErr: "redis.pool_size",
},
{
name: "redis idle exceeds pool",
mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 },
wantErr: "redis.min_idle_conns cannot exceed",
},
{
name: "dashboard cache disabled negative",
mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 },
wantErr: "dashboard_cache.stats_ttl_seconds",
},
{
name: "dashboard cache fresh ttl positive",
mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 },
wantErr: "dashboard_cache.stats_fresh_ttl_seconds",
},
{
name: "dashboard aggregation enabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "dashboard aggregation backfill positive",
mutate: func(c *Config) {
c.DashboardAgg.Enabled = true
c.DashboardAgg.BackfillEnabled = true
c.DashboardAgg.BackfillMaxDays = 0
},
wantErr: "dashboard_aggregation.backfill_max_days",
},
{
name: "dashboard aggregation retention",
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
wantErr: "dashboard_aggregation.retention.usage_logs_days",
},
{
name: "dashboard aggregation disabled interval",
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
wantErr: "dashboard_aggregation.interval_seconds",
},
{
name: "usage cleanup max range",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 },
wantErr: "usage_cleanup.max_range_days",
},
{
name: "usage cleanup worker interval",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 },
wantErr: "usage_cleanup.worker_interval_seconds",
},
{
name: "usage cleanup batch size",
mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "usage cleanup disabled negative",
mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 },
wantErr: "usage_cleanup.batch_size",
},
{
name: "gateway max body size",
mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
wantErr: "gateway.max_body_size",
},
{
name: "gateway max idle conns",
mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
wantErr: "gateway.max_idle_conns",
},
{
name: "gateway max idle conns per host",
mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 },
wantErr: "gateway.max_idle_conns_per_host",
},
{
name: "gateway idle timeout",
mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 },
wantErr: "gateway.idle_conn_timeout_seconds",
},
{
name: "gateway max upstream clients",
mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 },
wantErr: "gateway.max_upstream_clients",
},
{
name: "gateway client idle ttl",
mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 },
wantErr: "gateway.client_idle_ttl_seconds",
},
{
name: "gateway concurrency slot ttl",
mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 },
wantErr: "gateway.concurrency_slot_ttl_minutes",
},
{
name: "gateway max conns per host",
mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 },
wantErr: "gateway.max_conns_per_host",
},
{
name: "gateway connection isolation",
mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" },
wantErr: "gateway.connection_pool_isolation",
},
{
name: "gateway stream keepalive range",
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
wantErr: "gateway.stream_data_interval_timeout",
},
{
name: "gateway stream data interval negative",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
},
{
name: "gateway max line size",
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
wantErr: "gateway.max_line_size must be at least",
},
{
name: "gateway max line size negative",
mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
wantErr: "gateway.max_line_size must be non-negative",
},
{
name: "gateway scheduling sticky waiting",
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
wantErr: "gateway.scheduling.sticky_session_max_waiting",
},
{
name: "gateway scheduling outbox poll",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
wantErr: "gateway.scheduling.outbox_poll_interval_seconds",
},
{
name: "gateway scheduling outbox failures",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 },
wantErr: "gateway.scheduling.outbox_lag_rebuild_failures",
},
{
name: "gateway outbox lag rebuild",
mutate: func(c *Config) {
c.Gateway.Scheduling.OutboxLagWarnSeconds = 10
c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5
},
wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
},
{
name: "ops metrics collector ttl",
mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
wantErr: "ops.metrics_collector_cache.ttl",
},
{
name: "ops cleanup retention",
mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 },
wantErr: "ops.cleanup.error_log_retention_days",
},
{
name: "ops cleanup minute retention",
mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 },
wantErr: "ops.cleanup.minute_metrics_retention_days",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
cfg := buildValid(t)
tt.mutate(cfg)
err := cfg.Validate()
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,262 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
router.GET("/api/v1/admin/groups", groupHandler.List)
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
router.POST("/api/v1/admin/groups", groupHandler.Create)
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
router.GET("/api/v1/admin/proxies", proxyHandler.List)
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
return router, adminSvc
}
func TestUserHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
body, _ := json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
updateBody := map[string]any{"email": "updated@example.com"}
body, _ = json.Marshal(updateBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "update"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestProxyHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestRedeemHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -0,0 +1,134 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestParseTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
c.Request = req
start, end := parseTimeRange(c)
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
c.Request = req
start, end = parseTimeRange(c)
require.False(t, start.IsZero())
require.False(t, end.IsZero())
}
func TestParseOpsViewParam(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
c3, _ := gin.CreateTestContext(w)
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
require.Equal(t, "", parseOpsViewParam(nil))
}
func TestParseOpsDuration(t *testing.T) {
dur, ok := parseOpsDuration("1h")
require.True(t, ok)
require.Equal(t, time.Hour, dur)
_, ok = parseOpsDuration("invalid")
require.False(t, ok)
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
now := time.Now().UTC()
startStr := now.Add(-time.Hour).Format(time.RFC3339)
endStr := now.Format(time.RFC3339)
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
start, end, err := parseOpsTimeRange(c, "1h")
require.NoError(t, err)
require.True(t, start.Before(end))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
_, _, err = parseOpsTimeRange(c2, "1h")
require.Error(t, err)
}
func TestParseOpsRealtimeWindow(t *testing.T) {
dur, label, ok := parseOpsRealtimeWindow("5m")
require.True(t, ok)
require.Equal(t, 5*time.Minute, dur)
require.Equal(t, "5min", label)
_, _, ok = parseOpsRealtimeWindow("invalid")
require.False(t, ok)
}
func TestPickThroughputBucketSeconds(t *testing.T) {
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
}
func TestParseOpsQueryMode(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
}
func TestOpsAlertRuleValidation(t *testing.T) {
raw := map[string]json.RawMessage{
"name": json.RawMessage(`"High error rate"`),
"metric_type": json.RawMessage(`"error_rate"`),
"operator": json.RawMessage(`">"`),
"threshold": json.RawMessage(`90`),
}
validated, err := validateOpsAlertRulePayload(raw)
require.NoError(t, err)
require.Equal(t, "High error rate", validated.Name)
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
require.Error(t, err)
require.True(t, isPercentOrRateMetric("error_rate"))
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
}
func TestOpsWSHelpers(t *testing.T) {
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
require.Len(t, prefixes, 1)
require.Len(t, invalid, 1)
host := hostWithoutPort("example.com:443")
require.Equal(t, "example.com", host)
addr := netip.MustParseAddr("10.0.0.1")
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}

View File

@@ -0,0 +1,290 @@
package admin
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
}
func newStubAdminService() *stubAdminService {
now := time.Now().UTC()
user := service.User{
ID: 1,
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
apiKey := service.APIKey{
ID: 10,
UserID: user.ID,
Key: "sk-test",
Name: "test",
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
group := service.Group{
ID: 2,
Name: "group",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
account := service.Account{
ID: 3,
Name: "account",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
proxy := service.Proxy{
ID: 4,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
redeem := service.RedeemCode{
ID: 5,
Code: "R-TEST",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
CreatedAt: now,
}
return &stubAdminService{
users: []service.User{user},
apiKeys: []service.APIKey{apiKey},
groups: []service.Group{group},
accounts: []service.Account{account},
proxies: []service.Proxy{proxy},
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
redeems: []service.RedeemCode{redeem},
}
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil
}
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
for i := range s.users {
if s.users[i].ID == id {
return &s.users[i], nil
}
}
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
out := make([]*service.Account, 0, len(ids))
for _, id := range ids {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
out = append(out, &account)
}
return out, nil
}
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
return &account, nil
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
return s.proxies, int64(len(s.proxies)), nil
}
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil
}
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
return s.proxies, nil
}
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
return s.proxyCounts, nil
}
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
}
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
}
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return false, nil
}
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
return &code, nil
}
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
return s.redeems, nil
}
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
return int64(len(ids)), nil
}
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
return &code, nil
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -186,7 +186,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
@@ -195,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
var userID, apiKeyID, accountID, groupID int64
var model string
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
@@ -224,8 +225,17 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -241,13 +251,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
@@ -274,8 +285,17 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
stream = &streamVal
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return

View File

@@ -0,0 +1,377 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"database/sql"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type cleanupRepoStub struct {
mu sync.Mutex
created []*service.UsageCleanupTask
listTasks []service.UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
statusByID map[int64]string
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
task.UpdatedAt = task.CreatedAt
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
return nil, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
return 0, nil
}
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
if userID > 0 {
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
})
}
handler := NewUsageHandler(nil, nil, nil, cleanupService)
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
return router
}
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-13-01",
"end_date": "2024-01-02",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-02-40",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": " 2024-01-01 ",
"end_date": "2024-01-02",
"timezone": "UTC",
"model": "gpt-4",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Equal(t, int64(99), created.CreatedBy)
require.NotNil(t, created.Filters.Model)
require.Equal(t, "gpt-4", *created.Filters.Model)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
require.True(t, created.Filters.StartTime.Equal(start))
require.True(t, created.Filters.EndTime.Equal(end))
}
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 0)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
repo.listTasks = []service.UsageCleanupTask{
{
ID: 7,
Status: service.UsageCleanupStatusSucceeded,
CreatedBy: 4,
},
}
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Items []dto.UsageCleanupTask `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Items, 1)
require.Equal(t, int64(7), resp.Data.Items[0].ID)
require.Equal(t, int64(1), resp.Data.Total)
require.Equal(t, 1, resp.Data.Page)
}
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
repo := &cleanupRepoStub{listErr: errors.New("boom")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
}
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -1,7 +1,10 @@
package admin
import (
"log"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -9,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -16,9 +20,10 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.APIKeyService
adminService service.AdminService
usageService *service.UsageService
apiKeyService *service.APIKeyService
adminService service.AdminService
cleanupService *service.UsageCleanupService
}
// NewUsageHandler creates a new admin usage handler
@@ -26,14 +31,30 @@ func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
adminService service.AdminService,
cleanupService *service.UsageCleanupService,
) *UsageHandler {
return &UsageHandler{
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
cleanupService: cleanupService,
}
}
// CreateUsageCleanupTaskRequest represents cleanup task creation request
type CreateUsageCleanupTaskRequest struct {
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
UserID *int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
}
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
@@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
response.Success(c, result)
}
// ListCleanupTasks handles listing usage cleanup tasks
// GET /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
operator := int64(0)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
out := make([]dto.UsageCleanupTask, 0, len(tasks))
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
// CreateCleanupTask handles creating a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
var req CreateUsageCleanupTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.StartDate = strings.TrimSpace(req.StartDate)
req.EndDate = strings.TrimSpace(req.EndDate)
if req.StartDate == "" || req.EndDate == "" {
response.BadRequest(c, "start_date and end_date are required")
return
}
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
UserID: req.UserID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
Stream: req.Stream,
BillingType: req.BillingType,
}
var userID any
if filters.UserID != nil {
userID = *filters.UserID
}
var apiKeyID any
if filters.APIKeyID != nil {
apiKeyID = *filters.APIKeyID
}
var accountID any
if filters.AccountID != nil {
accountID = *filters.AccountID
}
var groupID any
if filters.GroupID != nil {
groupID = *filters.GroupID
}
var model any
if filters.Model != nil {
model = *filters.Model
}
var stream any
if filters.Stream != nil {
stream = *filters.Stream
}
var billingType any
if filters.BillingType != nil {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
}
// CancelCleanupTask handles canceling a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
taskID, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || taskID <= 0 {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -340,6 +340,36 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
}
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
if task == nil {
return nil
}
return &UsageCleanupTask{
ID: task.ID,
Status: task.Status,
Filters: UsageCleanupFilters{
StartTime: task.Filters.StartTime,
EndTime: task.Filters.EndTime,
UserID: task.Filters.UserID,
APIKeyID: task.Filters.APIKeyID,
AccountID: task.Filters.AccountID,
GroupID: task.Filters.GroupID,
Model: task.Filters.Model,
Stream: task.Filters.Stream,
BillingType: task.Filters.BillingType,
},
CreatedBy: task.CreatedBy,
DeletedRows: task.DeletedRows,
ErrorMessage: task.ErrorMsg,
CanceledBy: task.CanceledBy,
CanceledAt: task.CanceledAt,
StartedAt: task.StartedAt,
FinishedAt: task.FinishedAt,
CreatedAt: task.CreatedAt,
UpdatedAt: task.UpdatedAt,
}
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil

View File

@@ -223,6 +223,33 @@ type UsageLog struct {
Subscription *UserSubscription `json:"subscription,omitempty"`
}
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
type UsageCleanupTask struct {
ID int64 `json:"id"`
Status string `json:"status"`
Filters UsageCleanupFilters `json:"filters"`
CreatedBy int64 `json:"created_by"`
DeletedRows int64 `json:"deleted_rows"`
ErrorMessage *string `json:"error_message,omitempty"`
CanceledBy *int64 `json:"canceled_by,omitempty"`
CanceledAt *time.Time `json:"canceled_at,omitempty"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AccountSummary is a minimal account info for usage log display.
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
type AccountSummary struct {

View File

@@ -77,6 +77,75 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
return nil
}
func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error {
if r == nil || r.sql == nil {
return nil
}
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
// 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。
if db, ok := r.sql.(*sql.DB); ok {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
}
func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
// 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
return err
}
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
var ts time.Time
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"

View File

@@ -0,0 +1,363 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageCleanupRepository struct {
sql sqlExecutor
}
func NewUsageCleanupRepository(sqlDB *sql.DB) service.UsageCleanupRepository {
return &usageCleanupRepository{sql: sqlDB}
}
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
filtersJSON, err := json.Marshal(task.Filters)
if err != nil {
return fmt.Errorf("marshal cleanup filters: %w", err)
}
query := `
INSERT INTO usage_cleanup_tasks (
status,
filters,
created_by,
deleted_rows
) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at
`
if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
return err
}
return nil
}
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
var total int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
return nil, nil, err
}
if total == 0 {
return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
}
query := `
SELECT id, status, filters, created_by, deleted_rows, error_message,
canceled_by, canceled_at,
started_at, finished_at, created_at, updated_at
FROM usage_cleanup_tasks
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
`
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
if err != nil {
return nil, nil, err
}
defer rows.Close()
tasks := make([]service.UsageCleanupTask, 0)
for rows.Next() {
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var canceledBy sql.NullInt64
var canceledAt sql.NullTime
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := rows.Scan(
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&canceledBy,
&canceledAt,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
return nil, nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if canceledBy.Valid {
v := canceledBy.Int64
task.CanceledBy = &v
}
if canceledAt.Valid {
task.CanceledAt = &canceledAt.Time
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
tasks = append(tasks, task)
}
if err := rows.Err(); err != nil {
return nil, nil, err
}
return tasks, paginationResultFromTotal(total, params), nil
}
func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
if staleRunningAfterSeconds <= 0 {
staleRunningAfterSeconds = 1800
}
query := `
WITH next AS (
SELECT id
FROM usage_cleanup_tasks
WHERE status = $1
OR (
status = $2
AND started_at IS NOT NULL
AND started_at < NOW() - ($3 * interval '1 second')
)
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
UPDATE usage_cleanup_tasks
SET status = $4,
started_at = NOW(),
finished_at = NULL,
error_message = NULL,
updated_at = NOW()
FROM next
WHERE usage_cleanup_tasks.id = next.id
RETURNING id, status, filters, created_by, deleted_rows, error_message,
started_at, finished_at, created_at, updated_at
`
var task service.UsageCleanupTask
var filtersJSON []byte
var errMsg sql.NullString
var startedAt sql.NullTime
var finishedAt sql.NullTime
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
staleRunningAfterSeconds,
service.UsageCleanupStatusRunning,
},
&task.ID,
&task.Status,
&filtersJSON,
&task.CreatedBy,
&task.DeletedRows,
&errMsg,
&startedAt,
&finishedAt,
&task.CreatedAt,
&task.UpdatedAt,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
return nil, fmt.Errorf("parse cleanup filters: %w", err)
}
if errMsg.Valid {
task.ErrorMsg = &errMsg.String
}
if startedAt.Valid {
task.StartedAt = &startedAt.Time
}
if finishedAt.Valid {
task.FinishedAt = &finishedAt.Time
}
return &task, nil
}
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
var status string
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
return "", err
}
return status, nil
}
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
query := `
UPDATE usage_cleanup_tasks
SET deleted_rows = $1,
updated_at = NOW()
WHERE id = $2
`
_, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
canceled_by = $3,
canceled_at = NOW(),
finished_at = NOW(),
error_message = NULL,
updated_at = NOW()
WHERE id = $2
AND status IN ($4, $5)
RETURNING id
`
var id int64
err := scanSingleRow(ctx, r.sql, query, []any{
service.UsageCleanupStatusCanceled,
taskID,
canceledBy,
service.UsageCleanupStatusPending,
service.UsageCleanupStatusRunning,
}, &id)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $3
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
return err
}
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
deleted_rows = $2,
error_message = $3,
finished_at = NOW(),
updated_at = NOW()
WHERE id = $4
`
_, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
return err
}
func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
return 0, fmt.Errorf("cleanup filters missing time range")
}
whereClause, args := buildUsageCleanupWhere(filters)
if whereClause == "" {
return 0, fmt.Errorf("cleanup filters missing time range")
}
args = append(args, limit)
query := fmt.Sprintf(`
WITH target AS (
SELECT id
FROM usage_logs
WHERE %s
ORDER BY created_at ASC, id ASC
LIMIT $%d
)
DELETE FROM usage_logs
WHERE id IN (SELECT id FROM target)
RETURNING id
`, whereClause, len(args))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rows.Close()
var deleted int64
for rows.Next() {
deleted++
}
if err := rows.Err(); err != nil {
return 0, err
}
return deleted, nil
}
func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
idx := 1
if !filters.StartTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
args = append(args, filters.StartTime)
idx++
}
if !filters.EndTime.IsZero() {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
args = append(args, filters.EndTime)
idx++
}
if filters.UserID != nil {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
args = append(args, *filters.UserID)
idx++
}
if filters.APIKeyID != nil {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
args = append(args, *filters.APIKeyID)
idx++
}
if filters.AccountID != nil {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
args = append(args, *filters.AccountID)
idx++
}
if filters.GroupID != nil {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
args = append(args, *filters.GroupID)
idx++
}
if filters.Model != nil {
model := strings.TrimSpace(*filters.Model)
if model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
args = append(args, model)
idx++
}
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream)
idx++
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
args = append(args, *filters.BillingType)
idx++
}
return strings.Join(conditions, " AND "), args
}

View File

@@ -0,0 +1,440 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
t.Helper()
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp))
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
return db, mock
}
func TestNewUsageCleanupRepository(t *testing.T) {
db, _ := newSQLMock(t)
repo := NewUsageCleanupRepository(db)
require.NotNil(t, repo)
}
func TestUsageCleanupRepositoryCreateTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
CreatedBy: 12,
}
now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now))
err := repo.CreateTask(context.Background(), task)
require.NoError(t, err)
require.Equal(t, int64(1), task.ID)
require.Equal(t, now, task.CreatedAt)
require.Equal(t, now, task.UpdatedAt)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
err := repo.CreateTask(context.Background(), nil)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
task := &service.UsageCleanupTask{
Status: service.UsageCleanupStatusPending,
Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)},
CreatedBy: 1,
}
mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
WillReturnError(sql.ErrConnDone)
err := repo.CreateTask(context.Background(), task)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Empty(t, tasks)
require.Equal(t, int64(0), result.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasks(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(2 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
filtersJSON, err := json.Marshal(filters)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)
updatedAt := createdAt.Add(time.Minute)
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"canceled_by", "canceled_at",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(1),
service.UsageCleanupStatusSucceeded,
filtersJSON,
int64(2),
int64(9),
"error",
nil,
nil,
start,
end,
createdAt,
updatedAt,
)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
WithArgs(20, 0).
WillReturnRows(rows)
tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, tasks, 1)
require.Equal(t, int64(1), tasks[0].ID)
require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status)
require.Equal(t, int64(2), tasks[0].CreatedBy)
require.Equal(t, int64(9), tasks[0].DeletedRows)
require.NotNil(t, tasks[0].ErrorMsg)
require.Equal(t, "error", *tasks[0].ErrorMsg)
require.NotNil(t, tasks[0].StartedAt)
require.NotNil(t, tasks[0].FinishedAt)
require.Equal(t, int64(1), result.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"canceled_by", "canceled_at",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(1),
service.UsageCleanupStatusSucceeded,
[]byte("not-json"),
int64(2),
int64(9),
nil,
nil,
nil,
nil,
nil,
time.Now().UTC(),
time.Now().UTC(),
)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
WithArgs(20, 0).
WillReturnRows(rows)
_, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}))
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.NoError(t, err)
require.Nil(t, task)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
filtersJSON, err := json.Marshal(filters)
require.NoError(t, err)
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(4),
service.UsageCleanupStatusRunning,
filtersJSON,
int64(7),
int64(0),
nil,
start,
nil,
start,
start,
)
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(rows)
task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.NoError(t, err)
require.NotNil(t, task)
require.Equal(t, int64(4), task.ID)
require.Equal(t, service.UsageCleanupStatusRunning, task.Status)
require.Equal(t, int64(7), task.CreatedBy)
require.NotNil(t, task.StartedAt)
require.Nil(t, task.ErrorMsg)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnError(sql.ErrConnDone)
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
rows := sqlmock.NewRows([]string{
"id", "status", "filters", "created_by", "deleted_rows", "error_message",
"started_at", "finished_at", "created_at", "updated_at",
}).AddRow(
int64(4),
service.UsageCleanupStatusRunning,
[]byte("invalid"),
int64(7),
int64(0),
nil,
nil,
nil,
time.Now().UTC(),
time.Now().UTC(),
)
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
WillReturnRows(rows)
_, err := repo.ClaimNextPendingTask(context.Background(), 1800)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.MarkTaskSucceeded(context.Background(), 9, 12)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom")
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
WithArgs(int64(9)).
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending))
status, err := repo.GetTaskStatus(context.Background(), 9)
require.NoError(t, err)
require.Equal(t, service.UsageCleanupStatusPending, status)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectExec("UPDATE usage_cleanup_tasks").
WithArgs(int64(123), int64(8)).
WillReturnResult(sqlmock.NewResult(0, 1))
err := repo.UpdateTaskProgress(context.Background(), 8, 123)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryCancelTask(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
mock.ExpectQuery("UPDATE usage_cleanup_tasks").
WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6)))
ok, err := repo.CancelTask(context.Background(), 6, 9)
require.NoError(t, err)
require.True(t, ok)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) {
db, _ := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
_, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10)
require.Error(t, err)
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(3)
model := " gpt-4 "
filters := service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
Model: &model,
}
mock.ExpectQuery("DELETE FROM usage_logs").
WithArgs(start, end, userID, "gpt-4", 2).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2)))
deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2)
require.NoError(t, err)
require.Equal(t, int64(2), deleted)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
mock.ExpectQuery("DELETE FROM usage_logs").
WithArgs(start, end, 5).
WillReturnError(sql.ErrConnDone)
_, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5)
require.Error(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildUsageCleanupWhere(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(1)
apiKeyID := int64(2)
accountID := int64(3)
groupID := int64(4)
model := " gpt-4 "
stream := true
billingType := int8(2)
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
APIKeyID: &apiKeyID,
AccountID: &accountID,
GroupID: &groupID,
Model: &model,
Stream: &stream,
BillingType: &billingType,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where)
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
}
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
model := " "
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
Model: &model,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2", where)
require.Equal(t, []any{start, end}, args)
}

View File

@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
}
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
if err != nil {
models = []ModelStat{}
}

View File

@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}

View File

@@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet(
NewRedeemCodeRepository,
NewPromoCodeRepository,
NewUsageLogRepository,
NewUsageCleanupRepository,
NewDashboardAggregationRepository,
NewSettingRepository,
NewOpsRepository,

View File

@@ -1242,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}

View File

@@ -354,6 +354,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
usage.GET("/cleanup-tasks", h.Admin.Usage.ListCleanupTasks)
usage.POST("/cleanup-tasks", h.Admin.Usage.CreateCleanupTask)
usage.POST("/cleanup-tasks/:id/cancel", h.Admin.Usage.CancelCleanupTask)
}
}

View File

@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -272,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -294,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}

View File

@@ -21,11 +21,15 @@ var (
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
RecomputeRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return nil
}
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
// 与 TriggerBackfill 不同:
// - 不依赖 backfill_enabled这是内部一致性修复
// - 不更新 watermark避免影响正常增量聚合游标
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.Enabled {
return errors.New("聚合服务已禁用")
}
if !end.After(start) {
return errors.New("重新计算时间范围无效")
}
go func() {
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
err := s.recomputeRange(ctx, start, end)
cancel()
if err == nil {
return
}
if !errors.Is(err, errDashboardAggregationRunning) {
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
return
}
time.Sleep(5 * time.Second)
}
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
}()
return nil
}
func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays
if days <= 0 {
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
}
}
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
return err
}
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
start.UTC().Format(time.RFC3339),
end.UTC().Format(time.RFC3339),
time.Since(jobStart).String(),
)
return nil
}
func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errors.New("聚合作业正在运行")
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)

View File

@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
return s.aggregateErr
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return s.AggregateRange(ctx, start, end)
}
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil
}

View File

@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}

View File

@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return nil
}
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil {
return time.Time{}, s.err

View File

@@ -190,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return true, err
}
@@ -237,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return true, err
}

View File

@@ -0,0 +1,74 @@
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
UsageCleanupStatusPending = "pending"
UsageCleanupStatusRunning = "running"
UsageCleanupStatusSucceeded = "succeeded"
UsageCleanupStatusFailed = "failed"
UsageCleanupStatusCanceled = "canceled"
)
// UsageCleanupFilters 定义清理任务过滤条件
// 时间范围为必填,其他字段可选
// JSON 序列化用于存储任务参数
//
// start_time/end_time 使用 RFC3339 时间格式
// 以 UTC 或用户时区解析后的时间为准
//
// 说明:
// - nil 表示未设置该过滤条件
// - 过滤条件均为精确匹配
type UsageCleanupFilters struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
// UsageCleanupTask 表示使用记录清理任务
// 状态包含 pending/running/succeeded/failed/canceled
type UsageCleanupTask struct {
ID int64
Status string
Filters UsageCleanupFilters
CreatedBy int64
DeletedRows int64
ErrorMsg *string
CanceledBy *int64
CanceledAt *time.Time
StartedAt *time.Time
FinishedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
}
// UsageCleanupRepository 定义清理任务持久层接口
type UsageCleanupRepository interface {
CreateTask(ctx context.Context, task *UsageCleanupTask) error
ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error)
// ClaimNextPendingTask 抢占下一条可执行任务:
// - 优先 pending
// - 若 running 超过 staleRunningAfterSeconds可能由于进程退出/崩溃/超时),允许重新抢占继续执行
ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error)
// GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows
GetTaskStatus(ctx context.Context, taskID int64) (string, error)
// UpdateTaskProgress 更新任务进度deleted_rows用于断点续跑/展示
UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error
// CancelTask 将任务标记为 canceled仅允许 pending/running
CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error)
MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error
MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error
DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error)
}

View File

@@ -0,0 +1,400 @@
package service
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
usageCleanupWorkerName = "usage_cleanup_worker"
)
// UsageCleanupService 负责创建与执行使用记录清理任务
type UsageCleanupService struct {
repo UsageCleanupRepository
timingWheel *TimingWheelService
dashboard *DashboardAggregationService
cfg *config.Config
running int32
startOnce sync.Once
stopOnce sync.Once
workerCtx context.Context
workerCancel context.CancelFunc
}
func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
workerCtx, workerCancel := context.WithCancel(context.Background())
return &UsageCleanupService{
repo: repo,
timingWheel: timingWheel,
dashboard: dashboard,
cfg: cfg,
workerCtx: workerCtx,
workerCancel: workerCancel,
}
}
func describeUsageCleanupFilters(filters UsageCleanupFilters) string {
var parts []string
parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339))
parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339))
if filters.UserID != nil {
parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID))
}
if filters.APIKeyID != nil {
parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID))
}
if filters.AccountID != nil {
parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID))
}
if filters.GroupID != nil {
parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID))
}
if filters.Model != nil {
parts = append(parts, "model="+strings.TrimSpace(*filters.Model))
}
if filters.Stream != nil {
parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream))
}
if filters.BillingType != nil {
parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType))
}
return strings.Join(parts, " ")
}
func (s *UsageCleanupService) Start() {
if s == nil {
return
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
log.Printf("[UsageCleanup] not started (disabled)")
return
}
if s.repo == nil || s.timingWheel == nil {
log.Printf("[UsageCleanup] not started (missing deps)")
return
}
interval := s.workerInterval()
s.startOnce.Do(func() {
s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce)
log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
})
}
func (s *UsageCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.workerCancel != nil {
s.workerCancel()
}
if s.timingWheel != nil {
s.timingWheel.Cancel(usageCleanupWorkerName)
}
log.Printf("[UsageCleanup] stopped")
})
}
func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
if s == nil || s.repo == nil {
return nil, nil, fmt.Errorf("cleanup service not ready")
}
return s.repo.ListTasks(ctx, params)
}
func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) {
if s == nil || s.repo == nil {
return nil, fmt.Errorf("cleanup service not ready")
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
}
if createdBy <= 0 {
return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator")
}
log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
sanitizeUsageCleanupFilters(&filters)
if err := s.validateFilters(filters); err != nil {
log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
return nil, err
}
task := &UsageCleanupTask{
Status: UsageCleanupStatusPending,
Filters: filters,
CreatedBy: createdBy,
}
if err := s.repo.CreateTask(ctx, task); err != nil {
log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
return nil, fmt.Errorf("create cleanup task: %w", err)
}
log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
go s.runOnce()
return task, nil
}
func (s *UsageCleanupService) runOnce() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
return
}
defer atomic.StoreInt32(&s.running, 0)
parent := context.Background()
if s != nil && s.workerCtx != nil {
parent = s.workerCtx
}
ctx, cancel := context.WithTimeout(parent, s.taskTimeout())
defer cancel()
task, err := s.repo.ClaimNextPendingTask(ctx, int64(s.taskTimeout().Seconds()))
if err != nil {
log.Printf("[UsageCleanup] claim pending task failed: %v", err)
return
}
if task == nil {
log.Printf("[UsageCleanup] run_once done: no_task=true")
return
}
log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
s.executeTask(ctx, task)
}
func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) {
if task == nil {
return
}
batchSize := s.batchSize()
deletedTotal := task.DeletedRows
start := time.Now()
log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
var batchNum int
for {
if ctx != nil && ctx.Err() != nil {
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
return
}
canceled, err := s.isTaskCanceled(ctx, task.ID)
if err != nil {
s.markTaskFailed(task.ID, deletedTotal, err)
return
}
if canceled {
log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
return
}
batchNum++
deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
return
}
s.markTaskFailed(task.ID, deletedTotal, err)
return
}
deletedTotal += deleted
if deleted > 0 {
updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil {
log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
}
cancel()
}
if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) {
log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
}
if deleted == 0 || deleted < int64(batchSize) {
break
}
}
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil {
log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
} else {
log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
}
if s.dashboard != nil {
if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil {
log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
} else {
log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
}
}
}
func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) {
msg := strings.TrimSpace(err.Error())
if len(msg) > 500 {
msg = msg[:500]
}
log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil {
log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
}
}
func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) {
if s == nil || s.repo == nil {
return false, fmt.Errorf("cleanup service not ready")
}
checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
status, err := s.repo.GetTaskStatus(checkCtx, taskID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
if status == UsageCleanupStatusCanceled {
log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID)
}
return status == UsageCleanupStatusCanceled, nil
}
func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error {
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required")
}
if filters.EndTime.Before(filters.StartTime) {
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date")
}
maxDays := s.maxRangeDays()
if maxDays > 0 {
delta := filters.EndTime.Sub(filters.StartTime)
if delta > time.Duration(maxDays)*24*time.Hour {
return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays))
}
}
return nil
}
func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error {
if s == nil || s.repo == nil {
return fmt.Errorf("cleanup service not ready")
}
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
}
if canceledBy <= 0 {
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller")
}
status, err := s.repo.GetTaskStatus(ctx, taskID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found")
}
return err
}
log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
ok, err := s.repo.CancelTask(ctx, taskID, canceledBy)
if err != nil {
return err
}
if !ok {
// 状态可能并发改变
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
}
log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
return nil
}
func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) {
if filters == nil {
return
}
if filters.UserID != nil && *filters.UserID <= 0 {
filters.UserID = nil
}
if filters.APIKeyID != nil && *filters.APIKeyID <= 0 {
filters.APIKeyID = nil
}
if filters.AccountID != nil && *filters.AccountID <= 0 {
filters.AccountID = nil
}
if filters.GroupID != nil && *filters.GroupID <= 0 {
filters.GroupID = nil
}
if filters.Model != nil {
model := strings.TrimSpace(*filters.Model)
if model == "" {
filters.Model = nil
} else {
filters.Model = &model
}
}
if filters.BillingType != nil && *filters.BillingType < 0 {
filters.BillingType = nil
}
}
func (s *UsageCleanupService) maxRangeDays() int {
if s == nil || s.cfg == nil {
return 31
}
if s.cfg.UsageCleanup.MaxRangeDays > 0 {
return s.cfg.UsageCleanup.MaxRangeDays
}
return 31
}
func (s *UsageCleanupService) batchSize() int {
if s == nil || s.cfg == nil {
return 5000
}
if s.cfg.UsageCleanup.BatchSize > 0 {
return s.cfg.UsageCleanup.BatchSize
}
return 5000
}
func (s *UsageCleanupService) workerInterval() time.Duration {
if s == nil || s.cfg == nil {
return 10 * time.Second
}
if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 {
return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second
}
return 10 * time.Second
}
func (s *UsageCleanupService) taskTimeout() time.Duration {
if s == nil || s.cfg == nil {
return 30 * time.Minute
}
if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 {
return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second
}
return 30 * time.Minute
}

View File

@@ -0,0 +1,420 @@
package service
import (
"context"
"database/sql"
"errors"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type cleanupDeleteResponse struct {
deleted int64
err error
}
type cleanupDeleteCall struct {
filters UsageCleanupFilters
limit int
}
type cleanupMarkCall struct {
taskID int64
deletedRows int64
errMsg string
}
type cleanupRepoStub struct {
mu sync.Mutex
created []*UsageCleanupTask
createErr error
listTasks []UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
claimQueue []*UsageCleanupTask
claimErr error
deleteQueue []cleanupDeleteResponse
deleteCalls []cleanupDeleteCall
markSucceeded []cleanupMarkCall
markFailed []cleanupMarkCall
statusByID map[int64]string
progressCalls []cleanupMarkCall
cancelCalls []int64
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.createErr != nil {
return s.createErr
}
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
if task.UpdatedAt.IsZero() {
task.UpdatedAt = task.CreatedAt
}
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.claimErr != nil {
return nil, s.claimErr
}
if len(s.claimQueue) == 0 {
return nil, nil
}
task := s.claimQueue[0]
s.claimQueue = s.claimQueue[1:]
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[task.ID] = UsageCleanupStatusRunning
return task, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
s.mu.Lock()
defer s.mu.Unlock()
s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.cancelCalls = append(s.cancelCalls, taskID)
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
s.mu.Lock()
defer s.mu.Unlock()
s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusSucceeded
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg})
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusFailed
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit})
if len(s.deleteQueue) == 0 {
return 0, nil
}
resp := s.deleteQueue[0]
s.deleteQueue = s.deleteQueue[1:]
return resp.deleted, resp.err
}
func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
userID := int64(-1)
apiKeyID := int64(10)
model := " gpt-4 "
billingType := int8(-2)
filters := UsageCleanupFilters{
StartTime: start,
EndTime: end,
UserID: &userID,
APIKeyID: &apiKeyID,
Model: &model,
BillingType: &billingType,
}
task, err := svc.CreateTask(context.Background(), filters, 9)
require.NoError(t, err)
require.Equal(t, UsageCleanupStatusPending, task.Status)
require.Nil(t, task.Filters.UserID)
require.NotNil(t, task.Filters.APIKeyID)
require.Equal(t, apiKeyID, *task.Filters.APIKeyID)
require.NotNil(t, task.Filters.Model)
require.Equal(t, "gpt-4", *task.Filters.Model)
require.Nil(t, task.Filters.BillingType)
require.Equal(t, int64(9), task.CreatedBy)
}
func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 0)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(48 * time.Hour)
filters := UsageCleanupFilters{StartTime: start, EndTime: end}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
_, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1)
require.Error(t, err)
require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err))
}
func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
repo := &cleanupRepoStub{createErr: errors.New("db down")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
filters := UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
}
_, err := svc.CreateTask(context.Background(), filters, 1)
require.Error(t, err)
require.Contains(t, err.Error(), "create cleanup task")
}
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
repo := &cleanupRepoStub{
claimQueue: []*UsageCleanupTask{
{ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}},
},
deleteQueue: []cleanupDeleteResponse{
{deleted: 2},
{deleted: 2},
{deleted: 1},
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.runOnce()
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.deleteCalls, 3)
require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows)
}
func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) {
repo := &cleanupRepoStub{claimErr: errors.New("claim failed")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.runOnce()
repo.mu.Lock()
defer repo.mu.Unlock()
require.Empty(t, repo.markSucceeded)
require.Empty(t, repo.markFailed)
}
func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
svc.running = 1
svc.runOnce()
}
func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) {
longMsg := strings.Repeat("x", 600)
repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{
{err: errors.New(longMsg)},
},
}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}}
svc := NewUsageCleanupService(repo, nil, nil, cfg)
task := &UsageCleanupTask{
ID: 11,
Filters: UsageCleanupFilters{
StartTime: time.Now(),
EndTime: time.Now().Add(24 * time.Hour),
},
}
svc.executeTask(context.Background(), task)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.markFailed, 1)
require.Equal(t, int64(11), repo.markFailed[0].taskID)
require.Equal(t, 500, len(repo.markFailed[0].errMsg))
}
func TestUsageCleanupServiceListTasks(t *testing.T) {
repo := &cleanupRepoStub{
listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}},
listResult: &pagination.PaginationResult{
Total: 2,
Page: 1,
PageSize: 20,
Pages: 1,
},
}
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.NoError(t, err)
require.Len(t, tasks, 2)
require.Equal(t, int64(2), result.Total)
}
func TestUsageCleanupServiceListTasksNotReady(t *testing.T) {
var nilSvc *UsageCleanupService
_, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
_, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
require.Error(t, err)
}
func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) {
var nilSvc *UsageCleanupService
require.Equal(t, 31, nilSvc.maxRangeDays())
require.Equal(t, 5000, nilSvc.batchSize())
require.Equal(t, 10*time.Second, nilSvc.workerInterval())
require.Equal(t, 30*time.Minute, nilSvc.taskTimeout())
nilSvc.Start()
nilSvc.Stop()
repo := &cleanupRepoStub{}
cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled)
svcDisabled.Start()
svcDisabled.Stop()
timingWheel, err := NewTimingWheelService()
require.NoError(t, err)
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}}
svc := NewUsageCleanupService(repo, timingWheel, nil, cfg)
require.Equal(t, 5*time.Second, svc.workerInterval())
svc.Start()
svc.Stop()
cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback)
require.Equal(t, 31, svcFallback.maxRangeDays())
require.Equal(t, 5000, svcFallback.batchSize())
require.Equal(t, 10*time.Second, svcFallback.workerInterval())
svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback)
svcMissingDeps.Start()
}
func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) {
model := " "
apiKeyID := int64(-5)
accountID := int64(-1)
groupID := int64(-2)
filters := UsageCleanupFilters{
UserID: &apiKeyID,
APIKeyID: &apiKeyID,
AccountID: &accountID,
GroupID: &groupID,
Model: &model,
}
sanitizeUsageCleanupFilters(&filters)
require.Nil(t, filters.UserID)
require.Nil(t, filters.APIKeyID)
require.Nil(t, filters.AccountID)
require.Nil(t, filters.GroupID)
require.Nil(t, filters.Model)
}

View File

@@ -57,6 +57,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim
return svc
}
// ProvideUsageCleanupService 创建并启动使用记录清理任务服务
func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg)
svc.Start()
return svc
}
// ProvideAccountExpiryService creates and starts AccountExpiryService.
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
svc := NewAccountExpiryService(accountRepo, time.Minute)
@@ -248,6 +255,7 @@ var ProviderSet = wire.NewSet(
ProvideAccountExpiryService,
ProvideTimingWheelService,
ProvideDashboardAggregationService,
ProvideUsageCleanupService,
ProvideDeferredService,
NewAntigravityQuotaFetcher,
NewUserAttributeService,