diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 89c7175a..e66e0e05 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -55,31 +55,36 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client) - authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) - userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) + promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) + authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) + userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) - usageService := service.NewUsageService(usageLogRepository, userRepository, client) + dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db) + usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) - dashboardService := service.NewDashboardService(usageLogRepository) - dashboardHandler := admin.NewDashboardHandler(dashboardService) + dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig) + timingWheelService := service.ProvideTimingWheelService() + dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) + dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig) + dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) accountRepository := repository.NewAccountRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -124,7 +129,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingService := service.NewBillingService(configConfig, pricingService) identityCache := repository.NewIdentityCache(redisClient) identityService := service.NewIdentityService(identityCache) - timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) diff --git a/backend/go.mod b/backend/go.mod index 97f599f8..4ac6ba14 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -46,11 +46,13 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 0adfa4de..415e73a7 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= +github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 25c6cb65..31fbeed8 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -36,26 +36,29 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } type GeminiConfig struct { @@ -412,6 +415,55 @@ type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } +// APIKeyAuthCacheConfig API Key 认证缓存配置 +type APIKeyAuthCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + L2TTLSeconds int `mapstructure:"l2_ttl_seconds"` + NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` + Singleflight bool `mapstructure:"singleflight"` +} + +// DashboardCacheConfig 仪表盘统计缓存配置 +type DashboardCacheConfig struct { + // Enabled: 是否启用仪表盘缓存 + Enabled bool `mapstructure:"enabled"` + // KeyPrefix: Redis key 前缀,用于多环境隔离 + KeyPrefix string `mapstructure:"key_prefix"` + // StatsFreshTTLSeconds: 缓存命中认为“新鲜”的时间窗口(秒) + StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"` + // StatsTTLSeconds: Redis 缓存总 TTL(秒) + StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"` + // StatsRefreshTimeoutSeconds: 异步刷新超时(秒) + StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"` +} + +// DashboardAggregationConfig 仪表盘预聚合配置 +type DashboardAggregationConfig struct { + // Enabled: 是否启用预聚合作业 + Enabled bool `mapstructure:"enabled"` + // IntervalSeconds: 聚合刷新间隔(秒) + IntervalSeconds int `mapstructure:"interval_seconds"` + // LookbackSeconds: 回看窗口(秒) + LookbackSeconds int `mapstructure:"lookback_seconds"` + // BackfillEnabled: 是否允许全量回填 + BackfillEnabled bool `mapstructure:"backfill_enabled"` + // BackfillMaxDays: 回填最大跨度(天) + BackfillMaxDays int `mapstructure:"backfill_max_days"` + // Retention: 各表保留窗口(天) + Retention DashboardAggregationRetentionConfig `mapstructure:"retention"` + // RecomputeDays: 启动时重算最近 N 天 + RecomputeDays int `mapstructure:"recompute_days"` +} + +// DashboardAggregationRetentionConfig 预聚合保留窗口 +type DashboardAggregationRetentionConfig struct { + UsageLogsDays int `mapstructure:"usage_logs_days"` + HourlyDays int `mapstructure:"hourly_days"` + DailyDays int `mapstructure:"daily_days"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -465,6 +517,19 @@ func Load() (*Config, error) { cfg.Server.Mode = "debug" } cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) + cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) + cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL) + cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL) + cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL) + cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes) + cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL) + cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL) + cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod)) + cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) + cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) + cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) + cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) @@ -633,6 +698,32 @@ func setDefaults() { // Timezone (default to Asia/Shanghai for Chinese users) viper.SetDefault("timezone", "Asia/Shanghai") + // API Key auth cache + viper.SetDefault("api_key_auth_cache.l1_size", 65535) + viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15) + viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300) + viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30) + viper.SetDefault("api_key_auth_cache.jitter_percent", 10) + viper.SetDefault("api_key_auth_cache.singleflight", true) + + // Dashboard cache + viper.SetDefault("dashboard_cache.enabled", true) + viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") + viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15) + viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30) + viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30) + + // Dashboard aggregation + viper.SetDefault("dashboard_aggregation.enabled", true) + viper.SetDefault("dashboard_aggregation.interval_seconds", 60) + viper.SetDefault("dashboard_aggregation.lookback_seconds", 120) + viper.SetDefault("dashboard_aggregation.backfill_enabled", false) + viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) + viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) + viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) + viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) + viper.SetDefault("dashboard_aggregation.recompute_days", 2) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", true) @@ -788,6 +879,78 @@ func (c *Config) Validate() error { if c.Redis.MinIdleConns > c.Redis.PoolSize { return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size") } + if c.Dashboard.Enabled { + if c.Dashboard.StatsFreshTTLSeconds <= 0 { + return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive") + } + if c.Dashboard.StatsTTLSeconds <= 0 { + return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive") + } + if c.Dashboard.StatsRefreshTimeoutSeconds <= 0 { + return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive") + } + if c.Dashboard.StatsFreshTTLSeconds > c.Dashboard.StatsTTLSeconds { + return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be <= dashboard_cache.stats_ttl_seconds") + } + } else { + if c.Dashboard.StatsFreshTTLSeconds < 0 { + return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be non-negative") + } + if c.Dashboard.StatsTTLSeconds < 0 { + return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be non-negative") + } + if c.Dashboard.StatsRefreshTimeoutSeconds < 0 { + return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative") + } + } + if c.DashboardAgg.Enabled { + if c.DashboardAgg.IntervalSeconds <= 0 { + return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive") + } + if c.DashboardAgg.LookbackSeconds < 0 { + return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") + } + if c.DashboardAgg.BackfillMaxDays < 0 { + return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative") + } + if c.DashboardAgg.BackfillEnabled && c.DashboardAgg.BackfillMaxDays == 0 { + return fmt.Errorf("dashboard_aggregation.backfill_max_days must be positive") + } + if c.DashboardAgg.Retention.UsageLogsDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") + } + if c.DashboardAgg.Retention.HourlyDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive") + } + if c.DashboardAgg.Retention.DailyDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.daily_days must be positive") + } + if c.DashboardAgg.RecomputeDays < 0 { + return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") + } + } else { + if c.DashboardAgg.IntervalSeconds < 0 { + return fmt.Errorf("dashboard_aggregation.interval_seconds must be non-negative") + } + if c.DashboardAgg.LookbackSeconds < 0 { + return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") + } + if c.DashboardAgg.BackfillMaxDays < 0 { + return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative") + } + if c.DashboardAgg.Retention.UsageLogsDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") + } + if c.DashboardAgg.Retention.HourlyDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative") + } + if c.DashboardAgg.Retention.DailyDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.daily_days must be non-negative") + } + if c.DashboardAgg.RecomputeDays < 0 { + return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") + } + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index a39d41f9..1ba6d053 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -141,3 +141,142 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { t.Fatalf("Validate() expected use_pkce error, got: %v", err) } } + +func TestLoadDefaultDashboardCacheConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Dashboard.Enabled { + t.Fatalf("Dashboard.Enabled = false, want true") + } + if cfg.Dashboard.KeyPrefix != "sub2api:" { + t.Fatalf("Dashboard.KeyPrefix = %q, want %q", cfg.Dashboard.KeyPrefix, "sub2api:") + } + if cfg.Dashboard.StatsFreshTTLSeconds != 15 { + t.Fatalf("Dashboard.StatsFreshTTLSeconds = %d, want 15", cfg.Dashboard.StatsFreshTTLSeconds) + } + if cfg.Dashboard.StatsTTLSeconds != 30 { + t.Fatalf("Dashboard.StatsTTLSeconds = %d, want 30", cfg.Dashboard.StatsTTLSeconds) + } + if cfg.Dashboard.StatsRefreshTimeoutSeconds != 30 { + t.Fatalf("Dashboard.StatsRefreshTimeoutSeconds = %d, want 30", cfg.Dashboard.StatsRefreshTimeoutSeconds) + } +} + +func TestValidateDashboardCacheConfigEnabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Dashboard.Enabled = true + cfg.Dashboard.StatsFreshTTLSeconds = 10 + cfg.Dashboard.StatsTTLSeconds = 5 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for stats_fresh_ttl_seconds > stats_ttl_seconds, got nil") + } + if !strings.Contains(err.Error(), "dashboard_cache.stats_fresh_ttl_seconds") { + t.Fatalf("Validate() expected stats_fresh_ttl_seconds error, got: %v", err) + } +} + +func TestValidateDashboardCacheConfigDisabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Dashboard.Enabled = false + cfg.Dashboard.StatsTTLSeconds = -1 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for negative stats_ttl_seconds, got nil") + } + if !strings.Contains(err.Error(), "dashboard_cache.stats_ttl_seconds") { + t.Fatalf("Validate() expected stats_ttl_seconds error, got: %v", err) + } +} + +func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.DashboardAgg.Enabled { + t.Fatalf("DashboardAgg.Enabled = false, want true") + } + if cfg.DashboardAgg.IntervalSeconds != 60 { + t.Fatalf("DashboardAgg.IntervalSeconds = %d, want 60", cfg.DashboardAgg.IntervalSeconds) + } + if cfg.DashboardAgg.LookbackSeconds != 120 { + t.Fatalf("DashboardAgg.LookbackSeconds = %d, want 120", cfg.DashboardAgg.LookbackSeconds) + } + if cfg.DashboardAgg.BackfillEnabled { + t.Fatalf("DashboardAgg.BackfillEnabled = true, want false") + } + if cfg.DashboardAgg.BackfillMaxDays != 31 { + t.Fatalf("DashboardAgg.BackfillMaxDays = %d, want 31", cfg.DashboardAgg.BackfillMaxDays) + } + if cfg.DashboardAgg.Retention.UsageLogsDays != 90 { + t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays) + } + if cfg.DashboardAgg.Retention.HourlyDays != 180 { + t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays) + } + if cfg.DashboardAgg.Retention.DailyDays != 730 { + t.Fatalf("DashboardAgg.Retention.DailyDays = %d, want 730", cfg.DashboardAgg.Retention.DailyDays) + } + if cfg.DashboardAgg.RecomputeDays != 2 { + t.Fatalf("DashboardAgg.RecomputeDays = %d, want 2", cfg.DashboardAgg.RecomputeDays) + } +} + +func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.DashboardAgg.Enabled = false + cfg.DashboardAgg.IntervalSeconds = -1 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for negative dashboard_aggregation.interval_seconds, got nil") + } + if !strings.Contains(err.Error(), "dashboard_aggregation.interval_seconds") { + t.Fatalf("Validate() expected interval_seconds error, got: %v", err) + } +} + +func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.DashboardAgg.BackfillEnabled = true + cfg.DashboardAgg.BackfillMaxDays = 0 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for dashboard_aggregation.backfill_max_days, got nil") + } + if !strings.Contains(err.Error(), "dashboard_aggregation.backfill_max_days") { + t.Fatalf("Validate() expected backfill_max_days error, got: %v", err) + } +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 30cdd914..9b675974 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -1,6 +1,7 @@ package admin import ( + "errors" "strconv" "time" @@ -13,15 +14,17 @@ import ( // DashboardHandler handles admin dashboard statistics type DashboardHandler struct { - dashboardService *service.DashboardService - startTime time.Time // Server start time for uptime calculation + dashboardService *service.DashboardService + aggregationService *service.DashboardAggregationService + startTime time.Time // Server start time for uptime calculation } // NewDashboardHandler creates a new admin dashboard handler -func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler { +func NewDashboardHandler(dashboardService *service.DashboardService, aggregationService *service.DashboardAggregationService) *DashboardHandler { return &DashboardHandler{ - dashboardService: dashboardService, - startTime: time.Now(), + dashboardService: dashboardService, + aggregationService: aggregationService, + startTime: time.Now(), } } @@ -114,6 +117,58 @@ func (h *DashboardHandler) GetStats(c *gin.Context) { // 性能指标 "rpm": stats.Rpm, "tpm": stats.Tpm, + + // 预聚合新鲜度 + "hourly_active_users": stats.HourlyActiveUsers, + "stats_updated_at": stats.StatsUpdatedAt, + "stats_stale": stats.StatsStale, + }) +} + +type DashboardAggregationBackfillRequest struct { + Start string `json:"start"` + End string `json:"end"` +} + +// BackfillAggregation handles triggering aggregation backfill +// POST /api/v1/admin/dashboard/aggregation/backfill +func (h *DashboardHandler) BackfillAggregation(c *gin.Context) { + if h.aggregationService == nil { + response.InternalError(c, "Aggregation service not available") + return + } + + var req DashboardAggregationBackfillRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + start, err := time.Parse(time.RFC3339, req.Start) + if err != nil { + response.BadRequest(c, "Invalid start time") + return + } + end, err := time.Parse(time.RFC3339, req.End) + if err != nil { + response.BadRequest(c, "Invalid end time") + return + } + + if err := h.aggregationService.TriggerBackfill(start, end); err != nil { + if errors.Is(err, service.ErrDashboardBackfillDisabled) { + response.Forbidden(c, "Backfill is disabled") + return + } + if errors.Is(err, service.ErrDashboardBackfillTooLarge) { + response.BadRequest(c, "Backfill range too large") + return + } + response.InternalError(c, "Failed to trigger backfill") + return + } + + response.Success(c, gin.H{ + "status": "accepted", }) } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index eba69006..5f3474b0 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -8,6 +8,7 @@ import ( "io" "log" "net/http" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -94,15 +95,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } - // For non-Codex CLI requests, set default instructions userAgent := c.GetHeader("User-Agent") if !openai.IsCodexCLIRequest(userAgent) { - reqBody["instructions"] = openai.DefaultInstructions - // Re-serialize body - body, err = json.Marshal(reqBody) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return + existingInstructions, _ := reqBody["instructions"].(string) + if strings.TrimSpace(existingInstructions) == "" { + if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { + reqBody["instructions"] = instructions + // Re-serialize body + body, err = json.Marshal(reqBody) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + } } } diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 39314602..3952785b 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -9,6 +9,12 @@ type DashboardStats struct { TotalUsers int64 `json:"total_users"` TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数 ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数 + // 小时活跃用户数(UTC 当前小时) + HourlyActiveUsers int64 `json:"hourly_active_users"` + + // 预聚合新鲜度 + StatsUpdatedAt string `json:"stats_updated_at"` + StatsStale bool `json:"stats_stale"` // API Key 统计 TotalAPIKeys int64 `json:"total_api_keys"` diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go index 73a929c5..6d834b40 100644 --- a/backend/internal/repository/api_key_cache.go +++ b/backend/internal/repository/api_key_cache.go @@ -2,6 +2,7 @@ package repository import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -13,6 +14,7 @@ import ( const ( apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" apiKeyRateLimitDuration = 24 * time.Hour + apiKeyAuthCachePrefix = "apikey:auth:" ) // apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. @@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string { return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) } +func apiKeyAuthCacheKey(key string) string { + return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key) +} + type apiKeyCache struct { rdb *redis.Client } @@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { return c.rdb.Expire(ctx, apiKey, ttl).Err() } + +func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes() + if err != nil { + return nil, err + } + var entry service.APIKeyAuthCacheEntry + if err := json.Unmarshal(val, &entry); err != nil { + return nil, err + } + return &entry, nil +} + +func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + if entry == nil { + return nil + } + payload, err := json.Marshal(entry) + if err != nil { + return err + } + return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err() +} + +func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err() +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 6b8cd40d..77a3f233 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -6,7 +6,9 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK return apiKeyEntityToService(m), nil } -// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。 +// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。 // 相比 GetByID,此方法性能更优,因为: -// - 使用 Select() 只查询 user_id 字段,减少数据传输量 +// - 使用 Select() 只查询必要字段,减少数据传输量 // - 不加载完整的 API Key 实体及其关联数据(User、Group 等) -// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) -func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { +// - 适用于删除等只需 key 与用户 ID 的场景 +func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { m, err := r.activeQuery(). Where(apikey.IDEQ(id)). - Select(apikey.FieldUserID). + Select(apikey.FieldKey, apikey.FieldUserID). Only(ctx) if err != nil { if dbent.IsNotFound(err) { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return 0, err + return "", 0, err } - return m.UserID, nil + return m.Key, m.UserID, nil } func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A return apiKeyEntityToService(m), nil } +func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + m, err := r.activeQuery(). + Where(apikey.KeyEQ(key)). + Select( + apikey.FieldID, + apikey.FieldUserID, + apikey.FieldGroupID, + apikey.FieldStatus, + apikey.FieldIPWhitelist, + apikey.FieldIPBlacklist, + ). + WithUser(func(q *dbent.UserQuery) { + q.Select( + user.FieldID, + user.FieldStatus, + user.FieldRole, + user.FieldBalance, + user.FieldConcurrency, + ) + }). + WithGroup(func(q *dbent.GroupQuery) { + q.Select( + group.FieldID, + group.FieldName, + group.FieldPlatform, + group.FieldStatus, + group.FieldSubscriptionType, + group.FieldRateMultiplier, + group.FieldDailyLimitUsd, + group.FieldWeeklyLimitUsd, + group.FieldMonthlyLimitUsd, + group.FieldImagePrice1k, + group.FieldImagePrice2k, + group.FieldImagePrice4k, + group.FieldClaudeCodeOnly, + group.FieldFallbackGroupID, + ) + }). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return apiKeyEntityToService(m), nil +} + func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error { // 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。 // 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除, @@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i return int64(count), err } +func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.UserIDEQ(userID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + +func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.GroupIDEQ(groupID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go new file mode 100644 index 00000000..5241c468 --- /dev/null +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -0,0 +1,363 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +type dashboardAggregationRepository struct { + sql sqlExecutor +} + +// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。 +func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository { + return newDashboardAggregationRepositoryWithSQL(sqlDB) +} + +func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository { + return &dashboardAggregationRepository{sql: sqlq} +} + +func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error { + startUTC := start.UTC() + endUTC := end.UTC() + if !endUTC.After(startUTC) { + return nil + } + + hourStart := startUTC.Truncate(time.Hour) + hourEnd := endUTC.Truncate(time.Hour) + if endUTC.After(hourEnd) { + hourEnd = hourEnd.Add(time.Hour) + } + + dayStart := truncateToDayUTC(startUTC) + dayEnd := truncateToDayUTC(endUTC) + if endUTC.After(dayEnd) { + dayEnd = dayEnd.Add(24 * time.Hour) + } + + // 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。 + if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil { + return err + } + return nil +} + +func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) { + var ts time.Time + query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1" + if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil { + if err == sql.ErrNoRows { + return time.Unix(0, 0).UTC(), nil + } + return time.Time{}, err + } + return ts.UTC(), nil +} + +func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error { + query := ` + INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at) + VALUES (1, $1, NOW()) + ON CONFLICT (id) + DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at + ` + _, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC()) + return err +} + +func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error { + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1; + DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1; + DELETE FROM usage_dashboard_daily WHERE bucket_date < $2::date; + DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $2::date; + `, hourlyCutoff.UTC(), dailyCutoff.UTC()) + return err +} + +func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + isPartitioned, err := r.isUsageLogsPartitioned(ctx) + if err != nil { + return err + } + if isPartitioned { + return r.dropUsageLogsPartitions(ctx, cutoff) + } + _, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC()) + return err +} + +func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { + isPartitioned, err := r.isUsageLogsPartitioned(ctx) + if err != nil || !isPartitioned { + return err + } + monthStart := truncateToMonthUTC(now) + prevMonth := monthStart.AddDate(0, -1, 0) + nextMonth := monthStart.AddDate(0, 1, 0) + + for _, m := range []time.Time{prevMonth, monthStart, nextMonth} { + if err := r.createUsageLogsPartition(ctx, m); err != nil { + return err + } + } + return nil +} + +func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error { + query := ` + INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id) + SELECT DISTINCT + date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start, + user_id + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ON CONFLICT DO NOTHING + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error { + query := ` + INSERT INTO usage_dashboard_daily_users (bucket_date, user_id) + SELECT DISTINCT + (bucket_start AT TIME ZONE 'UTC')::date AS bucket_date, + user_id + FROM usage_dashboard_hourly_users + WHERE bucket_start >= $1 AND bucket_start < $2 + ON CONFLICT DO NOTHING + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error { + query := ` + WITH hourly AS ( + SELECT + date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start, + COUNT(*) AS total_requests, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens, + COALESCE(SUM(total_cost), 0) AS total_cost, + COALESCE(SUM(actual_cost), 0) AS actual_cost, + COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + GROUP BY 1 + ), + user_counts AS ( + SELECT bucket_start, COUNT(*) AS active_users + FROM usage_dashboard_hourly_users + WHERE bucket_start >= $1 AND bucket_start < $2 + GROUP BY bucket_start + ) + INSERT INTO usage_dashboard_hourly ( + bucket_start, + total_requests, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + total_duration_ms, + active_users, + computed_at + ) + SELECT + hourly.bucket_start, + hourly.total_requests, + hourly.input_tokens, + hourly.output_tokens, + hourly.cache_creation_tokens, + hourly.cache_read_tokens, + hourly.total_cost, + hourly.actual_cost, + hourly.total_duration_ms, + COALESCE(user_counts.active_users, 0) AS active_users, + NOW() + FROM hourly + LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start + ON CONFLICT (bucket_start) + DO UPDATE SET + total_requests = EXCLUDED.total_requests, + input_tokens = EXCLUDED.input_tokens, + output_tokens = EXCLUDED.output_tokens, + cache_creation_tokens = EXCLUDED.cache_creation_tokens, + cache_read_tokens = EXCLUDED.cache_read_tokens, + total_cost = EXCLUDED.total_cost, + actual_cost = EXCLUDED.actual_cost, + total_duration_ms = EXCLUDED.total_duration_ms, + active_users = EXCLUDED.active_users, + computed_at = EXCLUDED.computed_at + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error { + query := ` + WITH daily AS ( + SELECT + (bucket_start AT TIME ZONE 'UTC')::date AS bucket_date, + COALESCE(SUM(total_requests), 0) AS total_requests, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens, + COALESCE(SUM(total_cost), 0) AS total_cost, + COALESCE(SUM(actual_cost), 0) AS actual_cost, + COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms + FROM usage_dashboard_hourly + WHERE bucket_start >= $1 AND bucket_start < $2 + GROUP BY (bucket_start AT TIME ZONE 'UTC')::date + ), + user_counts AS ( + SELECT bucket_date, COUNT(*) AS active_users + FROM usage_dashboard_daily_users + WHERE bucket_date >= $3::date AND bucket_date < $4::date + GROUP BY bucket_date + ) + INSERT INTO usage_dashboard_daily ( + bucket_date, + total_requests, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + total_duration_ms, + active_users, + computed_at + ) + SELECT + daily.bucket_date, + daily.total_requests, + daily.input_tokens, + daily.output_tokens, + daily.cache_creation_tokens, + daily.cache_read_tokens, + daily.total_cost, + daily.actual_cost, + daily.total_duration_ms, + COALESCE(user_counts.active_users, 0) AS active_users, + NOW() + FROM daily + LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date + ON CONFLICT (bucket_date) + DO UPDATE SET + total_requests = EXCLUDED.total_requests, + input_tokens = EXCLUDED.input_tokens, + output_tokens = EXCLUDED.output_tokens, + cache_creation_tokens = EXCLUDED.cache_creation_tokens, + cache_read_tokens = EXCLUDED.cache_read_tokens, + total_cost = EXCLUDED.total_cost, + actual_cost = EXCLUDED.actual_cost, + total_duration_ms = EXCLUDED.total_duration_ms, + active_users = EXCLUDED.active_users, + computed_at = EXCLUDED.computed_at + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) { + query := ` + SELECT EXISTS( + SELECT 1 + FROM pg_partitioned_table pt + JOIN pg_class c ON c.oid = pt.partrelid + WHERE c.relname = 'usage_logs' + ) + ` + var partitioned bool + if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil { + return false, err + } + return partitioned, nil +} + +func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error { + rows, err := r.sql.QueryContext(ctx, ` + SELECT c.relname + FROM pg_inherits + JOIN pg_class c ON c.oid = pg_inherits.inhrelid + JOIN pg_class p ON p.oid = pg_inherits.inhparent + WHERE p.relname = 'usage_logs' + `) + if err != nil { + return err + } + defer func() { + _ = rows.Close() + }() + + cutoffMonth := truncateToMonthUTC(cutoff) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return err + } + if !strings.HasPrefix(name, "usage_logs_") { + continue + } + suffix := strings.TrimPrefix(name, "usage_logs_") + month, err := time.Parse("200601", suffix) + if err != nil { + continue + } + month = month.UTC() + if month.Before(cutoffMonth) { + if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil { + return err + } + } + } + return rows.Err() +} + +func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error { + monthStart := truncateToMonthUTC(month) + nextMonth := monthStart.AddDate(0, 1, 0) + name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601")) + query := fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)", + pq.QuoteIdentifier(name), + pq.QuoteLiteral(monthStart.Format("2006-01-02")), + pq.QuoteLiteral(nextMonth.Format("2006-01-02")), + ) + _, err := r.sql.ExecContext(ctx, query) + return err +} + +func truncateToDayUTC(t time.Time) time.Time { + t = t.UTC() + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) +} + +func truncateToMonthUTC(t time.Time) time.Time { + t = t.UTC() + return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC) +} diff --git a/backend/internal/repository/dashboard_cache.go b/backend/internal/repository/dashboard_cache.go new file mode 100644 index 00000000..f996cd68 --- /dev/null +++ b/backend/internal/repository/dashboard_cache.go @@ -0,0 +1,58 @@ +package repository + +import ( + "context" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const dashboardStatsCacheKey = "dashboard:stats:v1" + +type dashboardCache struct { + rdb *redis.Client + keyPrefix string +} + +func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache { + prefix := "sub2api:" + if cfg != nil { + prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) + } + if prefix != "" && !strings.HasSuffix(prefix, ":") { + prefix += ":" + } + return &dashboardCache{ + rdb: rdb, + keyPrefix: prefix, + } +} + +func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) { + val, err := c.rdb.Get(ctx, c.buildKey()).Result() + if err != nil { + if err == redis.Nil { + return "", service.ErrDashboardStatsCacheMiss + } + return "", err + } + return val, nil +} + +func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error { + return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err() +} + +func (c *dashboardCache) buildKey() string { + if c.keyPrefix == "" { + return dashboardStatsCacheKey + } + return c.keyPrefix + dashboardStatsCacheKey +} + +func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error { + return c.rdb.Del(ctx, c.buildKey()).Err() +} diff --git a/backend/internal/repository/dashboard_cache_test.go b/backend/internal/repository/dashboard_cache_test.go new file mode 100644 index 00000000..3bb0da4f --- /dev/null +++ b/backend/internal/repository/dashboard_cache_test.go @@ -0,0 +1,28 @@ +package repository + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestNewDashboardCacheKeyPrefix(t *testing.T) { + cache := NewDashboardCache(nil, &config.Config{ + Dashboard: config.DashboardCacheConfig{ + KeyPrefix: "prod", + }, + }) + impl, ok := cache.(*dashboardCache) + require.True(t, ok) + require.Equal(t, "prod:", impl.keyPrefix) + + cache = NewDashboardCache(nil, &config.Config{ + Dashboard: config.DashboardCacheConfig{ + KeyPrefix: "staging:", + }, + }) + impl, ok = cache.(*dashboardCache) + require.True(t, ok) + require.Equal(t, "staging:", impl.keyPrefix) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 6ed8910e..e483f89f 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -269,16 +269,60 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta type DashboardStats = usagestats.DashboardStats func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { - var stats DashboardStats - today := timezone.Today() - now := time.Now() + stats := &DashboardStats{} + now := time.Now().UTC() + todayUTC := truncateToDayUTC(now) - // 合并用户统计查询 + if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil { + return nil, err + } + if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil { + return nil, err + } + + rpm, tpm, err := r.getPerformanceStats(ctx, 0) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + +func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) { + startUTC := start.UTC() + endUTC := end.UTC() + if !endUTC.After(startUTC) { + return nil, errors.New("统计时间范围无效") + } + + stats := &DashboardStats{} + now := time.Now().UTC() + todayUTC := truncateToDayUTC(now) + + if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil { + return nil, err + } + if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil { + return nil, err + } + + rpm, tpm, err := r.getPerformanceStats(ctx, 0) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + +func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error { userStatsQuery := ` SELECT COUNT(*) as total_users, - COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users, - (SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users + COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users FROM users WHERE deleted_at IS NULL ` @@ -286,15 +330,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ctx, r.sql, userStatsQuery, - []any{today, today}, + []any{todayUTC}, &stats.TotalUsers, &stats.TodayNewUsers, - &stats.ActiveUsers, ); err != nil { - return nil, err + return err } - // 合并API Key统计查询 apiKeyStatsQuery := ` SELECT COUNT(*) as total_api_keys, @@ -310,10 +352,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TotalAPIKeys, &stats.ActiveAPIKeys, ); err != nil { - return nil, err + return err } - // 合并账户统计查询 accountStatsQuery := ` SELECT COUNT(*) as total_accounts, @@ -335,22 +376,26 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.RateLimitAccounts, &stats.OverloadAccounts, ); err != nil { - return nil, err + return err } - // 累计 Token 统计 + return nil +} + +func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error { totalStatsQuery := ` SELECT - COUNT(*) as total_requests, + COALESCE(SUM(total_requests), 0) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(output_tokens), 0) as total_output_tokens, COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(duration_ms), 0) as avg_duration_ms - FROM usage_logs + COALESCE(SUM(total_duration_ms), 0) as total_duration_ms + FROM usage_dashboard_daily ` + var totalDurationMs int64 if err := scanSingleRow( ctx, r.sql, @@ -363,13 +408,100 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TotalCacheReadTokens, &stats.TotalCost, &stats.TotalActualCost, - &stats.AverageDurationMs, + &totalDurationMs, ); err != nil { - return nil, err + return err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) + } - // 今日 Token 统计 + todayStatsQuery := ` + SELECT + total_requests as today_requests, + input_tokens as today_input_tokens, + output_tokens as today_output_tokens, + cache_creation_tokens as today_cache_creation_tokens, + cache_read_tokens as today_cache_read_tokens, + total_cost as today_cost, + actual_cost as today_actual_cost, + active_users as active_users + FROM usage_dashboard_daily + WHERE bucket_date = $1::date + ` + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{todayUTC}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + &stats.ActiveUsers, + ); err != nil { + if err != sql.ErrNoRows { + return err + } + } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + + hourlyActiveQuery := ` + SELECT active_users + FROM usage_dashboard_hourly + WHERE bucket_start = $1 + ` + hourStart := now.UTC().Truncate(time.Hour) + if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil { + if err != sql.ErrNoRows { + return err + } + } + + return nil +} + +func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error { + totalStatsQuery := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ` + var totalDurationMs int64 + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + []any{startUTC, endUTC}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &totalDurationMs, + ); err != nil { + return err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) + } + + todayEnd := todayUTC.Add(24 * time.Hour) todayStatsQuery := ` SELECT COUNT(*) as today_requests, @@ -380,13 +512,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS COALESCE(SUM(total_cost), 0) as today_cost, COALESCE(SUM(actual_cost), 0) as today_actual_cost FROM usage_logs - WHERE created_at >= $1 + WHERE created_at >= $1 AND created_at < $2 ` if err := scanSingleRow( ctx, r.sql, todayStatsQuery, - []any{today}, + []any{todayUTC, todayEnd}, &stats.TodayRequests, &stats.TodayInputTokens, &stats.TodayOutputTokens, @@ -395,19 +527,31 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TodayCost, &stats.TodayActualCost, ); err != nil { - return nil, err + return err } stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - // 性能指标:RPM 和 TPM(最近1分钟,全局) - rpm, tpm, err := r.getPerformanceStats(ctx, 0) - if err != nil { - return nil, err + activeUsersQuery := ` + SELECT COUNT(DISTINCT user_id) as active_users + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ` + if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil { + return err } - stats.Rpm = rpm - stats.Tpm = tpm - return &stats, nil + hourStart := now.UTC().Truncate(time.Hour) + hourEnd := hourStart.Add(time.Hour) + hourlyActiveQuery := ` + SELECT COUNT(DISTINCT user_id) as active_users + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ` + if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil { + return err + } + + return nil } func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 7193718f..51964782 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -11,7 +11,6 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" @@ -198,8 +197,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() { // --- GetDashboardStats --- func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { - now := time.Now() - todayStart := timezone.Today() + now := time.Now().UTC() + todayStart := truncateToDayUTC(now) baseStats, err := s.repo.GetDashboardStats(s.ctx) s.Require().NoError(err, "GetDashboardStats base") @@ -273,6 +272,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { _, err = s.repo.Create(s.ctx, logPerf) s.Require().NoError(err, "Create logPerf") + aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx) + aggStart := todayStart.Add(-2 * time.Hour) + aggEnd := now.Add(2 * time.Minute) + s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd), "AggregateRange") + stats, err := s.repo.GetDashboardStats(s.ctx) s.Require().NoError(err, "GetDashboardStats") @@ -303,6 +307,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch") } +func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() { + now := time.Now().UTC() + todayStart := truncateToDayUTC(now) + rangeStart := todayStart.Add(-24 * time.Hour) + rangeEnd := now.Add(1 * time.Second) + + user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"}) + user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-range-1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-range-2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-range"}) + + d1, d2, d3 := 100, 200, 300 + logOutside := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 7, + OutputTokens: 8, + TotalCost: 0.8, + ActualCost: 0.7, + DurationMs: &d3, + CreatedAt: rangeStart.Add(-1 * time.Hour), + } + _, err := s.repo.Create(s.ctx, logOutside) + s.Require().NoError(err) + + logRange := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 1, + CacheReadTokens: 2, + TotalCost: 1.0, + ActualCost: 0.9, + DurationMs: &d1, + CreatedAt: rangeStart.Add(2 * time.Hour), + } + _, err = s.repo.Create(s.ctx, logRange) + s.Require().NoError(err) + + logToday := &service.UsageLog{ + UserID: user2.ID, + APIKeyID: apiKey2.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 5, + OutputTokens: 6, + CacheReadTokens: 1, + TotalCost: 0.5, + ActualCost: 0.5, + DurationMs: &d2, + CreatedAt: now, + } + _, err = s.repo.Create(s.ctx, logToday) + s.Require().NoError(err) + + stats, err := s.repo.GetDashboardStatsWithRange(s.ctx, rangeStart, rangeEnd) + s.Require().NoError(err) + s.Require().Equal(int64(2), stats.TotalRequests) + s.Require().Equal(int64(15), stats.TotalInputTokens) + s.Require().Equal(int64(26), stats.TotalOutputTokens) + s.Require().Equal(int64(1), stats.TotalCacheCreationTokens) + s.Require().Equal(int64(3), stats.TotalCacheReadTokens) + s.Require().Equal(int64(45), stats.TotalTokens) + s.Require().Equal(1.5, stats.TotalCost) + s.Require().Equal(1.4, stats.TotalActualCost) + s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001) +} + // --- GetUserDashboardStats --- func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { @@ -333,6 +411,151 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { s.Require().Equal(int64(30), stats.Tokens) } +func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() { + now := time.Now().UTC().Truncate(time.Second) + hour1 := now.Add(-90 * time.Minute).Truncate(time.Hour) + hour2 := now.Add(-30 * time.Minute).Truncate(time.Hour) + dayStart := truncateToDayUTC(now) + + user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u1@test.com"}) + user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-agg-1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-agg-2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-agg"}) + + d1, d2, d3 := 100, 200, 150 + log1 := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 2, + CacheReadTokens: 1, + TotalCost: 1.0, + ActualCost: 0.9, + DurationMs: &d1, + CreatedAt: hour1.Add(5 * time.Minute), + } + _, err := s.repo.Create(s.ctx, log1) + s.Require().NoError(err) + + log2 := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 5, + OutputTokens: 5, + TotalCost: 0.5, + ActualCost: 0.5, + DurationMs: &d2, + CreatedAt: hour1.Add(20 * time.Minute), + } + _, err = s.repo.Create(s.ctx, log2) + s.Require().NoError(err) + + log3 := &service.UsageLog{ + UserID: user2.ID, + APIKeyID: apiKey2.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 7, + OutputTokens: 8, + TotalCost: 0.7, + ActualCost: 0.7, + DurationMs: &d3, + CreatedAt: hour2.Add(10 * time.Minute), + } + _, err = s.repo.Create(s.ctx, log3) + s.Require().NoError(err) + + aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx) + aggStart := hour1.Add(-5 * time.Minute) + aggEnd := now.Add(5 * time.Minute) + s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd)) + + type hourlyRow struct { + totalRequests int64 + inputTokens int64 + outputTokens int64 + cacheCreationTokens int64 + cacheReadTokens int64 + totalCost float64 + actualCost float64 + totalDurationMs int64 + activeUsers int64 + } + fetchHourly := func(bucketStart time.Time) hourlyRow { + var row hourlyRow + err := scanSingleRow(s.ctx, s.tx, ` + SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, + total_cost, actual_cost, total_duration_ms, active_users + FROM usage_dashboard_hourly + WHERE bucket_start = $1 + `, []any{bucketStart}, &row.totalRequests, &row.inputTokens, &row.outputTokens, + &row.cacheCreationTokens, &row.cacheReadTokens, &row.totalCost, &row.actualCost, + &row.totalDurationMs, &row.activeUsers, + ) + s.Require().NoError(err) + return row + } + + hour1Row := fetchHourly(hour1) + s.Require().Equal(int64(2), hour1Row.totalRequests) + s.Require().Equal(int64(15), hour1Row.inputTokens) + s.Require().Equal(int64(25), hour1Row.outputTokens) + s.Require().Equal(int64(2), hour1Row.cacheCreationTokens) + s.Require().Equal(int64(1), hour1Row.cacheReadTokens) + s.Require().Equal(1.5, hour1Row.totalCost) + s.Require().Equal(1.4, hour1Row.actualCost) + s.Require().Equal(int64(300), hour1Row.totalDurationMs) + s.Require().Equal(int64(1), hour1Row.activeUsers) + + hour2Row := fetchHourly(hour2) + s.Require().Equal(int64(1), hour2Row.totalRequests) + s.Require().Equal(int64(7), hour2Row.inputTokens) + s.Require().Equal(int64(8), hour2Row.outputTokens) + s.Require().Equal(int64(0), hour2Row.cacheCreationTokens) + s.Require().Equal(int64(0), hour2Row.cacheReadTokens) + s.Require().Equal(0.7, hour2Row.totalCost) + s.Require().Equal(0.7, hour2Row.actualCost) + s.Require().Equal(int64(150), hour2Row.totalDurationMs) + s.Require().Equal(int64(1), hour2Row.activeUsers) + + var daily struct { + totalRequests int64 + inputTokens int64 + outputTokens int64 + cacheCreationTokens int64 + cacheReadTokens int64 + totalCost float64 + actualCost float64 + totalDurationMs int64 + activeUsers int64 + } + err = scanSingleRow(s.ctx, s.tx, ` + SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, + total_cost, actual_cost, total_duration_ms, active_users + FROM usage_dashboard_daily + WHERE bucket_date = $1::date + `, []any{dayStart}, &daily.totalRequests, &daily.inputTokens, &daily.outputTokens, + &daily.cacheCreationTokens, &daily.cacheReadTokens, &daily.totalCost, &daily.actualCost, + &daily.totalDurationMs, &daily.activeUsers, + ) + s.Require().NoError(err) + s.Require().Equal(int64(3), daily.totalRequests) + s.Require().Equal(int64(22), daily.inputTokens) + s.Require().Equal(int64(33), daily.outputTokens) + s.Require().Equal(int64(2), daily.cacheCreationTokens) + s.Require().Equal(int64(1), daily.cacheReadTokens) + s.Require().Equal(2.2, daily.totalCost) + s.Require().Equal(2.1, daily.actualCost) + s.Require().Equal(int64(450), daily.totalDurationMs) + s.Require().Equal(int64(2), daily.activeUsers) +} + // --- GetBatchUserUsageStats --- func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 6c1f5851..e1c6c3d4 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet( NewRedeemCodeRepository, NewPromoCodeRepository, NewUsageLogRepository, + NewDashboardAggregationRepository, NewSettingRepository, NewOpsRepository, NewUserSubscriptionRepository, @@ -59,6 +60,7 @@ var ProviderSet = wire.NewSet( NewAPIKeyCache, NewTempUnschedCache, ProvideConcurrencyCache, + NewDashboardCache, NewEmailCache, NewIdentityCache, NewRedeemCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 330f36ad..6d8f67e3 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -331,6 +331,30 @@ func TestAPIContracts(t *testing.T) { } }`, }, + { + name: "POST /api/v1/admin/accounts/bulk-update", + method: http.MethodPost, + path: "/api/v1/admin/accounts/bulk-update", + body: `{"account_ids":[101,102],"schedulable":false}`, + headers: map[string]string{ + "Content-Type": "application/json", + }, + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "success": 2, + "failed": 0, + "success_ids": [101, 102], + "failed_ids": [], + "results": [ + {"account_id": 101, "success": true}, + {"account_id": 102, "success": true} + ] + } + }`, + }, } for _, tt := range tests { @@ -382,6 +406,9 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyCache := stubApiKeyCache{} groupRepo := stubGroupRepo{} userSubRepo := stubUserSubscriptionRepo{} + accountRepo := stubAccountRepo{} + proxyRepo := stubProxyRepo{} + redeemRepo := stubRedeemCodeRepo{} cfg := &config.Config{ Default: config.DefaultConfig{ @@ -390,19 +417,21 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo) + userService := service.NewUserService(userRepo, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() - usageService := service.NewUsageService(usageRepo, userRepo, nil) + usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ @@ -442,6 +471,7 @@ func newContractDeps(t *testing.T) *contractDeps { v1Admin := v1.Group("/admin") v1Admin.Use(adminAuth) v1Admin.GET("/settings", adminSettingHandler.GetSettings) + v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate) return &contractDeps{ now: now, @@ -566,6 +596,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t return nil } +func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return nil +} + type stubGroupRepo struct{} func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { @@ -620,6 +662,235 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i return 0, errors.New("not implemented") } +type stubAccountRepo struct { + bulkUpdateIDs []int64 +} + +func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) { + return nil, service.ErrAccountNotFound +} + +func (s *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) { + return false, errors.New("not implemented") +} + +func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + return 0, errors.New("not implemented") +} + +func (s *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { + s.bulkUpdateIDs = append([]int64{}, ids...) + return int64(len(ids)), nil +} + +type stubProxyRepo struct{} + +func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error { + return errors.New("not implemented") +} + +func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, error) { + return nil, service.ErrProxyNotFound +} + +func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error { + return errors.New("not implemented") +} + +func (stubProxyRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (stubProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListActive(ctx context.Context) ([]service.Proxy, error) { + return nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) { + return nil, errors.New("not implemented") +} + +func (stubProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + return false, errors.New("not implemented") +} + +func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubRedeemCodeRepo struct{} + +func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) CreateBatch(ctx context.Context, codes []service.RedeemCode) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) { + return nil, service.ErrRedeemCodeNotFound +} + +func (stubRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) { + return nil, service.ErrRedeemCodeNotFound +} + +func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) { + return nil, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct{} func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { @@ -738,12 +1009,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return &clone, nil } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { key, ok := r.byID[id] if !ok { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return key.UserID, nil + return key.Key, key.UserID, nil } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -755,6 +1026,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return &clone, nil } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { if key == nil { return errors.New("nil key") @@ -869,6 +1144,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 07b8e370..6f09469b 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { return nil, errors.New("not implemented") } -func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { if f.getByKey == nil { @@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK } return f.getByKey(ctx, key) } +func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return f.GetByKey(ctx, key) +} func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64 func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } +func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} +func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 182ea5f8..84398093 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return nil, errors.New("not implemented") } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index f3e66d04..a2f1b8c7 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -130,6 +130,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage) + dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation) } } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 14bb6daf..1874c5c1 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -186,9 +186,11 @@ type BulkUpdateAccountResult struct { // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { - Success int `json:"success"` - Failed int `json:"failed"` - Results []BulkUpdateAccountResult `json:"results"` + Success int `json:"success"` + Failed int `json:"failed"` + SuccessIDs []int64 `json:"success_ids"` + FailedIDs []int64 `json:"failed_ids"` + Results []BulkUpdateAccountResult `json:"results"` } type CreateProxyInput struct { @@ -244,14 +246,15 @@ type ProxyExitInfoProber interface { // adminServiceImpl implements AdminService type adminServiceImpl struct { - userRepo UserRepository - groupRepo GroupRepository - accountRepo AccountRepository - proxyRepo ProxyRepository - apiKeyRepo APIKeyRepository - redeemCodeRepo RedeemCodeRepository - billingCacheService *BillingCacheService - proxyProber ProxyExitInfoProber + userRepo UserRepository + groupRepo GroupRepository + accountRepo AccountRepository + proxyRepo ProxyRepository + apiKeyRepo APIKeyRepository + redeemCodeRepo RedeemCodeRepository + billingCacheService *BillingCacheService + proxyProber ProxyExitInfoProber + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewAdminService creates a new AdminService @@ -264,16 +267,18 @@ func NewAdminService( redeemCodeRepo RedeemCodeRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) AdminService { return &adminServiceImpl{ - userRepo: userRepo, - groupRepo: groupRepo, - accountRepo: accountRepo, - proxyRepo: proxyRepo, - apiKeyRepo: apiKeyRepo, - redeemCodeRepo: redeemCodeRepo, - billingCacheService: billingCacheService, - proxyProber: proxyProber, + userRepo: userRepo, + groupRepo: groupRepo, + accountRepo: accountRepo, + proxyRepo: proxyRepo, + apiKeyRepo: apiKeyRepo, + redeemCodeRepo: redeemCodeRepo, + billingCacheService: billingCacheService, + proxyProber: proxyProber, + authCacheInvalidator: authCacheInvalidator, } } @@ -323,6 +328,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda } oldConcurrency := user.Concurrency + oldStatus := user.Status + oldRole := user.Role if input.Email != "" { user.Email = input.Email @@ -355,6 +362,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) + } + } concurrencyDiff := user.Concurrency - oldConcurrency if concurrencyDiff != 0 { @@ -393,6 +405,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { log.Printf("delete user failed: user_id=%d err=%v", id, err) return err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) + } return nil } @@ -420,6 +435,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + balanceDiff := user.Balance - oldBalance + if s.authCacheInvalidator != nil && balanceDiff != 0 { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } if s.billingCacheService != nil { go func() { @@ -431,7 +450,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, }() } - balanceDiff := user.Balance - oldBalance if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { @@ -675,10 +693,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { + var groupKeys []string + if s.authCacheInvalidator != nil { + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) + if err == nil { + groupKeys = keys + } + } + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) if err != nil { return err @@ -697,6 +726,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { } }() } + if s.authCacheInvalidator != nil { + for _, key := range groupKeys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) + } + } return nil } @@ -885,7 +919,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U // It merges credentials/extra keys instead of overwriting the whole object. func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { result := &BulkUpdateAccountsResult{ - Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), + SuccessIDs: make([]int64, 0, len(input.AccountIDs)), + FailedIDs: make([]int64, 0, len(input.AccountIDs)), + Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), } if len(input.AccountIDs) == 0 { @@ -949,6 +985,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry.Success = false entry.Error = err.Error() result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } @@ -958,6 +995,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry.Success = false entry.Error = err.Error() result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } @@ -967,6 +1005,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry.Success = false entry.Error = err.Error() result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } @@ -974,6 +1013,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry.Success = true result.Success++ + result.SuccessIDs = append(result.SuccessIDs, accountID) result.Results = append(result.Results, entry) } diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go new file mode 100644 index 00000000..ef621213 --- /dev/null +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -0,0 +1,80 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +type accountRepoStubForBulkUpdate struct { + accountRepoStub + bulkUpdateErr error + bulkUpdateIDs []int64 + bindGroupErrByID map[int64]error +} + +func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { + s.bulkUpdateIDs = append([]int64{}, ids...) + if s.bulkUpdateErr != nil { + return 0, s.bulkUpdateErr + } + return int64(len(ids)), nil +} + +func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error { + if err, ok := s.bindGroupErrByID[accountID]; ok { + return err + } + return nil +} + +// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 +func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + schedulable := true + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2, 3}, + Schedulable: &schedulable, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 3, result.Success) + require.Equal(t, 0, result.Failed) + require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs) + require.Empty(t, result.FailedIDs) + require.Len(t, result.Results, 3) +} + +// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。 +func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + bindGroupErrByID: map[int64]error{ + 2: errors.New("bind failed"), + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + groupIDs := []int64{10} + schedulable := false + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2, 3}, + GroupIDs: &groupIDs, + Schedulable: &schedulable, + SkipMixedChannelCheck: true, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 2, result.Success) + require.Equal(t, 1, result.Failed) + require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs) + require.ElementsMatch(t, []int64{2}, result.FailedIDs) + require.Len(t, result.Results, 3) +} diff --git a/backend/internal/service/admin_service_update_balance_test.go b/backend/internal/service/admin_service_update_balance_test.go new file mode 100644 index 00000000..d3b3c700 --- /dev/null +++ b/backend/internal/service/admin_service_update_balance_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type balanceUserRepoStub struct { + *userRepoStub + updateErr error + updated []*User +} + +func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error { + if s.updateErr != nil { + return s.updateErr + } + if user == nil { + return nil + } + clone := *user + s.updated = append(s.updated, &clone) + if s.userRepoStub != nil { + s.userRepoStub.user = &clone + } + return nil +} + +type balanceRedeemRepoStub struct { + *redeemRepoStub + created []*RedeemCode +} + +func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + clone := *code + s.created = append(s.created, &clone) + return nil +} + +type authCacheInvalidatorStub struct { + userIDs []int64 + groupIDs []int64 + keys []string +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) { + s.keys = append(s.keys, key) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + s.userIDs = append(s.userIDs, userID) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + s.groupIDs = append(s.groupIDs, groupID) +} + +func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "") + require.NoError(t, err) + require.Equal(t, []int64{7}, invalidator.userIDs) + require.Len(t, redeemRepo.created, 1) +} + +func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "") + require.NoError(t, err) + require.Empty(t, invalidator.userIDs) + require.Empty(t, redeemRepo.created) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go new file mode 100644 index 00000000..7ce9a8a2 --- /dev/null +++ b/backend/internal/service/api_key_auth_cache.go @@ -0,0 +1,46 @@ +package service + +// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) +type APIKeyAuthSnapshot struct { + APIKeyID int64 `json:"api_key_id"` + UserID int64 `json:"user_id"` + GroupID *int64 `json:"group_id,omitempty"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist,omitempty"` + IPBlacklist []string `json:"ip_blacklist,omitempty"` + User APIKeyAuthUserSnapshot `json:"user"` + Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` +} + +// APIKeyAuthUserSnapshot 用户快照 +type APIKeyAuthUserSnapshot struct { + ID int64 `json:"id"` + Status string `json:"status"` + Role string `json:"role"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` +} + +// APIKeyAuthGroupSnapshot 分组快照 +type APIKeyAuthGroupSnapshot struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` +} + +// APIKeyAuthCacheEntry 缓存条目,支持负缓存 +type APIKeyAuthCacheEntry struct { + NotFound bool `json:"not_found"` + Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"` +} diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go new file mode 100644 index 00000000..dfc55eeb --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -0,0 +1,269 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/dgraph-io/ristretto" +) + +type apiKeyAuthCacheConfig struct { + l1Size int + l1TTL time.Duration + l2TTL time.Duration + negativeTTL time.Duration + jitterPercent int + singleflight bool +} + +var ( + jitterRandMu sync.Mutex + // 认证缓存抖动使用独立随机源,避免全局 Seed + jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { + if cfg == nil { + return apiKeyAuthCacheConfig{} + } + auth := cfg.APIKeyAuth + return apiKeyAuthCacheConfig{ + l1Size: auth.L1Size, + l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second, + l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second, + negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second, + jitterPercent: auth.JitterPercent, + singleflight: auth.Singleflight, + } +} + +func (c apiKeyAuthCacheConfig) l1Enabled() bool { + return c.l1Size > 0 && c.l1TTL > 0 +} + +func (c apiKeyAuthCacheConfig) l2Enabled() bool { + return c.l2TTL > 0 +} + +func (c apiKeyAuthCacheConfig) negativeEnabled() bool { + return c.negativeTTL > 0 +} + +func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return ttl + } + if c.jitterPercent <= 0 { + return ttl + } + percent := c.jitterPercent + if percent > 100 { + percent = 100 + } + delta := float64(percent) / 100 + jitterRandMu.Lock() + randVal := jitterRand.Float64() + jitterRandMu.Unlock() + factor := 1 - delta + randVal*(2*delta) + if factor <= 0 { + return ttl + } + return time.Duration(float64(ttl) * factor) +} + +func (s *APIKeyService) initAuthCache(cfg *config.Config) { + s.authCfg = newAPIKeyAuthCacheConfig(cfg) + if !s.authCfg.l1Enabled() { + return + } + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: int64(s.authCfg.l1Size) * 10, + MaxCost: int64(s.authCfg.l1Size), + BufferItems: 64, + }) + if err != nil { + return + } + s.authCacheL1 = cache +} + +func (s *APIKeyService) authCacheKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) { + if s.authCacheL1 != nil { + if val, ok := s.authCacheL1.Get(cacheKey); ok { + if entry, ok := val.(*APIKeyAuthCacheEntry); ok { + return entry, true + } + } + } + if s.cache == nil || !s.authCfg.l2Enabled() { + return nil, false + } + entry, err := s.cache.GetAuthCache(ctx, cacheKey) + if err != nil { + return nil, false + } + s.setAuthCacheL1(cacheKey, entry) + return entry, true +} + +func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) { + if s.authCacheL1 == nil || entry == nil { + return + } + ttl := s.authCfg.l1TTL + if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl { + ttl = s.authCfg.negativeTTL + } + ttl = s.authCfg.jitterTTL(ttl) + _ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl) +} + +func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) { + if entry == nil { + return + } + s.setAuthCacheL1(cacheKey, entry) + if s.cache == nil || !s.authCfg.l2Enabled() { + return + } + _ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl)) +} + +func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { + if s.authCacheL1 != nil { + s.authCacheL1.Del(cacheKey) + } + if s.cache == nil { + return + } + _ = s.cache.DeleteAuthCache(ctx, cacheKey) +} + +func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) + if err != nil { + if errors.Is(err, ErrAPIKeyNotFound) { + entry := &APIKeyAuthCacheEntry{NotFound: true} + if s.authCfg.negativeEnabled() { + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL) + } + return entry, nil + } + return nil, fmt.Errorf("get api key: %w", err) + } + apiKey.Key = key + snapshot := s.snapshotFromAPIKey(apiKey) + if snapshot == nil { + return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) + } + entry := &APIKeyAuthCacheEntry{Snapshot: snapshot} + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL) + return entry, nil +} + +func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) { + if entry == nil { + return nil, false, nil + } + if entry.NotFound { + return nil, true, ErrAPIKeyNotFound + } + if entry.Snapshot == nil { + return nil, false, nil + } + return s.snapshotToAPIKey(key, entry.Snapshot), true, nil +} + +func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { + if apiKey == nil || apiKey.User == nil { + return nil + } + snapshot := &APIKeyAuthSnapshot{ + APIKeyID: apiKey.ID, + UserID: apiKey.UserID, + GroupID: apiKey.GroupID, + Status: apiKey.Status, + IPWhitelist: apiKey.IPWhitelist, + IPBlacklist: apiKey.IPBlacklist, + User: APIKeyAuthUserSnapshot{ + ID: apiKey.User.ID, + Status: apiKey.User.Status, + Role: apiKey.User.Role, + Balance: apiKey.User.Balance, + Concurrency: apiKey.User.Concurrency, + }, + } + if apiKey.Group != nil { + snapshot.Group = &APIKeyAuthGroupSnapshot{ + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + } + } + return snapshot +} + +func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey { + if snapshot == nil { + return nil + } + apiKey := &APIKey{ + ID: snapshot.APIKeyID, + UserID: snapshot.UserID, + GroupID: snapshot.GroupID, + Key: key, + Status: snapshot.Status, + IPWhitelist: snapshot.IPWhitelist, + IPBlacklist: snapshot.IPBlacklist, + User: &User{ + ID: snapshot.User.ID, + Status: snapshot.User.Status, + Role: snapshot.User.Role, + Balance: snapshot.User.Balance, + Concurrency: snapshot.User.Concurrency, + }, + } + if snapshot.Group != nil { + apiKey.Group = &Group{ + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + } + } + return apiKey +} diff --git a/backend/internal/service/api_key_auth_cache_invalidate.go b/backend/internal/service/api_key_auth_cache_invalidate.go new file mode 100644 index 00000000..aeb58bcc --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_invalidate.go @@ -0,0 +1,48 @@ +package service + +import "context" + +// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) { + if key == "" { + return + } + cacheKey := s.authCacheKey(key) + s.deleteAuthCache(ctx, cacheKey) +} + +// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + if userID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + if groupID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) { + if len(keys) == 0 { + return + } + for _, key := range keys { + if key == "" { + continue + } + s.deleteAuthCache(ctx, s.authCacheKey(key)) + } +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 578afc1a..ecc570c7 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -12,6 +12,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/dgraph-io/ristretto" + "golang.org/x/sync/singleflight" ) var ( @@ -31,9 +33,11 @@ const ( type APIKeyRepository interface { Create(ctx context.Context, key *APIKey) error GetByID(ctx context.Context, id int64) (*APIKey, error) - // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 - GetOwnerID(ctx context.Context, id int64) (int64, error) + // GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景 + GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) GetByKey(ctx context.Context, key string) (*APIKey, error) + // GetByKeyForAuth 认证专用查询,返回最小字段集 + GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) Update(ctx context.Context, key *APIKey) error Delete(ctx context.Context, id int64) error @@ -45,6 +49,8 @@ type APIKeyRepository interface { SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error) + ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) + ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) } // APIKeyCache defines cache operations for API key service @@ -55,6 +61,17 @@ type APIKeyCache interface { IncrementDailyUsage(ctx context.Context, apiKey string) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error + + GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error + DeleteAuthCache(ctx context.Context, key string) error +} + +// APIKeyAuthCacheInvalidator 提供认证缓存失效能力 +type APIKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) + InvalidateAuthCacheByUserID(ctx context.Context, userID int64) + InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) } // CreateAPIKeyRequest 创建API Key请求 @@ -83,6 +100,9 @@ type APIKeyService struct { userSubRepo UserSubscriptionRepository cache APIKeyCache cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -94,7 +114,7 @@ func NewAPIKeyService( cache APIKeyCache, cfg *config.Config, ) *APIKeyService { - return &APIKeyService{ + svc := &APIKeyService{ apiKeyRepo: apiKeyRepo, userRepo: userRepo, groupRepo: groupRepo, @@ -102,6 +122,8 @@ func NewAPIKeyService( cache: cache, cfg: cfg, } + svc.initAuthCache(cfg) + return svc } // GenerateKey 生成随机API Key @@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK return nil, fmt.Errorf("create api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } @@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) // GetByKey 根据Key字符串获取API Key(用于认证) func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) { - // 尝试从Redis缓存获取 - cacheKey := fmt.Sprintf("apikey:%s", key) + cacheKey := s.authCacheKey(key) - // 这里可以添加Redis缓存逻辑,暂时直接查询数据库 - apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) + if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok { + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + if s.authCfg.singleflight { + value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) { + return s.loadAuthCacheEntry(ctx, key, cacheKey) + }) + if err != nil { + return nil, err + } + entry, _ := value.(*APIKeyAuthCacheEntry) + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } else { + entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey) + if err != nil { + return nil, err + } + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } - - // 缓存到Redis(可选,TTL设置为5分钟) - if s.cache != nil { - // 这里可以序列化并缓存API Key - _ = cacheKey // 使用变量避免未使用错误 - } - + apiKey.Key = key return apiKey, nil } @@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req return nil, fmt.Errorf("update api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } // Delete 删除API Key -// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, -// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能 func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error { - // 仅获取所有者 ID 用于权限验证,而非加载完整对象 - ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) + key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id) if err != nil { return fmt.Errorf("get api key: %w", err) } @@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro return ErrInsufficientPerms } - // 清除Redis缓存(使用 ownerID 而非 apiKey.UserID) + // 清除Redis缓存(使用 userID 而非 apiKey.UserID) if s.cache != nil { - _ = s.cache.DeleteCreateAttemptCount(ctx, ownerID) + _ = s.cache.DeleteCreateAttemptCount(ctx, userID) } + s.InvalidateAuthCacheByKey(ctx, key) if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go new file mode 100644 index 00000000..3314ca8d --- /dev/null +++ b/backend/internal/service/api_key_service_cache_test.go @@ -0,0 +1,417 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +type authRepoStub struct { + getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error) + listKeysByUserID func(ctx context.Context, userID int64) ([]string, error) + listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error) +} + +func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error { + panic("unexpected Create call") +} + +func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + panic("unexpected GetByID call") +} + +func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} + +func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKey call") +} + +func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + if s.getByKeyForAuth == nil { + panic("unexpected GetByKeyForAuth call") + } + return s.getByKeyForAuth(ctx, key) +} + +func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error { + panic("unexpected Update call") +} + +func (s *authRepoStub) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} + +func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} + +func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) { + panic("unexpected CountByUserID call") +} + +func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) { + panic("unexpected ExistsByKey call") +} + +func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} + +func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} + +func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} + +func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected CountByGroupID call") +} + +func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + if s.listKeysByUserID == nil { + panic("unexpected ListKeysByUserID call") + } + return s.listKeysByUserID(ctx, userID) +} + +func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + if s.listKeysByGroupID == nil { + panic("unexpected ListKeysByGroupID call") + } + return s.listKeysByGroupID(ctx, groupID) +} + +type authCacheStub struct { + getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + setAuthKeys []string + deleteAuthKeys []string +} + +func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error { + return nil +} + +func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { + return nil +} + +func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + if s.getAuthCache == nil { + return nil, redis.Nil + } + return s.getAuthCache(ctx, key) +} + +func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + s.setAuthKeys = append(s.setAuthKeys, key) + return nil +} + +func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + groupID := int64(9) + cacheEntry := &APIKeyAuthCacheEntry{ + Snapshot: &APIKeyAuthSnapshot{ + APIKeyID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: APIKeyAuthUserSnapshot{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &APIKeyAuthGroupSnapshot{ + ID: groupID, + Name: "g", + Platform: PlatformAnthropic, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + }, + }, + } + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return cacheEntry, nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k1") + require.NoError(t, err) + require.Equal(t, int64(1), apiKey.ID) + require.Equal(t, int64(2), apiKey.User.ID) + require.Equal(t, groupID, apiKey.Group.ID) +} + +func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return &APIKeyAuthCacheEntry{NotFound: true}, nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) +} + +func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return &APIKey{ + ID: 5, + UserID: 7, + Status: StatusActive, + User: &User{ + ID: 7, + Status: StatusActive, + Role: RoleUser, + Balance: 12, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k2") + require.NoError(t, err) + require.Equal(t, int64(5), apiKey.ID) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + return &APIKey{ + ID: 21, + UserID: 3, + Status: StatusActive, + User: &User{ + ID: 3, + Status: StatusActive, + Role: RoleUser, + Balance: 5, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L1Size: 1000, + L1TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + require.NotNil(t, svc.authCacheL1) + + _, err := svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + svc.authCacheL1.Wait() + cacheKey := svc.authCacheKey("k-l1") + _, ok := svc.authCacheL1.Get(cacheKey) + require.True(t, ok) + _, err = svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} + +func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByUserID(context.Background(), 7) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByGroupID(context.Background(), 9) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return nil, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByKey(context.Background(), "k1") + require.Len(t, cache.deleteAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, ErrAPIKeyNotFound + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + time.Sleep(50 * time.Millisecond) + return &APIKey{ + ID: 11, + UserID: 2, + Status: StatusActive, + User: &User{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 1, + Concurrency: 1, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + Singleflight: true, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + start := make(chan struct{}) + wg := sync.WaitGroup{} + errs := make([]error, 5) + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + _, err := svc.GetByKey(context.Background(), "k1") + errs[idx] = err + }(i) + } + close(start) + wg.Wait() + + for _, err := range errs { + require.NoError(t, err) + } + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 7d04c5ac..32ae884e 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -20,13 +20,12 @@ import ( // 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。 // // 设计说明: -// - ownerID: 模拟 GetOwnerID 返回的所有者 ID -// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound) +// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误 // - deleteErr: 模拟 Delete 返回的错误 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 type apiKeyRepoStub struct { - ownerID int64 // GetOwnerID 的返回值 - ownerErr error // GetOwnerID 的错误返回值 + apiKey *APIKey // GetKeyAndOwnerID 的返回值 + getByIDErr error // GetKeyAndOwnerID 的错误返回值 deleteErr error // Delete 的错误返回值 deletedIDs []int64 // 记录已删除的 API Key ID 列表 } @@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error { } func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + if s.getByIDErr != nil { + return nil, s.getByIDErr + } + if s.apiKey != nil { + clone := *s.apiKey + return &clone, nil + } panic("unexpected GetByID call") } -// GetOwnerID 返回预设的所有者 ID 或错误。 -// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。 -func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return s.ownerID, s.ownerErr +func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + if s.getByIDErr != nil { + return "", 0, s.getByIDErr + } + if s.apiKey != nil { + return s.apiKey.Key, s.apiKey.UserID, nil + } + return "", 0, ErrAPIKeyNotFound } func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { panic("unexpected GetByKey call") } +func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} + func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error { panic("unexpected Update call") } @@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int panic("unexpected CountByGroupID call") } +func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} + +func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // // 设计说明: // - invalidated: 记录被清除缓存的用户 ID 列表 type apiKeyCacheStub struct { - invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key } // GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制 @@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string return nil } +func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 1 +// - GetKeyAndOwnerID 返回所有者 ID 为 1 // - 调用者 userID 为 2(不匹配) // - 返回 ErrInsufficientPerms 错误 // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 1} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { require.ErrorIs(t, err, ErrInsufficientPerms) require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用 require.Empty(t, cache.invalidated) // 验证缓存未被清除 + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 7 +// - GetKeyAndOwnerID 返回所有者 ID 为 7 // - 调用者 userID 为 7(匹配) // - Delete 成功执行 // - 缓存被正确清除(使用 ownerID) // - 返回 nil 错误 func TestApiKeyService_Delete_Success(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 7} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) { require.NoError(t, err) require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // 预期行为: -// - GetOwnerID 返回 ErrAPIKeyNotFound 错误 +// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误 // - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装) // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_NotFound(t *testing.T) { - repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound} + repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound} cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) { require.ErrorIs(t, err, ErrAPIKeyNotFound) require.Empty(t, repo.deletedIDs) require.Empty(t, cache.invalidated) + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。 // 预期行为: -// - GetOwnerID 返回正确的所有者 ID +// - GetKeyAndOwnerID 返回正确的所有者 ID // - 所有权验证通过 // - 缓存被清除(在删除之前) // - Delete 被调用但返回错误 // - 返回包含 "delete api key" 的错误信息 func TestApiKeyService_Delete_DeleteFails(t *testing.T) { repo := &apiKeyRepoStub{ - ownerID: 3, + apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"}, deleteErr: errors.New("delete failed"), } cache := &apiKeyCacheStub{} @@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) { require.ErrorContains(t, err, "delete api key") require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用 require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败) + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } diff --git a/backend/internal/service/auth_cache_invalidation_test.go b/backend/internal/service/auth_cache_invalidation_test.go new file mode 100644 index 00000000..b6e56177 --- /dev/null +++ b/backend/internal/service/auth_cache_invalidation_test.go @@ -0,0 +1,33 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUsageService_InvalidateUsageCaches(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &UsageService{authCacheInvalidator: invalidator} + + svc.invalidateUsageCaches(context.Background(), 7, false) + require.Empty(t, invalidator.userIDs) + + svc.invalidateUsageCaches(context.Background(), 7, true) + require.Equal(t, []int64{7}, invalidator.userIDs) +} + +func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &RedeemService{authCacheInvalidator: invalidator} + + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance}) + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency}) + groupID := int64(3) + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID}) + + require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs) +} diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go new file mode 100644 index 00000000..0d1cec57 --- /dev/null +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -0,0 +1,242 @@ +package service + +import ( + "context" + "errors" + "log" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + defaultDashboardAggregationTimeout = 2 * time.Minute + defaultDashboardAggregationBackfillTimeout = 30 * time.Minute + dashboardAggregationRetentionInterval = 6 * time.Hour +) + +var ( + // ErrDashboardBackfillDisabled 当配置禁用回填时返回。 + ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") + // ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。 + ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") +) + +// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。 +type DashboardAggregationRepository interface { + AggregateRange(ctx context.Context, start, end time.Time) error + GetAggregationWatermark(ctx context.Context) (time.Time, error) + UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error + CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error + CleanupUsageLogs(ctx context.Context, cutoff time.Time) error + EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error +} + +// DashboardAggregationService 负责定时聚合与回填。 +type DashboardAggregationService struct { + repo DashboardAggregationRepository + timingWheel *TimingWheelService + cfg config.DashboardAggregationConfig + running int32 + lastRetentionCleanup atomic.Value // time.Time +} + +// NewDashboardAggregationService 创建聚合服务。 +func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { + var aggCfg config.DashboardAggregationConfig + if cfg != nil { + aggCfg = cfg.DashboardAgg + } + return &DashboardAggregationService{ + repo: repo, + timingWheel: timingWheel, + cfg: aggCfg, + } +} + +// Start 启动定时聚合作业(重启生效配置)。 +func (s *DashboardAggregationService) Start() { + if s == nil || s.repo == nil || s.timingWheel == nil { + return + } + if !s.cfg.Enabled { + log.Printf("[DashboardAggregation] 聚合作业已禁用") + return + } + + interval := time.Duration(s.cfg.IntervalSeconds) * time.Second + if interval <= 0 { + interval = time.Minute + } + + if s.cfg.RecomputeDays > 0 { + go s.recomputeRecentDays() + } + + s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() { + s.runScheduledAggregation() + }) + log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) + if !s.cfg.BackfillEnabled { + log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") + } +} + +// TriggerBackfill 触发回填(异步)。 +func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error { + if s == nil || s.repo == nil { + return errors.New("聚合服务未初始化") + } + if !s.cfg.BackfillEnabled { + log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false") + return ErrDashboardBackfillDisabled + } + if !end.After(start) { + return errors.New("回填时间范围无效") + } + if s.cfg.BackfillMaxDays > 0 { + maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour + if end.Sub(start) > maxRange { + return ErrDashboardBackfillTooLarge + } + } + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) + defer cancel() + if err := s.backfillRange(ctx, start, end); err != nil { + log.Printf("[DashboardAggregation] 回填失败: %v", err) + } + }() + return nil +} + +func (s *DashboardAggregationService) recomputeRecentDays() { + days := s.cfg.RecomputeDays + if days <= 0 { + return + } + now := time.Now().UTC() + start := now.AddDate(0, 0, -days) + + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) + defer cancel() + if err := s.backfillRange(ctx, start, now); err != nil { + log.Printf("[DashboardAggregation] 启动重算失败: %v", err) + return + } +} + +func (s *DashboardAggregationService) runScheduledAggregation() { + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + return + } + defer atomic.StoreInt32(&s.running, 0) + + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout) + defer cancel() + + now := time.Now().UTC() + last, err := s.repo.GetAggregationWatermark(ctx) + if err != nil { + log.Printf("[DashboardAggregation] 读取水位失败: %v", err) + last = time.Unix(0, 0).UTC() + } + + lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second + epoch := time.Unix(0, 0).UTC() + start := last.Add(-lookback) + if !last.After(epoch) { + retentionDays := s.cfg.Retention.UsageLogsDays + if retentionDays <= 0 { + retentionDays = 1 + } + start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays)) + } else if start.After(now) { + start = now.Add(-lookback) + } + + if err := s.aggregateRange(ctx, start, now); err != nil { + log.Printf("[DashboardAggregation] 聚合失败: %v", err) + return + } + + if err := s.repo.UpdateAggregationWatermark(ctx, now); err != nil { + log.Printf("[DashboardAggregation] 更新水位失败: %v", err) + } + + s.maybeCleanupRetention(ctx, now) +} + +func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error { + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + return errors.New("聚合作业正在运行") + } + defer atomic.StoreInt32(&s.running, 0) + + startUTC := start.UTC() + endUTC := end.UTC() + if !endUTC.After(startUTC) { + return errors.New("回填时间范围无效") + } + + cursor := truncateToDayUTC(startUTC) + for cursor.Before(endUTC) { + windowEnd := cursor.Add(24 * time.Hour) + if windowEnd.After(endUTC) { + windowEnd = endUTC + } + if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil { + return err + } + cursor = windowEnd + } + + if err := s.repo.UpdateAggregationWatermark(ctx, endUTC); err != nil { + log.Printf("[DashboardAggregation] 更新水位失败: %v", err) + } + + s.maybeCleanupRetention(ctx, endUTC) + return nil +} + +func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error { + if !end.After(start) { + return nil + } + if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil { + log.Printf("[DashboardAggregation] 分区检查失败: %v", err) + } + return s.repo.AggregateRange(ctx, start, end) +} + +func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) { + lastAny := s.lastRetentionCleanup.Load() + if lastAny != nil { + if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval { + return + } + } + + hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays) + dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays) + usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays) + + aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) + if aggErr != nil { + log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr) + } + usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff) + if usageErr != nil { + log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) + } + if aggErr == nil && usageErr == nil { + s.lastRetentionCleanup.Store(now) + } +} + +func truncateToDayUTC(t time.Time) time.Time { + t = t.UTC() + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) +} diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go new file mode 100644 index 00000000..2fc22105 --- /dev/null +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -0,0 +1,106 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type dashboardAggregationRepoTestStub struct { + aggregateCalls int + lastStart time.Time + lastEnd time.Time + watermark time.Time + aggregateErr error + cleanupAggregatesErr error + cleanupUsageErr error +} + +func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error { + s.aggregateCalls++ + s.lastStart = start + s.lastEnd = end + return s.aggregateErr +} + +func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { + return s.watermark, nil +} + +func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error { + return s.cleanupAggregatesErr +} + +func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + return s.cleanupUsageErr +} + +func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { + return nil +} + +func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.runScheduledAggregation() + + require.Equal(t, 1, repo.aggregateCalls) + require.False(t, repo.lastEnd.IsZero()) + require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart) +} + +func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) + + require.Nil(t, svc.lastRetentionCleanup.Load()) +} + +func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + BackfillEnabled: true, + BackfillMaxDays: 1, + }, + } + + start := time.Now().AddDate(0, 0, -3) + end := time.Now() + err := svc.TriggerBackfill(start, end) + require.ErrorIs(t, err, ErrDashboardBackfillTooLarge) + require.Equal(t, 0, repo.aggregateCalls) +} diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index f0b1f2a0..69d251cb 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -2,25 +2,119 @@ package service import ( "context" + "encoding/json" + "errors" "fmt" + "log" + "sync/atomic" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) -// DashboardService provides aggregated statistics for admin dashboard. -type DashboardService struct { - usageRepo UsageLogRepository +const ( + defaultDashboardStatsFreshTTL = 15 * time.Second + defaultDashboardStatsCacheTTL = 30 * time.Second + defaultDashboardStatsRefreshTimeout = 30 * time.Second +) + +// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。 +var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中") + +// DashboardStatsCache 定义仪表盘统计缓存接口。 +type DashboardStatsCache interface { + GetDashboardStats(ctx context.Context) (string, error) + SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error + DeleteDashboardStats(ctx context.Context) error } -func NewDashboardService(usageRepo UsageLogRepository) *DashboardService { +type dashboardStatsRangeFetcher interface { + GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) +} + +type dashboardStatsCacheEntry struct { + Stats *usagestats.DashboardStats `json:"stats"` + UpdatedAt int64 `json:"updated_at"` +} + +// DashboardService 提供管理员仪表盘统计服务。 +type DashboardService struct { + usageRepo UsageLogRepository + aggRepo DashboardAggregationRepository + cache DashboardStatsCache + cacheFreshTTL time.Duration + cacheTTL time.Duration + refreshTimeout time.Duration + refreshing int32 + aggEnabled bool + aggInterval time.Duration + aggLookback time.Duration + aggUsageDays int +} + +func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService { + freshTTL := defaultDashboardStatsFreshTTL + cacheTTL := defaultDashboardStatsCacheTTL + refreshTimeout := defaultDashboardStatsRefreshTimeout + aggEnabled := true + aggInterval := time.Minute + aggLookback := 2 * time.Minute + aggUsageDays := 90 + if cfg != nil { + if !cfg.Dashboard.Enabled { + cache = nil + } + if cfg.Dashboard.StatsFreshTTLSeconds > 0 { + freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second + } + if cfg.Dashboard.StatsTTLSeconds > 0 { + cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second + } + if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 { + refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second + } + aggEnabled = cfg.DashboardAgg.Enabled + if cfg.DashboardAgg.IntervalSeconds > 0 { + aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second + } + if cfg.DashboardAgg.LookbackSeconds > 0 { + aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second + } + if cfg.DashboardAgg.Retention.UsageLogsDays > 0 { + aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays + } + } return &DashboardService{ - usageRepo: usageRepo, + usageRepo: usageRepo, + aggRepo: aggRepo, + cache: cache, + cacheFreshTTL: freshTTL, + cacheTTL: cacheTTL, + refreshTimeout: refreshTimeout, + aggEnabled: aggEnabled, + aggInterval: aggInterval, + aggLookback: aggLookback, + aggUsageDays: aggUsageDays, } } func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { - stats, err := s.usageRepo.GetDashboardStats(ctx) + if s.cache != nil { + cached, fresh, err := s.getCachedDashboardStats(ctx) + if err == nil && cached != nil { + s.refreshAggregationStaleness(cached) + if !fresh { + s.refreshDashboardStatsAsync() + } + return cached, nil + } + if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) { + log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err) + } + } + + stats, err := s.refreshDashboardStats(ctx) if err != nil { return nil, fmt.Errorf("get dashboard stats: %w", err) } @@ -43,6 +137,169 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } +func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { + data, err := s.cache.GetDashboardStats(ctx) + if err != nil { + return nil, false, err + } + + var entry dashboardStatsCacheEntry + if err := json.Unmarshal([]byte(data), &entry); err != nil { + s.evictDashboardStatsCache(err) + return nil, false, ErrDashboardStatsCacheMiss + } + if entry.Stats == nil { + s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据")) + return nil, false, ErrDashboardStatsCacheMiss + } + + age := time.Since(time.Unix(entry.UpdatedAt, 0)) + return entry.Stats, age <= s.cacheFreshTTL, nil +} + +func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + stats, err := s.fetchDashboardStats(ctx) + if err != nil { + return nil, err + } + s.applyAggregationStatus(ctx, stats) + cacheCtx, cancel := s.cacheOperationContext() + defer cancel() + s.saveDashboardStatsCache(cacheCtx, stats) + return stats, nil +} + +func (s *DashboardService) refreshDashboardStatsAsync() { + if s.cache == nil { + return + } + if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) { + return + } + + go func() { + defer atomic.StoreInt32(&s.refreshing, 0) + + ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout) + defer cancel() + + stats, err := s.fetchDashboardStats(ctx) + if err != nil { + log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) + return + } + s.applyAggregationStatus(ctx, stats) + cacheCtx, cancel := s.cacheOperationContext() + defer cancel() + s.saveDashboardStatsCache(cacheCtx, stats) + }() +} + +func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + if !s.aggEnabled { + if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok { + now := time.Now().UTC() + start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays)) + return fetcher.GetDashboardStatsWithRange(ctx, start, now) + } + } + return s.usageRepo.GetDashboardStats(ctx) +} + +func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) { + if s.cache == nil || stats == nil { + return + } + + entry := dashboardStatsCacheEntry{ + Stats: stats, + UpdatedAt: time.Now().Unix(), + } + data, err := json.Marshal(entry) + if err != nil { + log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err) + return + } + + if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil { + log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err) + } +} + +func (s *DashboardService) evictDashboardStatsCache(reason error) { + if s.cache == nil { + return + } + cacheCtx, cancel := s.cacheOperationContext() + defer cancel() + + if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil { + log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err) + } + if reason != nil { + log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason) + } +} + +func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), s.refreshTimeout) +} + +func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) { + if stats == nil { + return + } + updatedAt := s.fetchAggregationUpdatedAt(ctx) + stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339) + stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC()) +} + +func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) { + if stats == nil { + return + } + updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt) + stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC()) +} + +func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time { + if s.aggRepo == nil { + return time.Unix(0, 0).UTC() + } + updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx) + if err != nil { + log.Printf("[Dashboard] 读取聚合水位失败: %v", err) + return time.Unix(0, 0).UTC() + } + if updatedAt.IsZero() { + return time.Unix(0, 0).UTC() + } + return updatedAt.UTC() +} + +func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool { + if !s.aggEnabled { + return true + } + epoch := time.Unix(0, 0).UTC() + if !updatedAt.After(epoch) { + return true + } + threshold := s.aggInterval + s.aggLookback + return now.Sub(updatedAt) > threshold +} + +func parseStatsUpdatedAt(raw string) time.Time { + if raw == "" { + return time.Unix(0, 0).UTC() + } + parsed, err := time.Parse(time.RFC3339, raw) + if err != nil { + return time.Unix(0, 0).UTC() + } + return parsed.UTC() +} + func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) if err != nil { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go new file mode 100644 index 00000000..db3c78c3 --- /dev/null +++ b/backend/internal/service/dashboard_service_test.go @@ -0,0 +1,387 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/stretchr/testify/require" +) + +type usageRepoStub struct { + UsageLogRepository + stats *usagestats.DashboardStats + rangeStats *usagestats.DashboardStats + err error + rangeErr error + calls int32 + rangeCalls int32 + rangeStart time.Time + rangeEnd time.Time + onCall chan struct{} +} + +func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + atomic.AddInt32(&s.calls, 1) + if s.onCall != nil { + select { + case s.onCall <- struct{}{}: + default: + } + } + if s.err != nil { + return nil, s.err + } + return s.stats, nil +} + +func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) { + atomic.AddInt32(&s.rangeCalls, 1) + s.rangeStart = start + s.rangeEnd = end + if s.rangeErr != nil { + return nil, s.rangeErr + } + if s.rangeStats != nil { + return s.rangeStats, nil + } + return s.stats, nil +} + +type dashboardCacheStub struct { + get func(ctx context.Context) (string, error) + set func(ctx context.Context, data string, ttl time.Duration) error + del func(ctx context.Context) error + getCalls int32 + setCalls int32 + delCalls int32 + lastSetMu sync.Mutex + lastSet string +} + +func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) { + atomic.AddInt32(&c.getCalls, 1) + if c.get != nil { + return c.get(ctx) + } + return "", ErrDashboardStatsCacheMiss +} + +func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error { + atomic.AddInt32(&c.setCalls, 1) + c.lastSetMu.Lock() + c.lastSet = data + c.lastSetMu.Unlock() + if c.set != nil { + return c.set(ctx, data, ttl) + } + return nil +} + +func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error { + atomic.AddInt32(&c.delCalls, 1) + if c.del != nil { + return c.del(ctx) + } + return nil +} + +type dashboardAggregationRepoStub struct { + watermark time.Time + err error +} + +func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { + if s.err != nil { + return time.Time{}, s.err + } + return s.watermark, nil +} + +func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { + return nil +} + +func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry { + t.Helper() + c.lastSetMu.Lock() + data := c.lastSet + c.lastSetMu.Unlock() + + var entry dashboardStatsCacheEntry + err := json.Unmarshal([]byte(data), &entry) + require.NoError(t, err) + return entry +} + +func TestDashboardService_CacheHitFresh(t *testing.T) { + stats := &usagestats.DashboardStats{ + TotalUsers: 10, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, + } + entry := dashboardStatsCacheEntry{ + Stats: stats, + UpdatedAt: time.Now().Unix(), + } + payload, err := json.Marshal(entry) + require.NoError(t, err) + + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return string(payload), nil + }, + } + repo := &usageRepoStub{ + stats: &usagestats.DashboardStats{TotalUsers: 99}, + } + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } + svc := NewDashboardService(repo, aggRepo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls)) +} + +func TestDashboardService_CacheMiss_StoresCache(t *testing.T) { + stats := &usagestats.DashboardStats{ + TotalUsers: 7, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, + } + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "", ErrDashboardStatsCacheMiss + }, + } + repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } + svc := NewDashboardService(repo, aggRepo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls)) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls)) + entry := cache.readLastEntry(t) + require.Equal(t, stats, entry.Stats) + require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second) +} + +func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) { + stats := &usagestats.DashboardStats{ + TotalUsers: 3, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, + } + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "", nil + }, + } + repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: false}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } + svc := NewDashboardService(repo, aggRepo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls)) +} + +func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) { + staleStats := &usagestats.DashboardStats{ + TotalUsers: 11, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, + } + entry := dashboardStatsCacheEntry{ + Stats: staleStats, + UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(), + } + payload, err := json.Marshal(entry) + require.NoError(t, err) + + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return string(payload), nil + }, + } + refreshCh := make(chan struct{}, 1) + repo := &usageRepoStub{ + stats: &usagestats.DashboardStats{TotalUsers: 22}, + onCall: refreshCh, + } + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } + svc := NewDashboardService(repo, aggRepo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, staleStats, got) + + select { + case <-refreshCh: + case <-time.After(1 * time.Second): + t.Fatal("等待异步刷新超时") + } + require.Eventually(t, func() bool { + return atomic.LoadInt32(&cache.setCalls) >= 1 + }, 1*time.Second, 10*time.Millisecond) +} + +func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) { + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "not-json", nil + }, + } + stats := &usagestats.DashboardStats{TotalUsers: 9} + repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } + svc := NewDashboardService(repo, aggRepo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls)) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls)) +} + +func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) { + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "not-json", nil + }, + } + repo := &usageRepoStub{err: errors.New("db down")} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } + svc := NewDashboardService(repo, aggRepo, cache, cfg) + + _, err := svc.GetDashboardStats(context.Background()) + require.Error(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls)) +} + +func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) { + stats := &usagestats.DashboardStats{} + repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}} + svc := NewDashboardService(repo, aggRepo, nil, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt) + require.True(t, got.StatsStale) +} + +func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) { + aggNow := time.Now().UTC().Truncate(time.Second) + stats := &usagestats.DashboardStats{} + repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: aggNow} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: false}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + }, + } + svc := NewDashboardService(repo, aggRepo, nil, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt) + require.False(t, got.StatsStale) +} + +func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) { + expected := &usagestats.DashboardStats{TotalUsers: 42} + repo := &usageRepoStub{ + rangeStats: expected, + err: errors.New("should not call aggregated stats"), + } + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: false}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: false, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 7, + }, + }, + } + svc := NewDashboardService(repo, nil, nil, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(42), got.TotalUsers) + require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls)) + require.False(t, repo.rangeEnd.IsZero()) + require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart) +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 2f0f4975..324f347b 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -50,13 +50,15 @@ type UpdateGroupRequest struct { // GroupService 分组管理服务 type GroupService struct { - groupRepo GroupRepository + groupRepo GroupRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewGroupService 创建分组服务实例 -func NewGroupService(groupRepo GroupRepository) *GroupService { +func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService { return &GroupService{ - groupRepo: groupRepo, + groupRepo: groupRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ if err := s.groupRepo.Update(ctx, group); err != nil { return nil, fmt.Errorf("update group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } @@ -167,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { return fmt.Errorf("get group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } if err := s.groupRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete group: %w", err) } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go new file mode 100644 index 00000000..94e74f22 --- /dev/null +++ b/backend/internal/service/openai_codex_transform.go @@ -0,0 +1,404 @@ +package service + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt" + codexCacheTTL = 15 * time.Minute +) + +var codexModelMap = map[string]string{ + "gpt-5.1-codex": "gpt-5.1-codex", + "gpt-5.1-codex-low": "gpt-5.1-codex", + "gpt-5.1-codex-medium": "gpt-5.1-codex", + "gpt-5.1-codex-high": "gpt-5.1-codex", + "gpt-5.1-codex-max": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", + "gpt-5.2": "gpt-5.2", + "gpt-5.2-none": "gpt-5.2", + "gpt-5.2-low": "gpt-5.2", + "gpt-5.2-medium": "gpt-5.2", + "gpt-5.2-high": "gpt-5.2", + "gpt-5.2-xhigh": "gpt-5.2", + "gpt-5.2-codex": "gpt-5.2-codex", + "gpt-5.2-codex-low": "gpt-5.2-codex", + "gpt-5.2-codex-medium": "gpt-5.2-codex", + "gpt-5.2-codex-high": "gpt-5.2-codex", + "gpt-5.2-codex-xhigh": "gpt-5.2-codex", + "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5.1": "gpt-5.1", + "gpt-5.1-none": "gpt-5.1", + "gpt-5.1-low": "gpt-5.1", + "gpt-5.1-medium": "gpt-5.1", + "gpt-5.1-high": "gpt-5.1", + "gpt-5.1-chat-latest": "gpt-5.1", + "gpt-5-codex": "gpt-5.1-codex", + "codex-mini-latest": "gpt-5.1-codex-mini", + "gpt-5-codex-mini": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5": "gpt-5.1", + "gpt-5-mini": "gpt-5.1", + "gpt-5-nano": "gpt-5.1", +} + +type codexTransformResult struct { + Modified bool + NormalizedModel string + PromptCacheKey string +} + +type opencodeCacheMetadata struct { + ETag string `json:"etag"` + LastFetch string `json:"lastFetch,omitempty"` + LastChecked int64 `json:"lastChecked"` +} + +func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { + result := codexTransformResult{} + + model := "" + if v, ok := reqBody["model"].(string); ok { + model = v + } + normalizedModel := normalizeCodexModel(model) + if normalizedModel != "" { + if model != normalizedModel { + reqBody["model"] = normalizedModel + result.Modified = true + } + result.NormalizedModel = normalizedModel + } + + if v, ok := reqBody["store"].(bool); !ok || v { + reqBody["store"] = false + result.Modified = true + } + if v, ok := reqBody["stream"].(bool); !ok || !v { + reqBody["stream"] = true + result.Modified = true + } + + if _, ok := reqBody["max_output_tokens"]; ok { + delete(reqBody, "max_output_tokens") + result.Modified = true + } + if _, ok := reqBody["max_completion_tokens"]; ok { + delete(reqBody, "max_completion_tokens") + result.Modified = true + } + + if normalizeCodexTools(reqBody) { + result.Modified = true + } + + if v, ok := reqBody["prompt_cache_key"].(string); ok { + result.PromptCacheKey = strings.TrimSpace(v) + } + + instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + existingInstructions, _ := reqBody["instructions"].(string) + existingInstructions = strings.TrimSpace(existingInstructions) + + if instructions != "" { + if existingInstructions != instructions { + reqBody["instructions"] = instructions + result.Modified = true + } + } + + if input, ok := reqBody["input"].([]any); ok { + input = filterCodexInput(input) + reqBody["input"] = input + result.Modified = true + } + + return result +} + +func normalizeCodexModel(model string) string { + if model == "" { + return "gpt-5.1" + } + + modelID := model + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + + if mapped := getNormalizedCodexModel(modelID); mapped != "" { + return mapped + } + + normalized := strings.ToLower(modelID) + + if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { + return "gpt-5.2-codex" + } + if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { + return "gpt-5.2" + } + if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { + return "gpt-5.1-codex-max" + } + if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") { + return "gpt-5.1-codex-mini" + } + if strings.Contains(normalized, "codex-mini-latest") || + strings.Contains(normalized, "gpt-5-codex-mini") || + strings.Contains(normalized, "gpt 5 codex mini") { + return "codex-mini-latest" + } + if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") { + return "gpt-5.1-codex" + } + if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") { + return "gpt-5.1" + } + if strings.Contains(normalized, "codex") { + return "gpt-5.1-codex" + } + if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { + return "gpt-5.1" + } + + return "gpt-5.1" +} + +func getNormalizedCodexModel(modelID string) string { + if modelID == "" { + return "" + } + if mapped, ok := codexModelMap[modelID]; ok { + return mapped + } + lower := strings.ToLower(modelID) + for key, value := range codexModelMap { + if strings.ToLower(key) == lower { + return value + } + } + return "" +} + +func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { + cacheDir := codexCachePath("") + if cacheDir == "" { + return "" + } + cacheFile := filepath.Join(cacheDir, cacheFileName) + metaFile := filepath.Join(cacheDir, metaFileName) + + var cachedContent string + if content, ok := readFile(cacheFile); ok { + cachedContent = content + } + + var meta opencodeCacheMetadata + if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" { + if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { + return cachedContent + } + } + + content, etag, status, err := fetchWithETag(url, meta.ETag) + if err == nil && status == http.StatusNotModified && cachedContent != "" { + return cachedContent + } + if err == nil && status >= 200 && status < 300 && content != "" { + _ = writeFile(cacheFile, content) + meta = opencodeCacheMetadata{ + ETag: etag, + LastFetch: time.Now().UTC().Format(time.RFC3339), + LastChecked: time.Now().UnixMilli(), + } + _ = writeJSON(metaFile, meta) + return content + } + + return cachedContent +} + +func getOpenCodeCodexHeader() string { + return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") +} + +func GetOpenCodeInstructions() string { + return getOpenCodeCodexHeader() +} + +func filterCodexInput(input []any) []any { + filtered := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + filtered = append(filtered, item) + continue + } + if typ, ok := m["type"].(string); ok && typ == "item_reference" { + continue + } + delete(m, "id") + filtered = append(filtered, m) + } + return filtered +} + +func normalizeCodexTools(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false + } + tools, ok := rawTools.([]any) + if !ok { + return false + } + + modified := false + for idx, tool := range tools { + toolMap, ok := tool.(map[string]any) + if !ok { + continue + } + + toolType, _ := toolMap["type"].(string) + if strings.TrimSpace(toolType) != "function" { + continue + } + + function, ok := toolMap["function"].(map[string]any) + if !ok { + continue + } + + if _, ok := toolMap["name"]; !ok { + if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" { + toolMap["name"] = name + modified = true + } + } + if _, ok := toolMap["description"]; !ok { + if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" { + toolMap["description"] = desc + modified = true + } + } + if _, ok := toolMap["parameters"]; !ok { + if params, ok := function["parameters"]; ok { + toolMap["parameters"] = params + modified = true + } + } + if _, ok := toolMap["strict"]; !ok { + if strict, ok := function["strict"]; ok { + toolMap["strict"] = strict + modified = true + } + } + + tools[idx] = toolMap + } + + if modified { + reqBody["tools"] = tools + } + + return modified +} + +func codexCachePath(filename string) string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + cacheDir := filepath.Join(home, ".opencode", "cache") + if filename == "" { + return cacheDir + } + return filepath.Join(cacheDir, filename) +} + +func readFile(path string) (string, bool) { + if path == "" { + return "", false + } + data, err := os.ReadFile(path) + if err != nil { + return "", false + } + return string(data), true +} + +func writeFile(path, content string) error { + if path == "" { + return fmt.Errorf("empty cache path") + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + return os.WriteFile(path, []byte(content), 0o644) +} + +func loadJSON(path string, target any) bool { + data, err := os.ReadFile(path) + if err != nil { + return false + } + if err := json.Unmarshal(data, target); err != nil { + return false + } + return true +} + +func writeJSON(path string, value any) error { + if path == "" { + return fmt.Errorf("empty json path") + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + data, err := json.Marshal(value) + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func fetchWithETag(url, etag string) (string, string, int, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return "", "", 0, err + } + req.Header.Set("User-Agent", "sub2api-codex") + if etag != "" { + req.Header.Set("If-None-Match", etag) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", "", 0, err + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", resp.StatusCode, err + } + return string(body), resp.Header.Get("etag"), resp.StatusCode, nil +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d11cbdd9..86b35311 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -12,6 +12,7 @@ import ( "io" "log" "net/http" + "os" "regexp" "sort" "strconv" @@ -20,6 +21,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" @@ -528,33 +530,38 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Extract model and stream from parsed body reqModel, _ := reqBody["model"].(string) reqStream, _ := reqBody["stream"].(bool) + promptCacheKey := "" + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } // Track if body needs re-serialization bodyModified := false originalModel := reqModel - // Apply model mapping - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - reqBody["model"] = mappedModel - bodyModified = true + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + + // Apply model mapping (skip for Codex CLI for transparent forwarding) + mappedModel := reqModel + if !isCodexCLI { + mappedModel = account.GetMappedModel(reqModel) + if mappedModel != reqModel { + reqBody["model"] = mappedModel + bodyModified = true + } } - // For OAuth accounts using ChatGPT internal API: - // 1. Add store: false - // 2. Normalize input format for Codex API compatibility - if account.Type == AccountTypeOAuth { - reqBody["store"] = false - // Codex 上游不接受 max_output_tokens 参数,需要在转发前移除。 - delete(reqBody, "max_output_tokens") - bodyModified = true - - // Normalize input format: convert AI SDK multi-part content format to simplified format - // AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} - // Codex API expects: {"content": "..."} - if normalizeInputForCodexAPI(reqBody) { + if account.Type == AccountTypeOAuth && !isCodexCLI { + codexResult := applyCodexOAuthTransform(reqBody) + if codexResult.Modified { bodyModified = true } + if codexResult.NormalizedModel != "" { + mappedModel = codexResult.NormalizedModel + } + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } } // Re-serialize body only if modified @@ -573,7 +580,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) if err != nil { return nil, err } @@ -674,7 +681,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }, nil } -func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) { +func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { // Determine target URL based on account type var targetURL string switch account.Type { @@ -714,12 +721,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } - // Set accept header based on stream mode - if isStream { - req.Header.Set("accept", "text/event-stream") - } else { - req.Header.Set("accept", "application/json") - } } // Whitelist passthrough headers @@ -731,6 +732,22 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } } } + if account.Type == AccountTypeOAuth { + req.Header.Set("OpenAI-Beta", "responses=experimental") + if isCodexCLI { + req.Header.Set("originator", "codex_cli_rs") + } else { + req.Header.Set("originator", "opencode") + } + req.Header.Set("accept", "text/event-stream") + if promptCacheKey != "" { + req.Header.Set("conversation_id", promptCacheKey) + req.Header.Set("session_id", promptCacheKey) + } else { + req.Header.Del("conversation_id") + req.Header.Del("session_id") + } + } // Apply custom User-Agent if configured customUA := account.GetOpenAIUserAgent() @@ -1109,6 +1126,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r return nil, err } + if account.Type == AccountTypeOAuth { + bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:")) + if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE { + return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel) + } + } + // Parse usage var response struct { Usage struct { @@ -1148,6 +1172,110 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r return usage, nil } +func isEventStreamResponse(header http.Header) bool { + contentType := strings.ToLower(header.Get("Content-Type")) + return strings.Contains(contentType, "text/event-stream") +} + +func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { + bodyText := string(body) + finalResponse, ok := extractCodexFinalResponse(bodyText) + + usage := &OpenAIUsage{} + if ok { + var response struct { + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` + } + if err := json.Unmarshal(finalResponse, &response); err == nil { + usage.InputTokens = response.Usage.InputTokens + usage.OutputTokens = response.Usage.OutputTokens + usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + } + body = finalResponse + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + } else { + usage = s.parseSSEUsageFromBody(bodyText) + if originalModel != mappedModel { + bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) + } + body = []byte(bodyText) + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + + contentType := "application/json; charset=utf-8" + if !ok { + contentType = resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream" + } + } + c.Data(resp.StatusCode, contentType, body) + + return usage, nil +} + +func extractCodexFinalResponse(body string) ([]byte, bool) { + lines := strings.Split(body, "\n") + for _, line := range lines { + if !openaiSSEDataRe.MatchString(line) { + continue + } + data := openaiSSEDataRe.ReplaceAllString(line, "") + if data == "" || data == "[DONE]" { + continue + } + var event struct { + Type string `json:"type"` + Response json.RawMessage `json:"response"` + } + if json.Unmarshal([]byte(data), &event) != nil { + continue + } + if event.Type == "response.done" || event.Type == "response.completed" { + if len(event.Response) > 0 { + return event.Response, true + } + } + } + return nil, false +} + +func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { + usage := &OpenAIUsage{} + lines := strings.Split(body, "\n") + for _, line := range lines { + if !openaiSSEDataRe.MatchString(line) { + continue + } + data := openaiSSEDataRe.ReplaceAllString(line, "") + if data == "" || data == "[DONE]" { + continue + } + s.parseSSEUsage(data, usage) + } + return usage +} + +func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { + lines := strings.Split(body, "\n") + for i, line := range lines { + if !openaiSSEDataRe.MatchString(line) { + continue + } + lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) + } + return strings.Join(lines, "\n") +} + func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) @@ -1187,101 +1315,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel return newBody } -// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format -// that the ChatGPT internal Codex API expects. -// -// AI SDK sends content as an array of typed objects: -// -// {"content": [{"type": "input_text", "text": "hello"}]} -// -// ChatGPT Codex API expects content as a simple string: -// -// {"content": "hello"} -// -// This function modifies reqBody in-place and returns true if any modification was made. -func normalizeInputForCodexAPI(reqBody map[string]any) bool { - input, ok := reqBody["input"] - if !ok { - return false - } - - // Handle case where input is a simple string (already compatible) - if _, isString := input.(string); isString { - return false - } - - // Handle case where input is an array of messages - inputArray, ok := input.([]any) - if !ok { - return false - } - - modified := false - for _, item := range inputArray { - message, ok := item.(map[string]any) - if !ok { - continue - } - - content, ok := message["content"] - if !ok { - continue - } - - // If content is already a string, no conversion needed - if _, isString := content.(string); isString { - continue - } - - // If content is an array (AI SDK format), convert to string - contentArray, ok := content.([]any) - if !ok { - continue - } - - // Extract text from content array - var textParts []string - for _, part := range contentArray { - partMap, ok := part.(map[string]any) - if !ok { - continue - } - - // Handle different content types - partType, _ := partMap["type"].(string) - switch partType { - case "input_text", "text": - // Extract text from input_text or text type - if text, ok := partMap["text"].(string); ok { - textParts = append(textParts, text) - } - case "input_image", "image": - // For images, we need to preserve the original format - // as ChatGPT Codex API may support images in a different way - // For now, skip image parts (they will be lost in conversion) - // TODO: Consider preserving image data or handling it separately - continue - case "input_file", "file": - // Similar to images, file inputs may need special handling - continue - default: - // For unknown types, try to extract text if available - if text, ok := partMap["text"].(string); ok { - textParts = append(textParts, text) - } - } - } - - // Convert content array to string - if len(textParts) > 0 { - message["content"] = strings.Join(textParts, "\n") - modified = true - } - } - - return modified -} - // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 8562d940..55e11b30 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { Credentials: map[string]any{"base_url": "://invalid-url"}, } - _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false) + _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false) if err == nil { t.Fatalf("expected error for invalid base_url when allowlist disabled") } diff --git a/backend/internal/service/promo_service.go b/backend/internal/service/promo_service.go index 9acd5868..5ff63bdc 100644 --- a/backend/internal/service/promo_service.go +++ b/backend/internal/service/promo_service.go @@ -24,10 +24,11 @@ var ( // PromoService 优惠码服务 type PromoService struct { - promoRepo PromoCodeRepository - userRepo UserRepository - billingCacheService *BillingCacheService - entClient *dbent.Client + promoRepo PromoCodeRepository + userRepo UserRepository + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewPromoService 创建优惠码服务实例 @@ -36,12 +37,14 @@ func NewPromoService( userRepo UserRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *PromoService { return &PromoService{ - promoRepo: promoRepo, - userRepo: userRepo, - billingCacheService: billingCacheService, - entClient: entClient, + promoRepo: promoRepo, + userRepo: userRepo, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return fmt.Errorf("commit transaction: %w", err) } + s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount) + // 失效余额缓存 if s.billingCacheService != nil { go func() { @@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return nil } +func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) { + if bonusAmount == 0 || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GenerateRandomCode 生成随机优惠码 func (s *PromoService) GenerateRandomCode() (string, error) { bytes := make([]byte, 8) diff --git a/backend/internal/service/prompts/codex_opencode_bridge.txt b/backend/internal/service/prompts/codex_opencode_bridge.txt new file mode 100644 index 00000000..093aa0f2 --- /dev/null +++ b/backend/internal/service/prompts/codex_opencode_bridge.txt @@ -0,0 +1,122 @@ +# Codex Running in OpenCode + +You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles. + +## CRITICAL: Tool Replacements + + +❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD +- NEVER use: apply_patch, applyPatch +- ALWAYS use: edit tool for ALL file modifications +- Before modifying files: Verify you're using "edit", NOT "apply_patch" + + + +❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD +- NEVER use: update_plan, updatePlan, read_plan, readPlan +- ALWAYS use: todowrite for task/plan updates, todoread to read plans +- Before plan operations: Verify you're using "todowrite", NOT "update_plan" + + +## Available OpenCode Tools + +**File Operations:** +- `write` - Create new files + - Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode. +- `edit` - Modify existing files (REPLACES apply_patch) + - Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing. +- `read` - Read file contents + +**Search/Discovery:** +- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`. +- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set. +- `list` - List directories (requires absolute paths) + +**Execution:** +- `bash` - Run shell commands + - No workdir parameter; do not include it in tool calls. + - Always include a short description for the command. + - Do not use cd; use absolute paths in commands. + - Quote paths containing spaces with double quotes. + - Chain multiple commands with ';' or '&&'; avoid newlines. + - Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features. + - Do not use `ls`/`cat` in bash; use `list`/`read` tools instead. + - For deletions (rm), verify by listing parent dir with `list`. + +**Network:** +- `webfetch` - Fetch web content + - Use fully-formed URLs (http/https; http auto-upgrades to https). + - Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required. + - Read-only; short cache window. + +**Task Management:** +- `todowrite` - Manage tasks/plans (REPLACES update_plan) +- `todoread` - Read current plan + +## Substitution Rules + +Base instruction says: You MUST use instead: +apply_patch → edit +update_plan → todowrite +read_plan → todoread + +**Path Usage:** Use per-tool conventions to avoid conflicts: +- Tool calls: `read`, `edit`, `write`, `list` require absolute paths. +- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed. +- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls. +- Tool schema overrides general path preferences—do not convert required absolute paths to relative. + +## Verification Checklist + +Before file/plan modifications: +1. Am I using "edit" NOT "apply_patch"? +2. Am I using "todowrite" NOT "update_plan"? +3. Is this tool in the approved list above? +4. Am I following each tool's path requirements? + +If ANY answer is NO → STOP and correct before proceeding. + +## OpenCode Working Style + +**Communication:** +- Send brief preambles (8-12 words) before tool calls, building on prior context +- Provide progress updates during longer tasks + +**Execution:** +- Keep working autonomously until query is fully resolved before yielding +- Don't return to user with partial solutions + +**Code Approach:** +- New projects: Be ambitious and creative +- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise + +**Testing:** +- If tests exist: Start specific to your changes, then broader validation + +## Advanced Tools + +**Task Tool (Sub-Agents):** +- Use the Task tool (functions.task) to launch sub-agents +- Check the Task tool description for current agent types and their capabilities +- Useful for complex analysis, specialized workflows, or tasks requiring isolated context +- The agent list is dynamically generated - refer to tool schema for available agents + +**Parallelization:** +- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently. +- Reserve sequential calls for ordered or data-dependent steps. + +**MCP Tools:** +- Model Context Protocol servers provide additional capabilities +- MCP tools are prefixed: `mcp____` +- Check your available tools for MCP integrations +- Use when the tool's functionality matches your task needs + +## What Remains from Codex + +Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations. + +## Approvals & Safety +- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise. +- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification. +- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval. +- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`). \ No newline at end of file diff --git a/backend/internal/service/prompts/tool_remap_message.txt b/backend/internal/service/prompts/tool_remap_message.txt new file mode 100644 index 00000000..4ff986e1 --- /dev/null +++ b/backend/internal/service/prompts/tool_remap_message.txt @@ -0,0 +1,63 @@ + + +YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references. + + + + +❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD +- NEVER use: apply_patch, applyPatch +- ALWAYS use: edit tool for ALL file modifications +- Before modifying files: Verify you're using "edit", NOT "apply_patch" + + + +❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD +- NEVER use: update_plan, updatePlan +- ALWAYS use: todowrite for ALL task/plan operations +- Use todoread to read current plan +- Before plan operations: Verify you're using "todowrite", NOT "update_plan" + + + + +File Operations: + • write - Create new files + • edit - Modify existing files (REPLACES apply_patch) + • patch - Apply diff patches + • read - Read file contents + +Search/Discovery: + • grep - Search file contents + • glob - Find files by pattern + • list - List directories (use relative paths) + +Execution: + • bash - Run shell commands + +Network: + • webfetch - Fetch web content + +Task Management: + • todowrite - Manage tasks/plans (REPLACES update_plan) + • todoread - Read current plan + + + +Base instruction says: You MUST use instead: +apply_patch → edit +update_plan → todowrite +read_plan → todoread +absolute paths → relative paths + + + +Before file/plan modifications: +1. Am I using "edit" NOT "apply_patch"? +2. Am I using "todowrite" NOT "update_plan"? +3. Is this tool in the approved list above? +4. Am I using relative paths? + +If ANY answer is NO → STOP and correct before proceeding. + + \ No newline at end of file diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index b6324235..ff52dc47 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -68,12 +68,13 @@ type RedeemCodeResponse struct { // RedeemService 兑换码服务 type RedeemService struct { - redeemRepo RedeemCodeRepository - userRepo UserRepository - subscriptionService *SubscriptionService - cache RedeemCache - billingCacheService *BillingCacheService - entClient *dbent.Client + redeemRepo RedeemCodeRepository + userRepo UserRepository + subscriptionService *SubscriptionService + cache RedeemCache + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewRedeemService 创建兑换码服务实例 @@ -84,14 +85,16 @@ func NewRedeemService( cache RedeemCache, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *RedeemService { return &RedeemService{ - redeemRepo: redeemRepo, - userRepo: userRepo, - subscriptionService: subscriptionService, - cache: cache, - billingCacheService: billingCacheService, - entClient: entClient, + redeemRepo: redeemRepo, + userRepo: userRepo, + subscriptionService: subscriptionService, + cache: cache, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -324,18 +327,33 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // invalidateRedeemCaches 失效兑换相关的缓存 func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { - if s.billingCacheService == nil { - return - } - switch redeemCode.Type { case RedeemTypeBalance: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) }() + case RedeemTypeConcurrency: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } case RedeemTypeSubscription: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } if redeemCode.GroupID != nil { groupID := *redeemCode.GroupID go func() { diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 10a294ae..aa0a5b87 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -54,17 +54,19 @@ type UsageStats struct { // UsageService 使用统计服务 type UsageService struct { - usageRepo UsageLogRepository - userRepo UserRepository - entClient *dbent.Client + usageRepo UsageLogRepository + userRepo UserRepository + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUsageService 创建使用统计服务实例 -func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService { +func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService { return &UsageService{ - usageRepo: usageRepo, - userRepo: userRepo, - entClient: entClient, + usageRepo: usageRepo, + userRepo: userRepo, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } // 扣除用户余额 + balanceUpdated := false if inserted && req.ActualCost > 0 { if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } + balanceUpdated = true } if tx != nil { @@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } } + s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated) + return usageLog, nil } +func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) { + if !balanceUpdated || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GetByID 根据ID获取使用日志 func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) { log, err := s.usageRepo.GetByID(ctx, id) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 08fa40b5..1734914a 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -55,13 +55,15 @@ type ChangePasswordRequest struct { // UserService 用户服务 type UserService struct { - userRepo UserRepository + userRepo UserRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUserService 创建用户服务实例 -func NewUserService(userRepo UserRepository) *UserService { +func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { return &UserService{ - userRepo: userRepo, + userRepo: userRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err != nil { return nil, fmt.Errorf("get user: %w", err) } + oldConcurrency := user.Concurrency // 更新字段 if req.Email != nil { @@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return user, nil } @@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil { return fmt.Errorf("update balance: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil { return fmt.Errorf("update concurrency: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -192,12 +204,18 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } // Delete 删除用户(管理员功能) func (s *UserService) Delete(ctx context.Context, userID int64) error { + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } if err := s.userRepo.Delete(ctx, userID); err != nil { return fmt.Errorf("delete user: %w", err) } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 62f69295..f2cb9c44 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -49,6 +49,13 @@ func ProvideTokenRefreshService( return svc } +// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 +func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { + svc := NewDashboardAggregationService(repo, timingWheel, cfg) + svc.Start() + return svc +} + // ProvideAccountExpiryService creates and starts AccountExpiryService. func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService { svc := NewAccountExpiryService(accountRepo, time.Minute) @@ -145,12 +152,18 @@ func ProvideOpsScheduledReportService( return svc } +// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力 +func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator { + return apiKeyService +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services NewAuthService, NewUserService, NewAPIKeyService, + ProvideAPIKeyAuthCacheInvalidator, NewGroupService, NewAccountService, NewProxyService, @@ -194,6 +207,7 @@ var ProviderSet = wire.NewSet( ProvideTokenRefreshService, ProvideAccountExpiryService, ProvideTimingWheelService, + ProvideDashboardAggregationService, ProvideDeferredService, NewAntigravityQuotaFetcher, NewUserAttributeService, diff --git a/backend/migrations/034_usage_dashboard_aggregation_tables.sql b/backend/migrations/034_usage_dashboard_aggregation_tables.sql new file mode 100644 index 00000000..64b383d4 --- /dev/null +++ b/backend/migrations/034_usage_dashboard_aggregation_tables.sql @@ -0,0 +1,77 @@ +-- Usage dashboard aggregation tables (hourly/daily) + active-user dedup + watermark. +-- These tables support Admin Dashboard statistics without full-table scans on usage_logs. + +-- Hourly aggregates (UTC buckets). +CREATE TABLE IF NOT EXISTS usage_dashboard_hourly ( + bucket_start TIMESTAMPTZ PRIMARY KEY, + total_requests BIGINT NOT NULL DEFAULT 0, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + cache_creation_tokens BIGINT NOT NULL DEFAULT 0, + cache_read_tokens BIGINT NOT NULL DEFAULT 0, + total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + total_duration_ms BIGINT NOT NULL DEFAULT 0, + active_users BIGINT NOT NULL DEFAULT 0, + computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_hourly_bucket_start + ON usage_dashboard_hourly (bucket_start DESC); + +COMMENT ON TABLE usage_dashboard_hourly IS 'Pre-aggregated hourly usage metrics for admin dashboard (UTC buckets).'; +COMMENT ON COLUMN usage_dashboard_hourly.bucket_start IS 'UTC start timestamp of the hour bucket.'; +COMMENT ON COLUMN usage_dashboard_hourly.computed_at IS 'When the hourly row was last computed/refreshed.'; + +-- Daily aggregates (UTC dates). +CREATE TABLE IF NOT EXISTS usage_dashboard_daily ( + bucket_date DATE PRIMARY KEY, + total_requests BIGINT NOT NULL DEFAULT 0, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + cache_creation_tokens BIGINT NOT NULL DEFAULT 0, + cache_read_tokens BIGINT NOT NULL DEFAULT 0, + total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + total_duration_ms BIGINT NOT NULL DEFAULT 0, + active_users BIGINT NOT NULL DEFAULT 0, + computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_daily_bucket_date + ON usage_dashboard_daily (bucket_date DESC); + +COMMENT ON TABLE usage_dashboard_daily IS 'Pre-aggregated daily usage metrics for admin dashboard (UTC dates).'; +COMMENT ON COLUMN usage_dashboard_daily.bucket_date IS 'UTC date of the day bucket.'; +COMMENT ON COLUMN usage_dashboard_daily.computed_at IS 'When the daily row was last computed/refreshed.'; + +-- Hourly active user dedup table. +CREATE TABLE IF NOT EXISTS usage_dashboard_hourly_users ( + bucket_start TIMESTAMPTZ NOT NULL, + user_id BIGINT NOT NULL, + PRIMARY KEY (bucket_start, user_id) +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_hourly_users_bucket_start + ON usage_dashboard_hourly_users (bucket_start); + +-- Daily active user dedup table. +CREATE TABLE IF NOT EXISTS usage_dashboard_daily_users ( + bucket_date DATE NOT NULL, + user_id BIGINT NOT NULL, + PRIMARY KEY (bucket_date, user_id) +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_daily_users_bucket_date + ON usage_dashboard_daily_users (bucket_date); + +-- Aggregation watermark table (single row). +CREATE TABLE IF NOT EXISTS usage_dashboard_aggregation_watermark ( + id INT PRIMARY KEY, + last_aggregated_at TIMESTAMPTZ NOT NULL DEFAULT TIMESTAMPTZ '1970-01-01 00:00:00+00', + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +INSERT INTO usage_dashboard_aggregation_watermark (id) +VALUES (1) +ON CONFLICT (id) DO NOTHING; diff --git a/backend/migrations/035_usage_logs_partitioning.sql b/backend/migrations/035_usage_logs_partitioning.sql new file mode 100644 index 00000000..e25a105e --- /dev/null +++ b/backend/migrations/035_usage_logs_partitioning.sql @@ -0,0 +1,54 @@ +-- usage_logs monthly partition bootstrap. +-- Only creates partitions when usage_logs is already partitioned. +-- Converting usage_logs to a partitioned table requires a manual migration plan. + +DO $$ +DECLARE + is_partitioned BOOLEAN := FALSE; + has_data BOOLEAN := FALSE; + month_start DATE; + prev_month DATE; + next_month DATE; +BEGIN + SELECT EXISTS( + SELECT 1 + FROM pg_partitioned_table pt + JOIN pg_class c ON c.oid = pt.partrelid + WHERE c.relname = 'usage_logs' + ) INTO is_partitioned; + + IF NOT is_partitioned THEN + SELECT EXISTS(SELECT 1 FROM usage_logs LIMIT 1) INTO has_data; + IF NOT has_data THEN + -- Automatic conversion is intentionally skipped; see manual migration plan. + RAISE NOTICE 'usage_logs is not partitioned; skip automatic partitioning'; + END IF; + END IF; + + IF is_partitioned THEN + month_start := date_trunc('month', now() AT TIME ZONE 'UTC')::date; + prev_month := (month_start - INTERVAL '1 month')::date; + next_month := (month_start + INTERVAL '1 month')::date; + + EXECUTE format( + 'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)', + to_char(prev_month, 'YYYYMM'), + prev_month, + month_start + ); + + EXECUTE format( + 'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)', + to_char(month_start, 'YYYYMM'), + month_start, + next_month + ); + + EXECUTE format( + 'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)', + to_char(next_month, 'YYYYMM'), + next_month, + (next_month + INTERVAL '1 month')::date + ); + END IF; +END $$; diff --git a/config.yaml b/config.yaml index 13e7977c..424ce9eb 100644 --- a/config.yaml +++ b/config.yaml @@ -170,6 +170,87 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + +# ============================================================================= +# Dashboard Cache Configuration +# 仪表盘缓存配置 +# ============================================================================= +dashboard_cache: + # Enable dashboard cache + # 启用仪表盘缓存 + enabled: true + # Redis key prefix for multi-environment isolation + # Redis key 前缀,用于多环境隔离 + key_prefix: "sub2api:" + # Fresh TTL (seconds); within this window cached stats are considered fresh + # 新鲜阈值(秒);命中后处于该窗口视为新鲜数据 + stats_fresh_ttl_seconds: 15 + # Cache TTL (seconds) stored in Redis + # Redis 缓存 TTL(秒) + stats_ttl_seconds: 30 + # Async refresh timeout (seconds) + # 异步刷新超时(秒) + stats_refresh_timeout_seconds: 30 + +# ============================================================================= +# Dashboard Aggregation Configuration +# 仪表盘预聚合配置(重启生效) +# ============================================================================= +dashboard_aggregation: + # Enable aggregation job + # 启用聚合作业 + enabled: true + # Refresh interval (seconds) + # 刷新间隔(秒) + interval_seconds: 60 + # Lookback window (seconds) for late-arriving data + # 回看窗口(秒),处理迟到数据 + lookback_seconds: 120 + # Allow manual backfill + # 允许手动回填 + backfill_enabled: false + # Backfill max range (days) + # 回填最大跨度(天) + backfill_max_days: 31 + # Recompute recent N days on startup + # 启动时重算最近 N 天 + recompute_days: 2 + # Retention windows (days) + # 保留窗口(天) + retention: + # Raw usage_logs retention + # 原始 usage_logs 保留天数 + usage_logs_days: 90 + # Hourly aggregation retention + # 小时聚合保留天数 + hourly_days: 180 + # Daily aggregation retention + # 日聚合保留天数 + daily_days: 730 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/.env.example b/deploy/.env.example index 4e77c720..27618284 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -69,6 +69,33 @@ JWT_EXPIRE_HOUR=24 # Leave unset to use default ./config.yaml #CONFIG_FILE=./config.yaml +# ----------------------------------------------------------------------------- +# Dashboard Aggregation (Optional) +# ----------------------------------------------------------------------------- +# Enable aggregation job +# 启用仪表盘预聚合 +DASHBOARD_AGGREGATION_ENABLED=true +# Refresh interval (seconds) +# 刷新间隔(秒) +DASHBOARD_AGGREGATION_INTERVAL_SECONDS=60 +# Lookback window (seconds) +# 回看窗口(秒) +DASHBOARD_AGGREGATION_LOOKBACK_SECONDS=120 +# Allow manual backfill +# 允许手动回填 +DASHBOARD_AGGREGATION_BACKFILL_ENABLED=false +# Backfill max range (days) +# 回填最大跨度(天) +DASHBOARD_AGGREGATION_BACKFILL_MAX_DAYS=31 +# Recompute recent N days on startup +# 启动时重算最近 N 天 +DASHBOARD_AGGREGATION_RECOMPUTE_DAYS=2 +# Retention windows (days) +# 保留窗口(天) +DASHBOARD_AGGREGATION_RETENTION_USAGE_LOGS_DAYS=90 +DASHBOARD_AGGREGATION_RETENTION_HOURLY_DAYS=180 +DASHBOARD_AGGREGATION_RETENTION_DAILY_DAYS=730 + # ----------------------------------------------------------------------------- # Security Configuration # ----------------------------------------------------------------------------- diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 46489799..b1fc9bbd 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -170,6 +170,87 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + +# ============================================================================= +# Dashboard Cache Configuration +# 仪表盘缓存配置 +# ============================================================================= +dashboard_cache: + # Enable dashboard cache + # 启用仪表盘缓存 + enabled: true + # Redis key prefix for multi-environment isolation + # Redis key 前缀,用于多环境隔离 + key_prefix: "sub2api:" + # Fresh TTL (seconds); within this window cached stats are considered fresh + # 新鲜阈值(秒);命中后处于该窗口视为新鲜数据 + stats_fresh_ttl_seconds: 15 + # Cache TTL (seconds) stored in Redis + # Redis 缓存 TTL(秒) + stats_ttl_seconds: 30 + # Async refresh timeout (seconds) + # 异步刷新超时(秒) + stats_refresh_timeout_seconds: 30 + +# ============================================================================= +# Dashboard Aggregation Configuration +# 仪表盘预聚合配置(重启生效) +# ============================================================================= +dashboard_aggregation: + # Enable aggregation job + # 启用聚合作业 + enabled: true + # Refresh interval (seconds) + # 刷新间隔(秒) + interval_seconds: 60 + # Lookback window (seconds) for late-arriving data + # 回看窗口(秒),处理迟到数据 + lookback_seconds: 120 + # Allow manual backfill + # 允许手动回填 + backfill_enabled: false + # Backfill max range (days) + # 回填最大跨度(天) + backfill_max_days: 31 + # Recompute recent N days on startup + # 启动时重算最近 N 天 + recompute_days: 2 + # Retention windows (days) + # 保留窗口(天) + retention: + # Raw usage_logs retention + # 原始 usage_logs 保留天数 + usage_logs_days: 90 + # Hourly aggregation retention + # 小时聚合保留天数 + hourly_days: 180 + # Daily aggregation retention + # 日聚合保留天数 + daily_days: 730 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 4e1f6cd3..54d0ad94 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -275,11 +275,15 @@ export async function bulkUpdate( ): Promise<{ success: number failed: number + success_ids?: number[] + failed_ids?: number[] results: Array<{ account_id: number; success: boolean; error?: string }> -}> { + }> { const { data } = await apiClient.post<{ success: number failed: number + success_ids?: number[] + failed_ids?: number[] results: Array<{ account_id: number; success: boolean; error?: string }> }>('/admin/accounts/bulk-update', { account_ids: accountIds, diff --git a/frontend/src/components/common/DataTable.vue b/frontend/src/components/common/DataTable.vue index 7ad31f7d..dc492d36 100644 --- a/frontend/src/components/common/DataTable.vue +++ b/frontend/src/components/common/DataTable.vue @@ -83,7 +83,7 @@ string | number) } const props = withDefaults(defineProps(), { @@ -222,6 +223,18 @@ const props = withDefaults(defineProps(), { const sortKey = ref('') const sortOrder = ref<'asc' | 'desc'>('asc') const actionsExpanded = ref(false) +const resolveRowKey = (row: any, index: number) => { + if (typeof props.rowKey === 'function') { + const key = props.rowKey(row) + return key ?? index + } + if (typeof props.rowKey === 'string' && props.rowKey) { + const key = row?.[props.rowKey] + return key ?? index + } + const key = row?.id + return key ?? index +} // 数据/列变化时重新检查滚动状态 // 注意:不能监听 actionsExpanded,因为 checkActionsColumnWidth 会临时修改它,会导致无限循环 diff --git a/frontend/src/components/common/README.md b/frontend/src/components/common/README.md index 640cdc0e..1733cfad 100644 --- a/frontend/src/components/common/README.md +++ b/frontend/src/components/common/README.md @@ -13,6 +13,7 @@ A generic data table component with sorting, loading states, and custom cell ren - `columns: Column[]` - Array of column definitions with key, label, sortable, and formatter - `data: any[]` - Array of data objects to display - `loading?: boolean` - Show loading skeleton +- `rowKey?: string | (row: any) => string | number` - Row key field or resolver (defaults to `row.id`, falls back to index) **Slots:** diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 546a53ab..58f42ae6 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -28,8 +28,8 @@ {{ platformDescription }}

- -
+ +