merge: 合并main分支最新改动
解决冲突: - backend/internal/config/config.go: 合并Ops和Dashboard配置 - backend/internal/server/api_contract_test.go: 合并handler初始化 - backend/internal/service/openai_gateway_service.go: 保留Ops错误追踪逻辑 - backend/internal/service/wire.go: 合并Ops和APIKeyAuth provider 主要合并内容: - Dashboard缓存和预聚合功能 - API Key认证缓存优化 - Codex转换支持 - 使用日志分区表
This commit is contained in:
@@ -55,31 +55,36 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingCache := repository.NewBillingCache(redisClient)
|
billingCache := repository.NewBillingCache(redisClient)
|
||||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client)
|
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
|
||||||
userService := service.NewUserService(userRepository)
|
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
|
||||||
userHandler := handler.NewUserHandler(userService)
|
|
||||||
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
apiKeyRepository := repository.NewAPIKeyRepository(client)
|
||||||
groupRepository := repository.NewGroupRepository(client, db)
|
groupRepository := repository.NewGroupRepository(client, db)
|
||||||
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
apiKeyCache := repository.NewAPIKeyCache(redisClient)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||||
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||||
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||||
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
|
||||||
|
userHandler := handler.NewUserHandler(userService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
|
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
|
||||||
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
redeemHandler := handler.NewRedeemHandler(redeemService)
|
redeemHandler := handler.NewRedeemHandler(redeemService)
|
||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
dashboardService := service.NewDashboardService(usageLogRepository)
|
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
|
||||||
dashboardHandler := admin.NewDashboardHandler(dashboardService)
|
timingWheelService := service.ProvideTimingWheelService()
|
||||||
|
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
|
||||||
|
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
|
||||||
|
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||||
accountRepository := repository.NewAccountRepository(client, db)
|
accountRepository := repository.NewAccountRepository(client, db)
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
@@ -124,7 +129,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityCache := repository.NewIdentityCache(redisClient)
|
identityCache := repository.NewIdentityCache(redisClient)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
timingWheelService := service.ProvideTimingWheelService()
|
|
||||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
|
||||||
|
|||||||
@@ -46,11 +46,13 @@ require (
|
|||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.6.0 // indirect
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/ebitengine/purego v0.8.4 // indirect
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
github.com/fatih/color v1.18.0 // indirect
|
github.com/fatih/color v1.18.0 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE=
|
||||||
|
github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
@@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
|
|||||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
|
|||||||
@@ -36,26 +36,29 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Security SecurityConfig `mapstructure:"security"`
|
Security SecurityConfig `mapstructure:"security"`
|
||||||
Billing BillingConfig `mapstructure:"billing"`
|
Billing BillingConfig `mapstructure:"billing"`
|
||||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
Ops OpsConfig `mapstructure:"ops"`
|
Ops OpsConfig `mapstructure:"ops"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
Update UpdateConfig `mapstructure:"update"`
|
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||||
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
|
Update UpdateConfig `mapstructure:"update"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiConfig struct {
|
type GeminiConfig struct {
|
||||||
@@ -412,6 +415,55 @@ type RateLimitConfig struct {
|
|||||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheConfig API Key 认证缓存配置
|
||||||
|
type APIKeyAuthCacheConfig struct {
|
||||||
|
L1Size int `mapstructure:"l1_size"`
|
||||||
|
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
|
||||||
|
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
|
||||||
|
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
|
||||||
|
JitterPercent int `mapstructure:"jitter_percent"`
|
||||||
|
Singleflight bool `mapstructure:"singleflight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardCacheConfig 仪表盘统计缓存配置
|
||||||
|
type DashboardCacheConfig struct {
|
||||||
|
// Enabled: 是否启用仪表盘缓存
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// KeyPrefix: Redis key 前缀,用于多环境隔离
|
||||||
|
KeyPrefix string `mapstructure:"key_prefix"`
|
||||||
|
// StatsFreshTTLSeconds: 缓存命中认为“新鲜”的时间窗口(秒)
|
||||||
|
StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"`
|
||||||
|
// StatsTTLSeconds: Redis 缓存总 TTL(秒)
|
||||||
|
StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"`
|
||||||
|
// StatsRefreshTimeoutSeconds: 异步刷新超时(秒)
|
||||||
|
StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardAggregationConfig 仪表盘预聚合配置
|
||||||
|
type DashboardAggregationConfig struct {
|
||||||
|
// Enabled: 是否启用预聚合作业
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// IntervalSeconds: 聚合刷新间隔(秒)
|
||||||
|
IntervalSeconds int `mapstructure:"interval_seconds"`
|
||||||
|
// LookbackSeconds: 回看窗口(秒)
|
||||||
|
LookbackSeconds int `mapstructure:"lookback_seconds"`
|
||||||
|
// BackfillEnabled: 是否允许全量回填
|
||||||
|
BackfillEnabled bool `mapstructure:"backfill_enabled"`
|
||||||
|
// BackfillMaxDays: 回填最大跨度(天)
|
||||||
|
BackfillMaxDays int `mapstructure:"backfill_max_days"`
|
||||||
|
// Retention: 各表保留窗口(天)
|
||||||
|
Retention DashboardAggregationRetentionConfig `mapstructure:"retention"`
|
||||||
|
// RecomputeDays: 启动时重算最近 N 天
|
||||||
|
RecomputeDays int `mapstructure:"recompute_days"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||||
|
type DashboardAggregationRetentionConfig struct {
|
||||||
|
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||||
|
HourlyDays int `mapstructure:"hourly_days"`
|
||||||
|
DailyDays int `mapstructure:"daily_days"`
|
||||||
|
}
|
||||||
|
|
||||||
func NormalizeRunMode(value string) string {
|
func NormalizeRunMode(value string) string {
|
||||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||||
switch normalized {
|
switch normalized {
|
||||||
@@ -465,6 +517,19 @@ func Load() (*Config, error) {
|
|||||||
cfg.Server.Mode = "debug"
|
cfg.Server.Mode = "debug"
|
||||||
}
|
}
|
||||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||||
|
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||||||
|
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||||||
|
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
|
||||||
|
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
|
||||||
|
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
|
||||||
|
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
|
||||||
|
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
|
||||||
|
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
|
||||||
|
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
|
||||||
|
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
|
||||||
|
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
|
||||||
|
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
|
||||||
|
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||||
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
|
||||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||||
@@ -633,6 +698,32 @@ func setDefaults() {
|
|||||||
// Timezone (default to Asia/Shanghai for Chinese users)
|
// Timezone (default to Asia/Shanghai for Chinese users)
|
||||||
viper.SetDefault("timezone", "Asia/Shanghai")
|
viper.SetDefault("timezone", "Asia/Shanghai")
|
||||||
|
|
||||||
|
// API Key auth cache
|
||||||
|
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
|
||||||
|
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
|
||||||
|
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
|
||||||
|
viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
|
||||||
|
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
|
||||||
|
viper.SetDefault("api_key_auth_cache.singleflight", true)
|
||||||
|
|
||||||
|
// Dashboard cache
|
||||||
|
viper.SetDefault("dashboard_cache.enabled", true)
|
||||||
|
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
|
||||||
|
viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15)
|
||||||
|
viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30)
|
||||||
|
viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30)
|
||||||
|
|
||||||
|
// Dashboard aggregation
|
||||||
|
viper.SetDefault("dashboard_aggregation.enabled", true)
|
||||||
|
viper.SetDefault("dashboard_aggregation.interval_seconds", 60)
|
||||||
|
viper.SetDefault("dashboard_aggregation.lookback_seconds", 120)
|
||||||
|
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||||
|
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||||
|
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||||
|
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||||
|
|
||||||
// Gateway
|
// Gateway
|
||||||
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||||
viper.SetDefault("gateway.log_upstream_error_body", true)
|
viper.SetDefault("gateway.log_upstream_error_body", true)
|
||||||
@@ -788,6 +879,78 @@ func (c *Config) Validate() error {
|
|||||||
if c.Redis.MinIdleConns > c.Redis.PoolSize {
|
if c.Redis.MinIdleConns > c.Redis.PoolSize {
|
||||||
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
|
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
|
||||||
}
|
}
|
||||||
|
if c.Dashboard.Enabled {
|
||||||
|
if c.Dashboard.StatsFreshTTLSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsTTLSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsRefreshTimeoutSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsFreshTTLSeconds > c.Dashboard.StatsTTLSeconds {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be <= dashboard_cache.stats_ttl_seconds")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if c.Dashboard.StatsFreshTTLSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsTTLSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Dashboard.StatsRefreshTimeoutSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Enabled {
|
||||||
|
if c.DashboardAgg.IntervalSeconds <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.LookbackSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.BackfillMaxDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.BackfillEnabled && c.DashboardAgg.BackfillMaxDays == 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.DailyDays <= 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be positive")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.RecomputeDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if c.DashboardAgg.IntervalSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.interval_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.LookbackSeconds < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.BackfillMaxDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.Retention.DailyDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.retention.daily_days must be non-negative")
|
||||||
|
}
|
||||||
|
if c.DashboardAgg.RecomputeDays < 0 {
|
||||||
|
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
|
||||||
|
}
|
||||||
|
}
|
||||||
if c.Gateway.MaxBodySize <= 0 {
|
if c.Gateway.MaxBodySize <= 0 {
|
||||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -141,3 +141,142 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
|
|||||||
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Dashboard.Enabled {
|
||||||
|
t.Fatalf("Dashboard.Enabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.KeyPrefix != "sub2api:" {
|
||||||
|
t.Fatalf("Dashboard.KeyPrefix = %q, want %q", cfg.Dashboard.KeyPrefix, "sub2api:")
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsFreshTTLSeconds != 15 {
|
||||||
|
t.Fatalf("Dashboard.StatsFreshTTLSeconds = %d, want 15", cfg.Dashboard.StatsFreshTTLSeconds)
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsTTLSeconds != 30 {
|
||||||
|
t.Fatalf("Dashboard.StatsTTLSeconds = %d, want 30", cfg.Dashboard.StatsTTLSeconds)
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsRefreshTimeoutSeconds != 30 {
|
||||||
|
t.Fatalf("Dashboard.StatsRefreshTimeoutSeconds = %d, want 30", cfg.Dashboard.StatsRefreshTimeoutSeconds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardCacheConfigEnabled(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Dashboard.Enabled = true
|
||||||
|
cfg.Dashboard.StatsFreshTTLSeconds = 10
|
||||||
|
cfg.Dashboard.StatsTTLSeconds = 5
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for stats_fresh_ttl_seconds > stats_ttl_seconds, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_cache.stats_fresh_ttl_seconds") {
|
||||||
|
t.Fatalf("Validate() expected stats_fresh_ttl_seconds error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardCacheConfigDisabled(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Dashboard.Enabled = false
|
||||||
|
cfg.Dashboard.StatsTTLSeconds = -1
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for negative stats_ttl_seconds, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_cache.stats_ttl_seconds") {
|
||||||
|
t.Fatalf("Validate() expected stats_ttl_seconds error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.DashboardAgg.Enabled {
|
||||||
|
t.Fatalf("DashboardAgg.Enabled = false, want true")
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.IntervalSeconds != 60 {
|
||||||
|
t.Fatalf("DashboardAgg.IntervalSeconds = %d, want 60", cfg.DashboardAgg.IntervalSeconds)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.LookbackSeconds != 120 {
|
||||||
|
t.Fatalf("DashboardAgg.LookbackSeconds = %d, want 120", cfg.DashboardAgg.LookbackSeconds)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.BackfillEnabled {
|
||||||
|
t.Fatalf("DashboardAgg.BackfillEnabled = true, want false")
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.BackfillMaxDays != 31 {
|
||||||
|
t.Fatalf("DashboardAgg.BackfillMaxDays = %d, want 31", cfg.DashboardAgg.BackfillMaxDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.DailyDays != 730 {
|
||||||
|
t.Fatalf("DashboardAgg.Retention.DailyDays = %d, want 730", cfg.DashboardAgg.Retention.DailyDays)
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.RecomputeDays != 2 {
|
||||||
|
t.Fatalf("DashboardAgg.RecomputeDays = %d, want 2", cfg.DashboardAgg.RecomputeDays)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardAggregationConfigDisabled(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DashboardAgg.Enabled = false
|
||||||
|
cfg.DashboardAgg.IntervalSeconds = -1
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for negative dashboard_aggregation.interval_seconds, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_aggregation.interval_seconds") {
|
||||||
|
t.Fatalf("Validate() expected interval_seconds error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DashboardAgg.BackfillEnabled = true
|
||||||
|
cfg.DashboardAgg.BackfillMaxDays = 0
|
||||||
|
err = cfg.Validate()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Validate() expected error for dashboard_aggregation.backfill_max_days, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "dashboard_aggregation.backfill_max_days") {
|
||||||
|
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,15 +14,17 @@ import (
|
|||||||
|
|
||||||
// DashboardHandler handles admin dashboard statistics
|
// DashboardHandler handles admin dashboard statistics
|
||||||
type DashboardHandler struct {
|
type DashboardHandler struct {
|
||||||
dashboardService *service.DashboardService
|
dashboardService *service.DashboardService
|
||||||
startTime time.Time // Server start time for uptime calculation
|
aggregationService *service.DashboardAggregationService
|
||||||
|
startTime time.Time // Server start time for uptime calculation
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDashboardHandler creates a new admin dashboard handler
|
// NewDashboardHandler creates a new admin dashboard handler
|
||||||
func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler {
|
func NewDashboardHandler(dashboardService *service.DashboardService, aggregationService *service.DashboardAggregationService) *DashboardHandler {
|
||||||
return &DashboardHandler{
|
return &DashboardHandler{
|
||||||
dashboardService: dashboardService,
|
dashboardService: dashboardService,
|
||||||
startTime: time.Now(),
|
aggregationService: aggregationService,
|
||||||
|
startTime: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,6 +117,58 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
|||||||
// 性能指标
|
// 性能指标
|
||||||
"rpm": stats.Rpm,
|
"rpm": stats.Rpm,
|
||||||
"tpm": stats.Tpm,
|
"tpm": stats.Tpm,
|
||||||
|
|
||||||
|
// 预聚合新鲜度
|
||||||
|
"hourly_active_users": stats.HourlyActiveUsers,
|
||||||
|
"stats_updated_at": stats.StatsUpdatedAt,
|
||||||
|
"stats_stale": stats.StatsStale,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type DashboardAggregationBackfillRequest struct {
|
||||||
|
Start string `json:"start"`
|
||||||
|
End string `json:"end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BackfillAggregation handles triggering aggregation backfill
|
||||||
|
// POST /api/v1/admin/dashboard/aggregation/backfill
|
||||||
|
func (h *DashboardHandler) BackfillAggregation(c *gin.Context) {
|
||||||
|
if h.aggregationService == nil {
|
||||||
|
response.InternalError(c, "Aggregation service not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req DashboardAggregationBackfillRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
start, err := time.Parse(time.RFC3339, req.Start)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid start time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
end, err := time.Parse(time.RFC3339, req.End)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid end time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.aggregationService.TriggerBackfill(start, end); err != nil {
|
||||||
|
if errors.Is(err, service.ErrDashboardBackfillDisabled) {
|
||||||
|
response.Forbidden(c, "Backfill is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, service.ErrDashboardBackfillTooLarge) {
|
||||||
|
response.BadRequest(c, "Backfill range too large")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalError(c, "Failed to trigger backfill")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"status": "accepted",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
@@ -94,15 +95,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For non-Codex CLI requests, set default instructions
|
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
if !openai.IsCodexCLIRequest(userAgent) {
|
if !openai.IsCodexCLIRequest(userAgent) {
|
||||||
reqBody["instructions"] = openai.DefaultInstructions
|
existingInstructions, _ := reqBody["instructions"].(string)
|
||||||
// Re-serialize body
|
if strings.TrimSpace(existingInstructions) == "" {
|
||||||
body, err = json.Marshal(reqBody)
|
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||||
if err != nil {
|
reqBody["instructions"] = instructions
|
||||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
// Re-serialize body
|
||||||
return
|
body, err = json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ type DashboardStats struct {
|
|||||||
TotalUsers int64 `json:"total_users"`
|
TotalUsers int64 `json:"total_users"`
|
||||||
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
|
||||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||||
|
// 小时活跃用户数(UTC 当前小时)
|
||||||
|
HourlyActiveUsers int64 `json:"hourly_active_users"`
|
||||||
|
|
||||||
|
// 预聚合新鲜度
|
||||||
|
StatsUpdatedAt string `json:"stats_updated_at"`
|
||||||
|
StatsStale bool `json:"stats_stale"`
|
||||||
|
|
||||||
// API Key 统计
|
// API Key 统计
|
||||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||||
apiKeyRateLimitDuration = 24 * time.Hour
|
apiKeyRateLimitDuration = 24 * time.Hour
|
||||||
|
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||||
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
|
|||||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func apiKeyAuthCacheKey(key string) string {
|
||||||
|
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
||||||
|
}
|
||||||
|
|
||||||
type apiKeyCache struct {
|
type apiKeyCache struct {
|
||||||
rdb *redis.Client
|
rdb *redis.Client
|
||||||
}
|
}
|
||||||
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
|
|||||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||||
|
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var entry service.APIKeyAuthCacheEntry
|
||||||
|
if err := json.Unmarshal(val, &entry); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
if entry == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(entry)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
@@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
|
|||||||
return apiKeyEntityToService(m), nil
|
return apiKeyEntityToService(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
|
||||||
// 相比 GetByID,此方法性能更优,因为:
|
// 相比 GetByID,此方法性能更优,因为:
|
||||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
// - 使用 Select() 只查询必要字段,减少数据传输量
|
||||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
// - 适用于删除等只需 key 与用户 ID 的场景
|
||||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
m, err := r.activeQuery().
|
m, err := r.activeQuery().
|
||||||
Where(apikey.IDEQ(id)).
|
Where(apikey.IDEQ(id)).
|
||||||
Select(apikey.FieldUserID).
|
Select(apikey.FieldKey, apikey.FieldUserID).
|
||||||
Only(ctx)
|
Only(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if dbent.IsNotFound(err) {
|
if dbent.IsNotFound(err) {
|
||||||
return 0, service.ErrAPIKeyNotFound
|
return "", 0, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
return 0, err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
return m.UserID, nil
|
return m.Key, m.UserID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
|||||||
return apiKeyEntityToService(m), nil
|
return apiKeyEntityToService(m), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
m, err := r.activeQuery().
|
||||||
|
Where(apikey.KeyEQ(key)).
|
||||||
|
Select(
|
||||||
|
apikey.FieldID,
|
||||||
|
apikey.FieldUserID,
|
||||||
|
apikey.FieldGroupID,
|
||||||
|
apikey.FieldStatus,
|
||||||
|
apikey.FieldIPWhitelist,
|
||||||
|
apikey.FieldIPBlacklist,
|
||||||
|
).
|
||||||
|
WithUser(func(q *dbent.UserQuery) {
|
||||||
|
q.Select(
|
||||||
|
user.FieldID,
|
||||||
|
user.FieldStatus,
|
||||||
|
user.FieldRole,
|
||||||
|
user.FieldBalance,
|
||||||
|
user.FieldConcurrency,
|
||||||
|
)
|
||||||
|
}).
|
||||||
|
WithGroup(func(q *dbent.GroupQuery) {
|
||||||
|
q.Select(
|
||||||
|
group.FieldID,
|
||||||
|
group.FieldName,
|
||||||
|
group.FieldPlatform,
|
||||||
|
group.FieldStatus,
|
||||||
|
group.FieldSubscriptionType,
|
||||||
|
group.FieldRateMultiplier,
|
||||||
|
group.FieldDailyLimitUsd,
|
||||||
|
group.FieldWeeklyLimitUsd,
|
||||||
|
group.FieldMonthlyLimitUsd,
|
||||||
|
group.FieldImagePrice1k,
|
||||||
|
group.FieldImagePrice2k,
|
||||||
|
group.FieldImagePrice4k,
|
||||||
|
group.FieldClaudeCodeOnly,
|
||||||
|
group.FieldFallbackGroupID,
|
||||||
|
)
|
||||||
|
}).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return apiKeyEntityToService(m), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||||
@@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
|||||||
return int64(count), err
|
return int64(count), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
keys, err := r.activeQuery().
|
||||||
|
Where(apikey.UserIDEQ(userID)).
|
||||||
|
Select(apikey.FieldKey).
|
||||||
|
Strings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
keys, err := r.activeQuery().
|
||||||
|
Where(apikey.GroupIDEQ(groupID)).
|
||||||
|
Select(apikey.FieldKey).
|
||||||
|
Strings(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
363
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
363
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardAggregationRepository struct {
|
||||||
|
sql sqlExecutor
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||||
|
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||||
|
return newDashboardAggregationRepositoryWithSQL(sqlDB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
|
||||||
|
return &dashboardAggregationRepository{sql: sqlq}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
startUTC := start.UTC()
|
||||||
|
endUTC := end.UTC()
|
||||||
|
if !endUTC.After(startUTC) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hourStart := startUTC.Truncate(time.Hour)
|
||||||
|
hourEnd := endUTC.Truncate(time.Hour)
|
||||||
|
if endUTC.After(hourEnd) {
|
||||||
|
hourEnd = hourEnd.Add(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
dayStart := truncateToDayUTC(startUTC)
|
||||||
|
dayEnd := truncateToDayUTC(endUTC)
|
||||||
|
if endUTC.After(dayEnd) {
|
||||||
|
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||||
|
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||||
|
var ts time.Time
|
||||||
|
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return time.Unix(0, 0).UTC(), nil
|
||||||
|
}
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
return ts.UTC(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
|
||||||
|
VALUES (1, $1, NOW())
|
||||||
|
ON CONFLICT (id)
|
||||||
|
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||||
|
_, err := r.sql.ExecContext(ctx, `
|
||||||
|
DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1;
|
||||||
|
DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1;
|
||||||
|
DELETE FROM usage_dashboard_daily WHERE bucket_date < $2::date;
|
||||||
|
DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $2::date;
|
||||||
|
`, hourlyCutoff.UTC(), dailyCutoff.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if isPartitioned {
|
||||||
|
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||||
|
}
|
||||||
|
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||||
|
if err != nil || !isPartitioned {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
monthStart := truncateToMonthUTC(now)
|
||||||
|
prevMonth := monthStart.AddDate(0, -1, 0)
|
||||||
|
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||||
|
|
||||||
|
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
|
||||||
|
if err := r.createUsageLogsPartition(ctx, m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||||
|
SELECT DISTINCT
|
||||||
|
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||||
|
user_id
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||||
|
SELECT DISTINCT
|
||||||
|
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||||
|
user_id
|
||||||
|
FROM usage_dashboard_hourly_users
|
||||||
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
WITH hourly AS (
|
||||||
|
SELECT
|
||||||
|
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||||
|
COUNT(*) AS total_requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
GROUP BY 1
|
||||||
|
),
|
||||||
|
user_counts AS (
|
||||||
|
SELECT bucket_start, COUNT(*) AS active_users
|
||||||
|
FROM usage_dashboard_hourly_users
|
||||||
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
|
GROUP BY bucket_start
|
||||||
|
)
|
||||||
|
INSERT INTO usage_dashboard_hourly (
|
||||||
|
bucket_start,
|
||||||
|
total_requests,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
total_duration_ms,
|
||||||
|
active_users,
|
||||||
|
computed_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
hourly.bucket_start,
|
||||||
|
hourly.total_requests,
|
||||||
|
hourly.input_tokens,
|
||||||
|
hourly.output_tokens,
|
||||||
|
hourly.cache_creation_tokens,
|
||||||
|
hourly.cache_read_tokens,
|
||||||
|
hourly.total_cost,
|
||||||
|
hourly.actual_cost,
|
||||||
|
hourly.total_duration_ms,
|
||||||
|
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||||
|
NOW()
|
||||||
|
FROM hourly
|
||||||
|
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
|
||||||
|
ON CONFLICT (bucket_start)
|
||||||
|
DO UPDATE SET
|
||||||
|
total_requests = EXCLUDED.total_requests,
|
||||||
|
input_tokens = EXCLUDED.input_tokens,
|
||||||
|
output_tokens = EXCLUDED.output_tokens,
|
||||||
|
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||||
|
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||||
|
total_cost = EXCLUDED.total_cost,
|
||||||
|
actual_cost = EXCLUDED.actual_cost,
|
||||||
|
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||||
|
active_users = EXCLUDED.active_users,
|
||||||
|
computed_at = EXCLUDED.computed_at
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||||
|
query := `
|
||||||
|
WITH daily AS (
|
||||||
|
SELECT
|
||||||
|
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||||
|
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||||
|
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||||
|
FROM usage_dashboard_hourly
|
||||||
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
|
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
|
||||||
|
),
|
||||||
|
user_counts AS (
|
||||||
|
SELECT bucket_date, COUNT(*) AS active_users
|
||||||
|
FROM usage_dashboard_daily_users
|
||||||
|
WHERE bucket_date >= $3::date AND bucket_date < $4::date
|
||||||
|
GROUP BY bucket_date
|
||||||
|
)
|
||||||
|
INSERT INTO usage_dashboard_daily (
|
||||||
|
bucket_date,
|
||||||
|
total_requests,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
cache_creation_tokens,
|
||||||
|
cache_read_tokens,
|
||||||
|
total_cost,
|
||||||
|
actual_cost,
|
||||||
|
total_duration_ms,
|
||||||
|
active_users,
|
||||||
|
computed_at
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
daily.bucket_date,
|
||||||
|
daily.total_requests,
|
||||||
|
daily.input_tokens,
|
||||||
|
daily.output_tokens,
|
||||||
|
daily.cache_creation_tokens,
|
||||||
|
daily.cache_read_tokens,
|
||||||
|
daily.total_cost,
|
||||||
|
daily.actual_cost,
|
||||||
|
daily.total_duration_ms,
|
||||||
|
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||||
|
NOW()
|
||||||
|
FROM daily
|
||||||
|
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
|
||||||
|
ON CONFLICT (bucket_date)
|
||||||
|
DO UPDATE SET
|
||||||
|
total_requests = EXCLUDED.total_requests,
|
||||||
|
input_tokens = EXCLUDED.input_tokens,
|
||||||
|
output_tokens = EXCLUDED.output_tokens,
|
||||||
|
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||||
|
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||||
|
total_cost = EXCLUDED.total_cost,
|
||||||
|
actual_cost = EXCLUDED.actual_cost,
|
||||||
|
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||||
|
active_users = EXCLUDED.active_users,
|
||||||
|
computed_at = EXCLUDED.computed_at
|
||||||
|
`
|
||||||
|
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
|
||||||
|
query := `
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_partitioned_table pt
|
||||||
|
JOIN pg_class c ON c.oid = pt.partrelid
|
||||||
|
WHERE c.relname = 'usage_logs'
|
||||||
|
)
|
||||||
|
`
|
||||||
|
var partitioned bool
|
||||||
|
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return partitioned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
|
||||||
|
rows, err := r.sql.QueryContext(ctx, `
|
||||||
|
SELECT c.relname
|
||||||
|
FROM pg_inherits
|
||||||
|
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
|
||||||
|
JOIN pg_class p ON p.oid = pg_inherits.inhparent
|
||||||
|
WHERE p.relname = 'usage_logs'
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = rows.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
cutoffMonth := truncateToMonthUTC(cutoff)
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(name, "usage_logs_") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
suffix := strings.TrimPrefix(name, "usage_logs_")
|
||||||
|
month, err := time.Parse("200601", suffix)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
month = month.UTC()
|
||||||
|
if month.Before(cutoffMonth) {
|
||||||
|
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
|
||||||
|
monthStart := truncateToMonthUTC(month)
|
||||||
|
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||||
|
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
|
||||||
|
query := fmt.Sprintf(
|
||||||
|
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
|
||||||
|
pq.QuoteIdentifier(name),
|
||||||
|
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
|
||||||
|
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
|
||||||
|
)
|
||||||
|
_, err := r.sql.ExecContext(ctx, query)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateToDayUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateToMonthUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
58
backend/internal/repository/dashboard_cache.go
Normal file
58
backend/internal/repository/dashboard_cache.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const dashboardStatsCacheKey = "dashboard:stats:v1"
|
||||||
|
|
||||||
|
type dashboardCache struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
keyPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
|
||||||
|
prefix := "sub2api:"
|
||||||
|
if cfg != nil {
|
||||||
|
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||||
|
}
|
||||||
|
if prefix != "" && !strings.HasSuffix(prefix, ":") {
|
||||||
|
prefix += ":"
|
||||||
|
}
|
||||||
|
return &dashboardCache{
|
||||||
|
rdb: rdb,
|
||||||
|
keyPrefix: prefix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
|
||||||
|
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return "", service.ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||||
|
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) buildKey() string {
|
||||||
|
if c.keyPrefix == "" {
|
||||||
|
return dashboardStatsCacheKey
|
||||||
|
}
|
||||||
|
return c.keyPrefix + dashboardStatsCacheKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
|
||||||
|
return c.rdb.Del(ctx, c.buildKey()).Err()
|
||||||
|
}
|
||||||
28
backend/internal/repository/dashboard_cache_test.go
Normal file
28
backend/internal/repository/dashboard_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
|
||||||
|
cache := NewDashboardCache(nil, &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{
|
||||||
|
KeyPrefix: "prod",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
impl, ok := cache.(*dashboardCache)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "prod:", impl.keyPrefix)
|
||||||
|
|
||||||
|
cache = NewDashboardCache(nil, &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{
|
||||||
|
KeyPrefix: "staging:",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
impl, ok = cache.(*dashboardCache)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "staging:", impl.keyPrefix)
|
||||||
|
}
|
||||||
@@ -269,16 +269,60 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
|||||||
type DashboardStats = usagestats.DashboardStats
|
type DashboardStats = usagestats.DashboardStats
|
||||||
|
|
||||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||||
var stats DashboardStats
|
stats := &DashboardStats{}
|
||||||
today := timezone.Today()
|
now := time.Now().UTC()
|
||||||
now := time.Now()
|
todayUTC := truncateToDayUTC(now)
|
||||||
|
|
||||||
// 合并用户统计查询
|
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stats.Rpm = rpm
|
||||||
|
stats.Tpm = tpm
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) {
|
||||||
|
startUTC := start.UTC()
|
||||||
|
endUTC := end.UTC()
|
||||||
|
if !endUTC.After(startUTC) {
|
||||||
|
return nil, errors.New("统计时间范围无效")
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := &DashboardStats{}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
todayUTC := truncateToDayUTC(now)
|
||||||
|
|
||||||
|
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stats.Rpm = rpm
|
||||||
|
stats.Tpm = tpm
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||||
userStatsQuery := `
|
userStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_users,
|
COUNT(*) as total_users,
|
||||||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users,
|
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
|
||||||
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users
|
|
||||||
FROM users
|
FROM users
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
`
|
`
|
||||||
@@ -286,15 +330,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
userStatsQuery,
|
userStatsQuery,
|
||||||
[]any{today, today},
|
[]any{todayUTC},
|
||||||
&stats.TotalUsers,
|
&stats.TotalUsers,
|
||||||
&stats.TodayNewUsers,
|
&stats.TodayNewUsers,
|
||||||
&stats.ActiveUsers,
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 合并API Key统计查询
|
|
||||||
apiKeyStatsQuery := `
|
apiKeyStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_api_keys,
|
COUNT(*) as total_api_keys,
|
||||||
@@ -310,10 +352,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.TotalAPIKeys,
|
&stats.TotalAPIKeys,
|
||||||
&stats.ActiveAPIKeys,
|
&stats.ActiveAPIKeys,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 合并账户统计查询
|
|
||||||
accountStatsQuery := `
|
accountStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_accounts,
|
COUNT(*) as total_accounts,
|
||||||
@@ -335,22 +376,26 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.RateLimitAccounts,
|
&stats.RateLimitAccounts,
|
||||||
&stats.OverloadAccounts,
|
&stats.OverloadAccounts,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 累计 Token 统计
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||||
totalStatsQuery := `
|
totalStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_requests,
|
COALESCE(SUM(total_requests), 0) as total_requests,
|
||||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_dashboard_daily
|
||||||
`
|
`
|
||||||
|
var totalDurationMs int64
|
||||||
if err := scanSingleRow(
|
if err := scanSingleRow(
|
||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
@@ -363,13 +408,100 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.TotalCacheReadTokens,
|
&stats.TotalCacheReadTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalCost,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalActualCost,
|
||||||
&stats.AverageDurationMs,
|
&totalDurationMs,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||||
|
if stats.TotalRequests > 0 {
|
||||||
|
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||||
|
}
|
||||||
|
|
||||||
// 今日 Token 统计
|
todayStatsQuery := `
|
||||||
|
SELECT
|
||||||
|
total_requests as today_requests,
|
||||||
|
input_tokens as today_input_tokens,
|
||||||
|
output_tokens as today_output_tokens,
|
||||||
|
cache_creation_tokens as today_cache_creation_tokens,
|
||||||
|
cache_read_tokens as today_cache_read_tokens,
|
||||||
|
total_cost as today_cost,
|
||||||
|
actual_cost as today_actual_cost,
|
||||||
|
active_users as active_users
|
||||||
|
FROM usage_dashboard_daily
|
||||||
|
WHERE bucket_date = $1::date
|
||||||
|
`
|
||||||
|
if err := scanSingleRow(
|
||||||
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
todayStatsQuery,
|
||||||
|
[]any{todayUTC},
|
||||||
|
&stats.TodayRequests,
|
||||||
|
&stats.TodayInputTokens,
|
||||||
|
&stats.TodayOutputTokens,
|
||||||
|
&stats.TodayCacheCreationTokens,
|
||||||
|
&stats.TodayCacheReadTokens,
|
||||||
|
&stats.TodayCost,
|
||||||
|
&stats.TodayActualCost,
|
||||||
|
&stats.ActiveUsers,
|
||||||
|
); err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||||
|
|
||||||
|
hourlyActiveQuery := `
|
||||||
|
SELECT active_users
|
||||||
|
FROM usage_dashboard_hourly
|
||||||
|
WHERE bucket_start = $1
|
||||||
|
`
|
||||||
|
hourStart := now.UTC().Truncate(time.Hour)
|
||||||
|
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||||||
|
totalStatsQuery := `
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_requests,
|
||||||
|
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||||
|
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||||
|
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||||
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`
|
||||||
|
var totalDurationMs int64
|
||||||
|
if err := scanSingleRow(
|
||||||
|
ctx,
|
||||||
|
r.sql,
|
||||||
|
totalStatsQuery,
|
||||||
|
[]any{startUTC, endUTC},
|
||||||
|
&stats.TotalRequests,
|
||||||
|
&stats.TotalInputTokens,
|
||||||
|
&stats.TotalOutputTokens,
|
||||||
|
&stats.TotalCacheCreationTokens,
|
||||||
|
&stats.TotalCacheReadTokens,
|
||||||
|
&stats.TotalCost,
|
||||||
|
&stats.TotalActualCost,
|
||||||
|
&totalDurationMs,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||||
|
if stats.TotalRequests > 0 {
|
||||||
|
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||||
todayStatsQuery := `
|
todayStatsQuery := `
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as today_requests,
|
COUNT(*) as today_requests,
|
||||||
@@ -380,13 +512,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
`
|
`
|
||||||
if err := scanSingleRow(
|
if err := scanSingleRow(
|
||||||
ctx,
|
ctx,
|
||||||
r.sql,
|
r.sql,
|
||||||
todayStatsQuery,
|
todayStatsQuery,
|
||||||
[]any{today},
|
[]any{todayUTC, todayEnd},
|
||||||
&stats.TodayRequests,
|
&stats.TodayRequests,
|
||||||
&stats.TodayInputTokens,
|
&stats.TodayInputTokens,
|
||||||
&stats.TodayOutputTokens,
|
&stats.TodayOutputTokens,
|
||||||
@@ -395,19 +527,31 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
|||||||
&stats.TodayCost,
|
&stats.TodayCost,
|
||||||
&stats.TodayActualCost,
|
&stats.TodayActualCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||||
|
|
||||||
// 性能指标:RPM 和 TPM(最近1分钟,全局)
|
activeUsersQuery := `
|
||||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
SELECT COUNT(DISTINCT user_id) as active_users
|
||||||
if err != nil {
|
FROM usage_logs
|
||||||
return nil, err
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`
|
||||||
|
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
stats.Rpm = rpm
|
|
||||||
stats.Tpm = tpm
|
|
||||||
|
|
||||||
return &stats, nil
|
hourStart := now.UTC().Truncate(time.Hour)
|
||||||
|
hourEnd := hourStart.Add(time.Hour)
|
||||||
|
hourlyActiveQuery := `
|
||||||
|
SELECT COUNT(DISTINCT user_id) as active_users
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
|
`
|
||||||
|
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
@@ -198,8 +197,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
|
|||||||
// --- GetDashboardStats ---
|
// --- GetDashboardStats ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||||
now := time.Now()
|
now := time.Now().UTC()
|
||||||
todayStart := timezone.Today()
|
todayStart := truncateToDayUTC(now)
|
||||||
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
||||||
s.Require().NoError(err, "GetDashboardStats base")
|
s.Require().NoError(err, "GetDashboardStats base")
|
||||||
|
|
||||||
@@ -273,6 +272,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
|||||||
_, err = s.repo.Create(s.ctx, logPerf)
|
_, err = s.repo.Create(s.ctx, logPerf)
|
||||||
s.Require().NoError(err, "Create logPerf")
|
s.Require().NoError(err, "Create logPerf")
|
||||||
|
|
||||||
|
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||||
|
aggStart := todayStart.Add(-2 * time.Hour)
|
||||||
|
aggEnd := now.Add(2 * time.Minute)
|
||||||
|
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd), "AggregateRange")
|
||||||
|
|
||||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||||
s.Require().NoError(err, "GetDashboardStats")
|
s.Require().NoError(err, "GetDashboardStats")
|
||||||
|
|
||||||
@@ -303,6 +307,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
|||||||
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
todayStart := truncateToDayUTC(now)
|
||||||
|
rangeStart := todayStart.Add(-24 * time.Hour)
|
||||||
|
rangeEnd := now.Add(1 * time.Second)
|
||||||
|
|
||||||
|
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-range-1", Name: "k1"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-range-2", Name: "k2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-range"})
|
||||||
|
|
||||||
|
d1, d2, d3 := 100, 200, 300
|
||||||
|
logOutside := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 7,
|
||||||
|
OutputTokens: 8,
|
||||||
|
TotalCost: 0.8,
|
||||||
|
ActualCost: 0.7,
|
||||||
|
DurationMs: &d3,
|
||||||
|
CreatedAt: rangeStart.Add(-1 * time.Hour),
|
||||||
|
}
|
||||||
|
_, err := s.repo.Create(s.ctx, logOutside)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
logRange := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
CacheCreationTokens: 1,
|
||||||
|
CacheReadTokens: 2,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 0.9,
|
||||||
|
DurationMs: &d1,
|
||||||
|
CreatedAt: rangeStart.Add(2 * time.Hour),
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, logRange)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
logToday := &service.UsageLog{
|
||||||
|
UserID: user2.ID,
|
||||||
|
APIKeyID: apiKey2.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 6,
|
||||||
|
CacheReadTokens: 1,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
DurationMs: &d2,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, logToday)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
stats, err := s.repo.GetDashboardStatsWithRange(s.ctx, rangeStart, rangeEnd)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||||
|
s.Require().Equal(int64(15), stats.TotalInputTokens)
|
||||||
|
s.Require().Equal(int64(26), stats.TotalOutputTokens)
|
||||||
|
s.Require().Equal(int64(1), stats.TotalCacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(3), stats.TotalCacheReadTokens)
|
||||||
|
s.Require().Equal(int64(45), stats.TotalTokens)
|
||||||
|
s.Require().Equal(1.5, stats.TotalCost)
|
||||||
|
s.Require().Equal(1.4, stats.TotalActualCost)
|
||||||
|
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
|
||||||
|
}
|
||||||
|
|
||||||
// --- GetUserDashboardStats ---
|
// --- GetUserDashboardStats ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||||
@@ -333,6 +411,151 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
|||||||
s.Require().Equal(int64(30), stats.Tokens)
|
s.Require().Equal(int64(30), stats.Tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
||||||
|
now := time.Now().UTC().Truncate(time.Second)
|
||||||
|
hour1 := now.Add(-90 * time.Minute).Truncate(time.Hour)
|
||||||
|
hour2 := now.Add(-30 * time.Minute).Truncate(time.Hour)
|
||||||
|
dayStart := truncateToDayUTC(now)
|
||||||
|
|
||||||
|
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u2@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-agg-1", Name: "k1"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-agg-2", Name: "k2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-agg"})
|
||||||
|
|
||||||
|
d1, d2, d3 := 100, 200, 150
|
||||||
|
log1 := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
CacheCreationTokens: 2,
|
||||||
|
CacheReadTokens: 1,
|
||||||
|
TotalCost: 1.0,
|
||||||
|
ActualCost: 0.9,
|
||||||
|
DurationMs: &d1,
|
||||||
|
CreatedAt: hour1.Add(5 * time.Minute),
|
||||||
|
}
|
||||||
|
_, err := s.repo.Create(s.ctx, log1)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
log2 := &service.UsageLog{
|
||||||
|
UserID: user1.ID,
|
||||||
|
APIKeyID: apiKey1.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
DurationMs: &d2,
|
||||||
|
CreatedAt: hour1.Add(20 * time.Minute),
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, log2)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
log3 := &service.UsageLog{
|
||||||
|
UserID: user2.ID,
|
||||||
|
APIKeyID: apiKey2.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 7,
|
||||||
|
OutputTokens: 8,
|
||||||
|
TotalCost: 0.7,
|
||||||
|
ActualCost: 0.7,
|
||||||
|
DurationMs: &d3,
|
||||||
|
CreatedAt: hour2.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
_, err = s.repo.Create(s.ctx, log3)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||||
|
aggStart := hour1.Add(-5 * time.Minute)
|
||||||
|
aggEnd := now.Add(5 * time.Minute)
|
||||||
|
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd))
|
||||||
|
|
||||||
|
type hourlyRow struct {
|
||||||
|
totalRequests int64
|
||||||
|
inputTokens int64
|
||||||
|
outputTokens int64
|
||||||
|
cacheCreationTokens int64
|
||||||
|
cacheReadTokens int64
|
||||||
|
totalCost float64
|
||||||
|
actualCost float64
|
||||||
|
totalDurationMs int64
|
||||||
|
activeUsers int64
|
||||||
|
}
|
||||||
|
fetchHourly := func(bucketStart time.Time) hourlyRow {
|
||||||
|
var row hourlyRow
|
||||||
|
err := scanSingleRow(s.ctx, s.tx, `
|
||||||
|
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||||
|
total_cost, actual_cost, total_duration_ms, active_users
|
||||||
|
FROM usage_dashboard_hourly
|
||||||
|
WHERE bucket_start = $1
|
||||||
|
`, []any{bucketStart}, &row.totalRequests, &row.inputTokens, &row.outputTokens,
|
||||||
|
&row.cacheCreationTokens, &row.cacheReadTokens, &row.totalCost, &row.actualCost,
|
||||||
|
&row.totalDurationMs, &row.activeUsers,
|
||||||
|
)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
|
||||||
|
hour1Row := fetchHourly(hour1)
|
||||||
|
s.Require().Equal(int64(2), hour1Row.totalRequests)
|
||||||
|
s.Require().Equal(int64(15), hour1Row.inputTokens)
|
||||||
|
s.Require().Equal(int64(25), hour1Row.outputTokens)
|
||||||
|
s.Require().Equal(int64(2), hour1Row.cacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(1), hour1Row.cacheReadTokens)
|
||||||
|
s.Require().Equal(1.5, hour1Row.totalCost)
|
||||||
|
s.Require().Equal(1.4, hour1Row.actualCost)
|
||||||
|
s.Require().Equal(int64(300), hour1Row.totalDurationMs)
|
||||||
|
s.Require().Equal(int64(1), hour1Row.activeUsers)
|
||||||
|
|
||||||
|
hour2Row := fetchHourly(hour2)
|
||||||
|
s.Require().Equal(int64(1), hour2Row.totalRequests)
|
||||||
|
s.Require().Equal(int64(7), hour2Row.inputTokens)
|
||||||
|
s.Require().Equal(int64(8), hour2Row.outputTokens)
|
||||||
|
s.Require().Equal(int64(0), hour2Row.cacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(0), hour2Row.cacheReadTokens)
|
||||||
|
s.Require().Equal(0.7, hour2Row.totalCost)
|
||||||
|
s.Require().Equal(0.7, hour2Row.actualCost)
|
||||||
|
s.Require().Equal(int64(150), hour2Row.totalDurationMs)
|
||||||
|
s.Require().Equal(int64(1), hour2Row.activeUsers)
|
||||||
|
|
||||||
|
var daily struct {
|
||||||
|
totalRequests int64
|
||||||
|
inputTokens int64
|
||||||
|
outputTokens int64
|
||||||
|
cacheCreationTokens int64
|
||||||
|
cacheReadTokens int64
|
||||||
|
totalCost float64
|
||||||
|
actualCost float64
|
||||||
|
totalDurationMs int64
|
||||||
|
activeUsers int64
|
||||||
|
}
|
||||||
|
err = scanSingleRow(s.ctx, s.tx, `
|
||||||
|
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||||
|
total_cost, actual_cost, total_duration_ms, active_users
|
||||||
|
FROM usage_dashboard_daily
|
||||||
|
WHERE bucket_date = $1::date
|
||||||
|
`, []any{dayStart}, &daily.totalRequests, &daily.inputTokens, &daily.outputTokens,
|
||||||
|
&daily.cacheCreationTokens, &daily.cacheReadTokens, &daily.totalCost, &daily.actualCost,
|
||||||
|
&daily.totalDurationMs, &daily.activeUsers,
|
||||||
|
)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(int64(3), daily.totalRequests)
|
||||||
|
s.Require().Equal(int64(22), daily.inputTokens)
|
||||||
|
s.Require().Equal(int64(33), daily.outputTokens)
|
||||||
|
s.Require().Equal(int64(2), daily.cacheCreationTokens)
|
||||||
|
s.Require().Equal(int64(1), daily.cacheReadTokens)
|
||||||
|
s.Require().Equal(2.2, daily.totalCost)
|
||||||
|
s.Require().Equal(2.1, daily.actualCost)
|
||||||
|
s.Require().Equal(int64(450), daily.totalDurationMs)
|
||||||
|
s.Require().Equal(int64(2), daily.activeUsers)
|
||||||
|
}
|
||||||
|
|
||||||
// --- GetBatchUserUsageStats ---
|
// --- GetBatchUserUsageStats ---
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewRedeemCodeRepository,
|
NewRedeemCodeRepository,
|
||||||
NewPromoCodeRepository,
|
NewPromoCodeRepository,
|
||||||
NewUsageLogRepository,
|
NewUsageLogRepository,
|
||||||
|
NewDashboardAggregationRepository,
|
||||||
NewSettingRepository,
|
NewSettingRepository,
|
||||||
NewOpsRepository,
|
NewOpsRepository,
|
||||||
NewUserSubscriptionRepository,
|
NewUserSubscriptionRepository,
|
||||||
@@ -59,6 +60,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAPIKeyCache,
|
NewAPIKeyCache,
|
||||||
NewTempUnschedCache,
|
NewTempUnschedCache,
|
||||||
ProvideConcurrencyCache,
|
ProvideConcurrencyCache,
|
||||||
|
NewDashboardCache,
|
||||||
NewEmailCache,
|
NewEmailCache,
|
||||||
NewIdentityCache,
|
NewIdentityCache,
|
||||||
NewRedeemCache,
|
NewRedeemCache,
|
||||||
|
|||||||
@@ -331,6 +331,30 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "POST /api/v1/admin/accounts/bulk-update",
|
||||||
|
method: http.MethodPost,
|
||||||
|
path: "/api/v1/admin/accounts/bulk-update",
|
||||||
|
body: `{"account_ids":[101,102],"schedulable":false}`,
|
||||||
|
headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantJSON: `{
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": {
|
||||||
|
"success": 2,
|
||||||
|
"failed": 0,
|
||||||
|
"success_ids": [101, 102],
|
||||||
|
"failed_ids": [],
|
||||||
|
"results": [
|
||||||
|
{"account_id": 101, "success": true},
|
||||||
|
{"account_id": 102, "success": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -382,6 +406,9 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
apiKeyCache := stubApiKeyCache{}
|
apiKeyCache := stubApiKeyCache{}
|
||||||
groupRepo := stubGroupRepo{}
|
groupRepo := stubGroupRepo{}
|
||||||
userSubRepo := stubUserSubscriptionRepo{}
|
userSubRepo := stubUserSubscriptionRepo{}
|
||||||
|
accountRepo := stubAccountRepo{}
|
||||||
|
proxyRepo := stubProxyRepo{}
|
||||||
|
redeemRepo := stubRedeemCodeRepo{}
|
||||||
|
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
Default: config.DefaultConfig{
|
Default: config.DefaultConfig{
|
||||||
@@ -390,19 +417,21 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
RunMode: config.RunModeStandard,
|
RunMode: config.RunModeStandard,
|
||||||
}
|
}
|
||||||
|
|
||||||
userService := service.NewUserService(userRepo)
|
userService := service.NewUserService(userRepo, nil)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||||
|
|
||||||
usageRepo := newStubUsageLogRepo()
|
usageRepo := newStubUsageLogRepo()
|
||||||
usageService := service.NewUsageService(usageRepo, userRepo, nil)
|
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||||
|
|
||||||
settingRepo := newStubSettingRepo()
|
settingRepo := newStubSettingRepo()
|
||||||
settingService := service.NewSettingService(settingRepo, cfg)
|
settingService := service.NewSettingService(settingRepo, cfg)
|
||||||
|
|
||||||
|
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil)
|
||||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
|
||||||
|
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
jwtAuth := func(c *gin.Context) {
|
jwtAuth := func(c *gin.Context) {
|
||||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||||
@@ -442,6 +471,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
v1Admin := v1.Group("/admin")
|
v1Admin := v1.Group("/admin")
|
||||||
v1Admin.Use(adminAuth)
|
v1Admin.Use(adminAuth)
|
||||||
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
|
||||||
|
v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate)
|
||||||
|
|
||||||
return &contractDeps{
|
return &contractDeps{
|
||||||
now: now,
|
now: now,
|
||||||
@@ -566,6 +596,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type stubGroupRepo struct{}
|
type stubGroupRepo struct{}
|
||||||
|
|
||||||
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
|
||||||
@@ -620,6 +662,235 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type stubAccountRepo struct {
|
||||||
|
bulkUpdateIDs []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||||
|
return nil, service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||||
|
return false, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) Delete(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||||
|
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||||
|
return int64(len(ids)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubProxyRepo struct{}
|
||||||
|
|
||||||
|
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||||
|
return nil, service.ErrProxyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) Delete(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||||
|
return false, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||||
|
return 0, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubRedeemCodeRepo struct{}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||||
|
return nil, service.ErrRedeemCodeNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||||
|
return nil, service.ErrRedeemCodeNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct{}
|
type stubUserSubscriptionRepo struct{}
|
||||||
|
|
||||||
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||||
@@ -738,12 +1009,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
key, ok := r.byID[id]
|
key, ok := r.byID[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, service.ErrAPIKeyNotFound
|
return "", 0, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
return key.UserID, nil
|
return key.Key, key.UserID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -755,6 +1026,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return r.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
return errors.New("nil key")
|
return errors.New("nil key")
|
||||||
@@ -869,6 +1144,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUsageLogRepo struct {
|
type stubUsageLogRepo struct {
|
||||||
userLogs map[int64][]service.UsageLog
|
userLogs map[int64][]service.UsageLog
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
|||||||
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return "", 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
if f.getByKey == nil {
|
if f.getByKey == nil {
|
||||||
@@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK
|
|||||||
}
|
}
|
||||||
return f.getByKey(ctx, key)
|
return f.getByKey(ctx, key)
|
||||||
}
|
}
|
||||||
|
func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return f.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64
|
|||||||
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type googleErrorResponse struct {
|
type googleErrorResponse struct {
|
||||||
Error struct {
|
Error struct {
|
||||||
|
|||||||
@@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
return 0, errors.New("not implemented")
|
return "", 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
@@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
return r.GetByKey(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
return 0, errors.New("not implemented")
|
return 0, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
type stubUserSubscriptionRepo struct {
|
type stubUserSubscriptionRepo struct {
|
||||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||||
|
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -186,9 +186,11 @@ type BulkUpdateAccountResult struct {
|
|||||||
|
|
||||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||||
type BulkUpdateAccountsResult struct {
|
type BulkUpdateAccountsResult struct {
|
||||||
Success int `json:"success"`
|
Success int `json:"success"`
|
||||||
Failed int `json:"failed"`
|
Failed int `json:"failed"`
|
||||||
Results []BulkUpdateAccountResult `json:"results"`
|
SuccessIDs []int64 `json:"success_ids"`
|
||||||
|
FailedIDs []int64 `json:"failed_ids"`
|
||||||
|
Results []BulkUpdateAccountResult `json:"results"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateProxyInput struct {
|
type CreateProxyInput struct {
|
||||||
@@ -244,14 +246,15 @@ type ProxyExitInfoProber interface {
|
|||||||
|
|
||||||
// adminServiceImpl implements AdminService
|
// adminServiceImpl implements AdminService
|
||||||
type adminServiceImpl struct {
|
type adminServiceImpl struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
apiKeyRepo APIKeyRepository
|
apiKeyRepo APIKeyRepository
|
||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
proxyProber ProxyExitInfoProber
|
proxyProber ProxyExitInfoProber
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
@@ -264,16 +267,18 @@ func NewAdminService(
|
|||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
proxyProber ProxyExitInfoProber,
|
proxyProber ProxyExitInfoProber,
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
proxyRepo: proxyRepo,
|
proxyRepo: proxyRepo,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
proxyProber: proxyProber,
|
proxyProber: proxyProber,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +328,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
}
|
}
|
||||||
|
|
||||||
oldConcurrency := user.Concurrency
|
oldConcurrency := user.Concurrency
|
||||||
|
oldStatus := user.Status
|
||||||
|
oldRole := user.Role
|
||||||
|
|
||||||
if input.Email != "" {
|
if input.Email != "" {
|
||||||
user.Email = input.Email
|
user.Email = input.Email
|
||||||
@@ -355,6 +362,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||||
if concurrencyDiff != 0 {
|
if concurrencyDiff != 0 {
|
||||||
@@ -393,6 +405,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
|||||||
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,6 +435,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
balanceDiff := user.Balance - oldBalance
|
||||||
|
if s.authCacheInvalidator != nil && balanceDiff != 0 {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
if s.billingCacheService != nil {
|
if s.billingCacheService != nil {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -431,7 +450,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
balanceDiff := user.Balance - oldBalance
|
|
||||||
if balanceDiff != 0 {
|
if balanceDiff != 0 {
|
||||||
code, err := GenerateRedeemCode()
|
code, err := GenerateRedeemCode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -675,10 +693,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||||
|
var groupKeys []string
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
|
||||||
|
if err == nil {
|
||||||
|
groupKeys = keys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -697,6 +726,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
for _, key := range groupKeys {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -885,7 +919,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||||
result := &BulkUpdateAccountsResult{
|
result := &BulkUpdateAccountsResult{
|
||||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||||
|
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||||
|
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(input.AccountIDs) == 0 {
|
if len(input.AccountIDs) == 0 {
|
||||||
@@ -949,6 +985,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
entry.Success = false
|
entry.Success = false
|
||||||
entry.Error = err.Error()
|
entry.Error = err.Error()
|
||||||
result.Failed++
|
result.Failed++
|
||||||
|
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -958,6 +995,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
entry.Success = false
|
entry.Success = false
|
||||||
entry.Error = err.Error()
|
entry.Error = err.Error()
|
||||||
result.Failed++
|
result.Failed++
|
||||||
|
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -967,6 +1005,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
entry.Success = false
|
entry.Success = false
|
||||||
entry.Error = err.Error()
|
entry.Error = err.Error()
|
||||||
result.Failed++
|
result.Failed++
|
||||||
|
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -974,6 +1013,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
|
|
||||||
entry.Success = true
|
entry.Success = true
|
||||||
result.Success++
|
result.Success++
|
||||||
|
result.SuccessIDs = append(result.SuccessIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type accountRepoStubForBulkUpdate struct {
|
||||||
|
accountRepoStub
|
||||||
|
bulkUpdateErr error
|
||||||
|
bulkUpdateIDs []int64
|
||||||
|
bindGroupErrByID map[int64]error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||||
|
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||||
|
if s.bulkUpdateErr != nil {
|
||||||
|
return 0, s.bulkUpdateErr
|
||||||
|
}
|
||||||
|
return int64(len(ids)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
|
||||||
|
if err, ok := s.bindGroupErrByID[accountID]; ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||||
|
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||||
|
repo := &accountRepoStubForBulkUpdate{}
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
|
schedulable := true
|
||||||
|
input := &BulkUpdateAccountsInput{
|
||||||
|
AccountIDs: []int64{1, 2, 3},
|
||||||
|
Schedulable: &schedulable,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 3, result.Success)
|
||||||
|
require.Equal(t, 0, result.Failed)
|
||||||
|
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
|
||||||
|
require.Empty(t, result.FailedIDs)
|
||||||
|
require.Len(t, result.Results, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
|
||||||
|
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||||
|
repo := &accountRepoStubForBulkUpdate{
|
||||||
|
bindGroupErrByID: map[int64]error{
|
||||||
|
2: errors.New("bind failed"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &adminServiceImpl{accountRepo: repo}
|
||||||
|
|
||||||
|
groupIDs := []int64{10}
|
||||||
|
schedulable := false
|
||||||
|
input := &BulkUpdateAccountsInput{
|
||||||
|
AccountIDs: []int64{1, 2, 3},
|
||||||
|
GroupIDs: &groupIDs,
|
||||||
|
Schedulable: &schedulable,
|
||||||
|
SkipMixedChannelCheck: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 2, result.Success)
|
||||||
|
require.Equal(t, 1, result.Failed)
|
||||||
|
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
|
||||||
|
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||||
|
require.Len(t, result.Results, 3)
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type balanceUserRepoStub struct {
|
||||||
|
*userRepoStub
|
||||||
|
updateErr error
|
||||||
|
updated []*User
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
|
||||||
|
if s.updateErr != nil {
|
||||||
|
return s.updateErr
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *user
|
||||||
|
s.updated = append(s.updated, &clone)
|
||||||
|
if s.userRepoStub != nil {
|
||||||
|
s.userRepoStub.user = &clone
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type balanceRedeemRepoStub struct {
|
||||||
|
*redeemRepoStub
|
||||||
|
created []*RedeemCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||||
|
if code == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
clone := *code
|
||||||
|
s.created = append(s.created, &clone)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type authCacheInvalidatorStub struct {
|
||||||
|
userIDs []int64
|
||||||
|
groupIDs []int64
|
||||||
|
keys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||||
|
s.keys = append(s.keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||||
|
s.userIDs = append(s.userIDs, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||||
|
s.groupIDs = append(s.groupIDs, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
|
||||||
|
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||||
|
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||||
|
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: redeemRepo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||||
|
require.Len(t, redeemRepo.created, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
|
||||||
|
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||||
|
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||||
|
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: repo,
|
||||||
|
redeemCodeRepo: redeemRepo,
|
||||||
|
authCacheInvalidator: invalidator,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, invalidator.userIDs)
|
||||||
|
require.Empty(t, redeemRepo.created)
|
||||||
|
}
|
||||||
46
backend/internal/service/api_key_auth_cache.go
Normal file
46
backend/internal/service/api_key_auth_cache.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||||
|
type APIKeyAuthSnapshot struct {
|
||||||
|
APIKeyID int64 `json:"api_key_id"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
GroupID *int64 `json:"group_id,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||||
|
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||||
|
User APIKeyAuthUserSnapshot `json:"user"`
|
||||||
|
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthUserSnapshot 用户快照
|
||||||
|
type APIKeyAuthUserSnapshot struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Balance float64 `json:"balance"`
|
||||||
|
Concurrency int `json:"concurrency"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
|
type APIKeyAuthGroupSnapshot struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Platform string `json:"platform"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
SubscriptionType string `json:"subscription_type"`
|
||||||
|
RateMultiplier float64 `json:"rate_multiplier"`
|
||||||
|
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||||
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||||
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||||
|
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||||
|
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||||
|
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||||
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||||
|
type APIKeyAuthCacheEntry struct {
|
||||||
|
NotFound bool `json:"not_found"`
|
||||||
|
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
|
||||||
|
}
|
||||||
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/dgraph-io/ristretto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type apiKeyAuthCacheConfig struct {
|
||||||
|
l1Size int
|
||||||
|
l1TTL time.Duration
|
||||||
|
l2TTL time.Duration
|
||||||
|
negativeTTL time.Duration
|
||||||
|
jitterPercent int
|
||||||
|
singleflight bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
jitterRandMu sync.Mutex
|
||||||
|
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||||
|
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||||
|
if cfg == nil {
|
||||||
|
return apiKeyAuthCacheConfig{}
|
||||||
|
}
|
||||||
|
auth := cfg.APIKeyAuth
|
||||||
|
return apiKeyAuthCacheConfig{
|
||||||
|
l1Size: auth.L1Size,
|
||||||
|
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
|
||||||
|
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
|
||||||
|
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
|
||||||
|
jitterPercent: auth.JitterPercent,
|
||||||
|
singleflight: auth.Singleflight,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
|
||||||
|
return c.l1Size > 0 && c.l1TTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
|
||||||
|
return c.l2TTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||||
|
return c.negativeTTL > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||||
|
if ttl <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
if c.jitterPercent <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
percent := c.jitterPercent
|
||||||
|
if percent > 100 {
|
||||||
|
percent = 100
|
||||||
|
}
|
||||||
|
delta := float64(percent) / 100
|
||||||
|
jitterRandMu.Lock()
|
||||||
|
randVal := jitterRand.Float64()
|
||||||
|
jitterRandMu.Unlock()
|
||||||
|
factor := 1 - delta + randVal*(2*delta)
|
||||||
|
if factor <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
return time.Duration(float64(ttl) * factor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||||
|
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
|
||||||
|
if !s.authCfg.l1Enabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||||
|
NumCounters: int64(s.authCfg.l1Size) * 10,
|
||||||
|
MaxCost: int64(s.authCfg.l1Size),
|
||||||
|
BufferItems: 64,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.authCacheL1 = cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) authCacheKey(key string) string {
|
||||||
|
sum := sha256.Sum256([]byte(key))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
|
||||||
|
if s.authCacheL1 != nil {
|
||||||
|
if val, ok := s.authCacheL1.Get(cacheKey); ok {
|
||||||
|
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
s.setAuthCacheL1(cacheKey, entry)
|
||||||
|
return entry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
|
||||||
|
if s.authCacheL1 == nil || entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ttl := s.authCfg.l1TTL
|
||||||
|
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
|
||||||
|
ttl = s.authCfg.negativeTTL
|
||||||
|
}
|
||||||
|
ttl = s.authCfg.jitterTTL(ttl)
|
||||||
|
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
|
||||||
|
if entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.setAuthCacheL1(cacheKey, entry)
|
||||||
|
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||||
|
if s.authCacheL1 != nil {
|
||||||
|
s.authCacheL1.Del(cacheKey)
|
||||||
|
}
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrAPIKeyNotFound) {
|
||||||
|
entry := &APIKeyAuthCacheEntry{NotFound: true}
|
||||||
|
if s.authCfg.negativeEnabled() {
|
||||||
|
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
|
||||||
|
}
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
apiKey.Key = key
|
||||||
|
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||||
|
if snapshot == nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||||
|
}
|
||||||
|
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
|
||||||
|
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
|
||||||
|
if entry == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if entry.NotFound {
|
||||||
|
return nil, true, ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
if entry.Snapshot == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||||
|
if apiKey == nil || apiKey.User == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
snapshot := &APIKeyAuthSnapshot{
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
UserID: apiKey.UserID,
|
||||||
|
GroupID: apiKey.GroupID,
|
||||||
|
Status: apiKey.Status,
|
||||||
|
IPWhitelist: apiKey.IPWhitelist,
|
||||||
|
IPBlacklist: apiKey.IPBlacklist,
|
||||||
|
User: APIKeyAuthUserSnapshot{
|
||||||
|
ID: apiKey.User.ID,
|
||||||
|
Status: apiKey.User.Status,
|
||||||
|
Role: apiKey.User.Role,
|
||||||
|
Balance: apiKey.User.Balance,
|
||||||
|
Concurrency: apiKey.User.Concurrency,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if apiKey.Group != nil {
|
||||||
|
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||||
|
ID: apiKey.Group.ID,
|
||||||
|
Name: apiKey.Group.Name,
|
||||||
|
Platform: apiKey.Group.Platform,
|
||||||
|
Status: apiKey.Group.Status,
|
||||||
|
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||||
|
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||||
|
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||||
|
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||||
|
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||||
|
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||||
|
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||||
|
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
|
||||||
|
if snapshot == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
apiKey := &APIKey{
|
||||||
|
ID: snapshot.APIKeyID,
|
||||||
|
UserID: snapshot.UserID,
|
||||||
|
GroupID: snapshot.GroupID,
|
||||||
|
Key: key,
|
||||||
|
Status: snapshot.Status,
|
||||||
|
IPWhitelist: snapshot.IPWhitelist,
|
||||||
|
IPBlacklist: snapshot.IPBlacklist,
|
||||||
|
User: &User{
|
||||||
|
ID: snapshot.User.ID,
|
||||||
|
Status: snapshot.User.Status,
|
||||||
|
Role: snapshot.User.Role,
|
||||||
|
Balance: snapshot.User.Balance,
|
||||||
|
Concurrency: snapshot.User.Concurrency,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if snapshot.Group != nil {
|
||||||
|
apiKey.Group = &Group{
|
||||||
|
ID: snapshot.Group.ID,
|
||||||
|
Name: snapshot.Group.Name,
|
||||||
|
Platform: snapshot.Group.Platform,
|
||||||
|
Status: snapshot.Group.Status,
|
||||||
|
Hydrated: true,
|
||||||
|
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||||
|
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||||
|
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||||
|
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||||
|
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||||
|
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||||
|
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||||
|
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return apiKey
|
||||||
|
}
|
||||||
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cacheKey := s.authCacheKey(key)
|
||||||
|
s.deleteAuthCache(ctx, cacheKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||||
|
if userID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.deleteAuthCacheByKeys(ctx, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
|
||||||
|
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||||
|
if groupID <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.deleteAuthCacheByKeys(ctx, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.deleteAuthCache(ctx, s.authCacheKey(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/dgraph-io/ristretto"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -31,9 +33,11 @@ const (
|
|||||||
type APIKeyRepository interface {
|
type APIKeyRepository interface {
|
||||||
Create(ctx context.Context, key *APIKey) error
|
Create(ctx context.Context, key *APIKey) error
|
||||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景
|
||||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
|
||||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||||
|
// GetByKeyForAuth 认证专用查询,返回最小字段集
|
||||||
|
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
|
||||||
Update(ctx context.Context, key *APIKey) error
|
Update(ctx context.Context, key *APIKey) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
@@ -45,6 +49,8 @@ type APIKeyRepository interface {
|
|||||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||||
|
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||||
|
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyCache defines cache operations for API key service
|
// APIKeyCache defines cache operations for API key service
|
||||||
@@ -55,6 +61,17 @@ type APIKeyCache interface {
|
|||||||
|
|
||||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||||
|
|
||||||
|
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||||
|
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||||
|
DeleteAuthCache(ctx context.Context, key string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||||
|
type APIKeyAuthCacheInvalidator interface {
|
||||||
|
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||||
|
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
|
||||||
|
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAPIKeyRequest 创建API Key请求
|
// CreateAPIKeyRequest 创建API Key请求
|
||||||
@@ -83,6 +100,9 @@ type APIKeyService struct {
|
|||||||
userSubRepo UserSubscriptionRepository
|
userSubRepo UserSubscriptionRepository
|
||||||
cache APIKeyCache
|
cache APIKeyCache
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
authCacheL1 *ristretto.Cache
|
||||||
|
authCfg apiKeyAuthCacheConfig
|
||||||
|
authGroup singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIKeyService 创建API Key服务实例
|
// NewAPIKeyService 创建API Key服务实例
|
||||||
@@ -94,7 +114,7 @@ func NewAPIKeyService(
|
|||||||
cache APIKeyCache,
|
cache APIKeyCache,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *APIKeyService {
|
) *APIKeyService {
|
||||||
return &APIKeyService{
|
svc := &APIKeyService{
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
@@ -102,6 +122,8 @@ func NewAPIKeyService(
|
|||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
|
svc.initAuthCache(cfg)
|
||||||
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateKey 生成随机API Key
|
// GenerateKey 生成随机API Key
|
||||||
@@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
|||||||
return nil, fmt.Errorf("create api key: %w", err)
|
return nil, fmt.Errorf("create api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
|||||||
|
|
||||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
// 尝试从Redis缓存获取
|
cacheKey := s.authCacheKey(key)
|
||||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
|
||||||
|
|
||||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
|
||||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.authCfg.singleflight {
|
||||||
|
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
|
||||||
|
return s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entry, _ := value.(*APIKeyAuthCacheEntry)
|
||||||
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get api key: %w", err)
|
return nil, fmt.Errorf("get api key: %w", err)
|
||||||
}
|
}
|
||||||
|
apiKey.Key = key
|
||||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
|
||||||
if s.cache != nil {
|
|
||||||
// 这里可以序列化并缓存API Key
|
|
||||||
_ = cacheKey // 使用变量避免未使用错误
|
|
||||||
}
|
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
|||||||
return nil, fmt.Errorf("update api key: %w", err)
|
return nil, fmt.Errorf("update api key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||||
|
|
||||||
return apiKey, nil
|
return apiKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete 删除API Key
|
// Delete 删除API Key
|
||||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
|
||||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
|
||||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
|
||||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get api key: %w", err)
|
return fmt.Errorf("get api key: %w", err)
|
||||||
}
|
}
|
||||||
@@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
|||||||
return ErrInsufficientPerms
|
return ErrInsufficientPerms
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清除Redis缓存(使用 ownerID 而非 apiKey.UserID)
|
// 清除Redis缓存(使用 userID 而非 apiKey.UserID)
|
||||||
if s.cache != nil {
|
if s.cache != nil {
|
||||||
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
|
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
|
||||||
}
|
}
|
||||||
|
s.InvalidateAuthCacheByKey(ctx, key)
|
||||||
|
|
||||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||||
return fmt.Errorf("delete api key: %w", err)
|
return fmt.Errorf("delete api key: %w", err)
|
||||||
|
|||||||
417
backend/internal/service/api_key_service_cache_test.go
Normal file
417
backend/internal/service/api_key_service_cache_test.go
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authRepoStub struct {
|
||||||
|
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
|
||||||
|
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
|
||||||
|
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||||
|
panic("unexpected Create call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
|
panic("unexpected GetKeyAndOwnerID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKey call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
if s.getByKeyForAuth == nil {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
return s.getByKeyForAuth(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||||
|
panic("unexpected Update call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||||
|
panic("unexpected VerifyOwnership call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||||
|
panic("unexpected CountByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||||
|
panic("unexpected ExistsByKey call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||||
|
panic("unexpected ListByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||||
|
panic("unexpected SearchAPIKeys call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
panic("unexpected ClearGroupIDByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
panic("unexpected CountByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
if s.listKeysByUserID == nil {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
return s.listKeysByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
if s.listKeysByGroupID == nil {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
return s.listKeysByGroupID(ctx, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type authCacheStub struct {
|
||||||
|
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||||
|
setAuthKeys []string
|
||||||
|
deleteAuthKeys []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
if s.getAuthCache == nil {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
return s.getAuthCache(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
s.setAuthKeys = append(s.setAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, errors.New("unexpected repo call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
groupID := int64(9)
|
||||||
|
cacheEntry := &APIKeyAuthCacheEntry{
|
||||||
|
Snapshot: &APIKeyAuthSnapshot{
|
||||||
|
APIKeyID: 1,
|
||||||
|
UserID: 2,
|
||||||
|
GroupID: &groupID,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: APIKeyAuthUserSnapshot{
|
||||||
|
ID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 10,
|
||||||
|
Concurrency: 3,
|
||||||
|
},
|
||||||
|
Group: &APIKeyAuthGroupSnapshot{
|
||||||
|
ID: groupID,
|
||||||
|
Name: "g",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Status: StatusActive,
|
||||||
|
SubscriptionType: SubscriptionTypeStandard,
|
||||||
|
RateMultiplier: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return cacheEntry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := svc.GetByKey(context.Background(), "k1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(1), apiKey.ID)
|
||||||
|
require.Equal(t, int64(2), apiKey.User.ID)
|
||||||
|
require.Equal(t, groupID, apiKey.Group.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, errors.New("unexpected repo call")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "missing")
|
||||||
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return &APIKey{
|
||||||
|
ID: 5,
|
||||||
|
UserID: 7,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 7,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 12,
|
||||||
|
Concurrency: 2,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := svc.GetByKey(context.Background(), "k2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(5), apiKey.ID)
|
||||||
|
require.Len(t, cache.setAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &APIKey{
|
||||||
|
ID: 21,
|
||||||
|
UserID: 3,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 3,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 5,
|
||||||
|
Concurrency: 2,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L1Size: 1000,
|
||||||
|
L1TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
require.NotNil(t, svc.authCacheL1)
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
svc.authCacheL1.Wait()
|
||||||
|
cacheKey := svc.authCacheKey("k-l1")
|
||||||
|
_, ok := svc.authCacheL1.Get(cacheKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, err = svc.GetByKey(context.Background(), "k-l1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return []string{"k1", "k2"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
return []string{"k1", "k2"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||||
|
require.Len(t, cache.deleteAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
return nil, ErrAPIKeyNotFound
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
L2TTLSeconds: 60,
|
||||||
|
NegativeTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, redis.Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.GetByKey(context.Background(), "missing")
|
||||||
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
|
require.Len(t, cache.setAuthKeys, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
cache := &authCacheStub{}
|
||||||
|
repo := &authRepoStub{
|
||||||
|
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
return &APIKey{
|
||||||
|
ID: 11,
|
||||||
|
UserID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
User: &User{
|
||||||
|
ID: 2,
|
||||||
|
Status: StatusActive,
|
||||||
|
Role: RoleUser,
|
||||||
|
Balance: 1,
|
||||||
|
Concurrency: 1,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||||
|
Singleflight: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||||
|
|
||||||
|
start := make(chan struct{})
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
errs := make([]error, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
_, err := svc.GetByKey(context.Background(), "k1")
|
||||||
|
errs[idx] = err
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for _, err := range errs {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||||
|
}
|
||||||
@@ -20,13 +20,12 @@ import (
|
|||||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||||
//
|
//
|
||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
|
||||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
|
||||||
// - deleteErr: 模拟 Delete 返回的错误
|
// - deleteErr: 模拟 Delete 返回的错误
|
||||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||||
type apiKeyRepoStub struct {
|
type apiKeyRepoStub struct {
|
||||||
ownerID int64 // GetOwnerID 的返回值
|
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||||
ownerErr error // GetOwnerID 的错误返回值
|
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||||
deleteErr error // Delete 的错误返回值
|
deleteErr error // Delete 的错误返回值
|
||||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||||
}
|
}
|
||||||
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||||
|
if s.getByIDErr != nil {
|
||||||
|
return nil, s.getByIDErr
|
||||||
|
}
|
||||||
|
if s.apiKey != nil {
|
||||||
|
clone := *s.apiKey
|
||||||
|
return &clone, nil
|
||||||
|
}
|
||||||
panic("unexpected GetByID call")
|
panic("unexpected GetByID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOwnerID 返回预设的所有者 ID 或错误。
|
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
|
if s.getByIDErr != nil {
|
||||||
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
return "", 0, s.getByIDErr
|
||||||
return s.ownerID, s.ownerErr
|
}
|
||||||
|
if s.apiKey != nil {
|
||||||
|
return s.apiKey.Key, s.apiKey.UserID, nil
|
||||||
|
}
|
||||||
|
return "", 0, ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||||
panic("unexpected GetByKey call")
|
panic("unexpected GetByKey call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||||
|
panic("unexpected GetByKeyForAuth call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||||
panic("unexpected Update call")
|
panic("unexpected Update call")
|
||||||
}
|
}
|
||||||
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
|||||||
panic("unexpected CountByGroupID call")
|
panic("unexpected CountByGroupID call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByUserID call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||||
|
panic("unexpected ListKeysByGroupID call")
|
||||||
|
}
|
||||||
|
|
||||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||||
//
|
//
|
||||||
// 设计说明:
|
// 设计说明:
|
||||||
// - invalidated: 记录被清除缓存的用户 ID 列表
|
// - invalidated: 记录被清除缓存的用户 ID 列表
|
||||||
type apiKeyCacheStub struct {
|
type apiKeyCacheStub struct {
|
||||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||||
|
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
||||||
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||||
|
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回所有者 ID 为 1
|
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||||
// - 调用者 userID 为 2(不匹配)
|
// - 调用者 userID 为 2(不匹配)
|
||||||
// - 返回 ErrInsufficientPerms 错误
|
// - 返回 ErrInsufficientPerms 错误
|
||||||
// - Delete 方法不被调用
|
// - Delete 方法不被调用
|
||||||
// - 缓存不被清除
|
// - 缓存不被清除
|
||||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerID: 1}
|
repo := &apiKeyRepoStub{
|
||||||
|
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
|
||||||
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||||
|
require.Empty(t, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回所有者 ID 为 7
|
// - GetKeyAndOwnerID 返回所有者 ID 为 7
|
||||||
// - 调用者 userID 为 7(匹配)
|
// - 调用者 userID 为 7(匹配)
|
||||||
// - Delete 成功执行
|
// - Delete 成功执行
|
||||||
// - 缓存被正确清除(使用 ownerID)
|
// - 缓存被正确清除(使用 ownerID)
|
||||||
// - 返回 nil 错误
|
// - 返回 nil 错误
|
||||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerID: 7}
|
repo := &apiKeyRepoStub{
|
||||||
|
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
|
||||||
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||||
// - Delete 方法不被调用
|
// - Delete 方法不被调用
|
||||||
// - 缓存不被清除
|
// - 缓存不被清除
|
||||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||||
|
|
||||||
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
|||||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||||
require.Empty(t, repo.deletedIDs)
|
require.Empty(t, repo.deletedIDs)
|
||||||
require.Empty(t, cache.invalidated)
|
require.Empty(t, cache.invalidated)
|
||||||
|
require.Empty(t, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||||
// 预期行为:
|
// 预期行为:
|
||||||
// - GetOwnerID 返回正确的所有者 ID
|
// - GetKeyAndOwnerID 返回正确的所有者 ID
|
||||||
// - 所有权验证通过
|
// - 所有权验证通过
|
||||||
// - 缓存被清除(在删除之前)
|
// - 缓存被清除(在删除之前)
|
||||||
// - Delete 被调用但返回错误
|
// - Delete 被调用但返回错误
|
||||||
// - 返回包含 "delete api key" 的错误信息
|
// - 返回包含 "delete api key" 的错误信息
|
||||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||||
repo := &apiKeyRepoStub{
|
repo := &apiKeyRepoStub{
|
||||||
ownerID: 3,
|
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
|
||||||
deleteErr: errors.New("delete failed"),
|
deleteErr: errors.New("delete failed"),
|
||||||
}
|
}
|
||||||
cache := &apiKeyCacheStub{}
|
cache := &apiKeyCacheStub{}
|
||||||
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
|||||||
require.ErrorContains(t, err, "delete api key")
|
require.ErrorContains(t, err, "delete api key")
|
||||||
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
||||||
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
||||||
|
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||||
}
|
}
|
||||||
|
|||||||
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &UsageService{authCacheInvalidator: invalidator}
|
||||||
|
|
||||||
|
svc.invalidateUsageCaches(context.Background(), 7, false)
|
||||||
|
require.Empty(t, invalidator.userIDs)
|
||||||
|
|
||||||
|
svc.invalidateUsageCaches(context.Background(), 7, true)
|
||||||
|
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
|
||||||
|
invalidator := &authCacheInvalidatorStub{}
|
||||||
|
svc := &RedeemService{authCacheInvalidator: invalidator}
|
||||||
|
|
||||||
|
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
|
||||||
|
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
|
||||||
|
groupID := int64(3)
|
||||||
|
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
|
||||||
|
|
||||||
|
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
|
||||||
|
}
|
||||||
242
backend/internal/service/dashboard_aggregation_service.go
Normal file
242
backend/internal/service/dashboard_aggregation_service.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultDashboardAggregationTimeout = 2 * time.Minute
|
||||||
|
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
|
||||||
|
dashboardAggregationRetentionInterval = 6 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||||
|
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||||
|
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||||
|
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||||
|
)
|
||||||
|
|
||||||
|
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||||
|
type DashboardAggregationRepository interface {
|
||||||
|
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||||
|
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||||
|
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||||
|
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||||
|
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||||
|
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardAggregationService 负责定时聚合与回填。
|
||||||
|
type DashboardAggregationService struct {
|
||||||
|
repo DashboardAggregationRepository
|
||||||
|
timingWheel *TimingWheelService
|
||||||
|
cfg config.DashboardAggregationConfig
|
||||||
|
running int32
|
||||||
|
lastRetentionCleanup atomic.Value // time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDashboardAggregationService 创建聚合服务。
|
||||||
|
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||||
|
var aggCfg config.DashboardAggregationConfig
|
||||||
|
if cfg != nil {
|
||||||
|
aggCfg = cfg.DashboardAgg
|
||||||
|
}
|
||||||
|
return &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
timingWheel: timingWheel,
|
||||||
|
cfg: aggCfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动定时聚合作业(重启生效配置)。
|
||||||
|
func (s *DashboardAggregationService) Start() {
|
||||||
|
if s == nil || s.repo == nil || s.timingWheel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.cfg.Enabled {
|
||||||
|
log.Printf("[DashboardAggregation] 聚合作业已禁用")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.cfg.RecomputeDays > 0 {
|
||||||
|
go s.recomputeRecentDays()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
|
||||||
|
s.runScheduledAggregation()
|
||||||
|
})
|
||||||
|
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||||
|
if !s.cfg.BackfillEnabled {
|
||||||
|
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerBackfill 触发回填(异步)。
|
||||||
|
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
|
||||||
|
if s == nil || s.repo == nil {
|
||||||
|
return errors.New("聚合服务未初始化")
|
||||||
|
}
|
||||||
|
if !s.cfg.BackfillEnabled {
|
||||||
|
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||||
|
return ErrDashboardBackfillDisabled
|
||||||
|
}
|
||||||
|
if !end.After(start) {
|
||||||
|
return errors.New("回填时间范围无效")
|
||||||
|
}
|
||||||
|
if s.cfg.BackfillMaxDays > 0 {
|
||||||
|
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
|
||||||
|
if end.Sub(start) > maxRange {
|
||||||
|
return ErrDashboardBackfillTooLarge
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.backfillRange(ctx, start, end); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 回填失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||||
|
days := s.cfg.RecomputeDays
|
||||||
|
if days <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
start := now.AddDate(0, 0, -days)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.backfillRange(ctx, start, now); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer atomic.StoreInt32(&s.running, 0)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
last, err := s.repo.GetAggregationWatermark(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
|
||||||
|
last = time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
|
||||||
|
epoch := time.Unix(0, 0).UTC()
|
||||||
|
start := last.Add(-lookback)
|
||||||
|
if !last.After(epoch) {
|
||||||
|
retentionDays := s.cfg.Retention.UsageLogsDays
|
||||||
|
if retentionDays <= 0 {
|
||||||
|
retentionDays = 1
|
||||||
|
}
|
||||||
|
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
|
||||||
|
} else if start.After(now) {
|
||||||
|
start = now.Add(-lookback)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.aggregateRange(ctx, start, now); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.repo.UpdateAggregationWatermark(ctx, now); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 更新水位失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.maybeCleanupRetention(ctx, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||||
|
return errors.New("聚合作业正在运行")
|
||||||
|
}
|
||||||
|
defer atomic.StoreInt32(&s.running, 0)
|
||||||
|
|
||||||
|
startUTC := start.UTC()
|
||||||
|
endUTC := end.UTC()
|
||||||
|
if !endUTC.After(startUTC) {
|
||||||
|
return errors.New("回填时间范围无效")
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor := truncateToDayUTC(startUTC)
|
||||||
|
for cursor.Before(endUTC) {
|
||||||
|
windowEnd := cursor.Add(24 * time.Hour)
|
||||||
|
if windowEnd.After(endUTC) {
|
||||||
|
windowEnd = endUTC
|
||||||
|
}
|
||||||
|
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cursor = windowEnd
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.repo.UpdateAggregationWatermark(ctx, endUTC); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 更新水位失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.maybeCleanupRetention(ctx, endUTC)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
if !end.After(start) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
|
||||||
|
}
|
||||||
|
return s.repo.AggregateRange(ctx, start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
|
||||||
|
lastAny := s.lastRetentionCleanup.Load()
|
||||||
|
if lastAny != nil {
|
||||||
|
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||||
|
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||||
|
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||||
|
|
||||||
|
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||||
|
if aggErr != nil {
|
||||||
|
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||||
|
}
|
||||||
|
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
|
||||||
|
if usageErr != nil {
|
||||||
|
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||||
|
}
|
||||||
|
if aggErr == nil && usageErr == nil {
|
||||||
|
s.lastRetentionCleanup.Store(now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateToDayUTC(t time.Time) time.Time {
|
||||||
|
t = t.UTC()
|
||||||
|
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dashboardAggregationRepoTestStub struct {
|
||||||
|
aggregateCalls int
|
||||||
|
lastStart time.Time
|
||||||
|
lastEnd time.Time
|
||||||
|
watermark time.Time
|
||||||
|
aggregateErr error
|
||||||
|
cleanupAggregatesErr error
|
||||||
|
cleanupUsageErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
s.aggregateCalls++
|
||||||
|
s.lastStart = start
|
||||||
|
s.lastEnd = end
|
||||||
|
return s.aggregateErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||||
|
return s.watermark, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||||
|
return s.cleanupAggregatesErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
return s.cleanupUsageErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
IntervalSeconds: 60,
|
||||||
|
LookbackSeconds: 120,
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.runScheduledAggregation()
|
||||||
|
|
||||||
|
require.Equal(t, 1, repo.aggregateCalls)
|
||||||
|
require.False(t, repo.lastEnd.IsZero())
|
||||||
|
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 1,
|
||||||
|
HourlyDays: 1,
|
||||||
|
DailyDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||||
|
|
||||||
|
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||||
|
repo := &dashboardAggregationRepoTestStub{}
|
||||||
|
svc := &DashboardAggregationService{
|
||||||
|
repo: repo,
|
||||||
|
cfg: config.DashboardAggregationConfig{
|
||||||
|
BackfillEnabled: true,
|
||||||
|
BackfillMaxDays: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now().AddDate(0, 0, -3)
|
||||||
|
end := time.Now()
|
||||||
|
err := svc.TriggerBackfill(start, end)
|
||||||
|
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
|
||||||
|
require.Equal(t, 0, repo.aggregateCalls)
|
||||||
|
}
|
||||||
@@ -2,25 +2,119 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DashboardService provides aggregated statistics for admin dashboard.
|
const (
|
||||||
type DashboardService struct {
|
defaultDashboardStatsFreshTTL = 15 * time.Second
|
||||||
usageRepo UsageLogRepository
|
defaultDashboardStatsCacheTTL = 30 * time.Second
|
||||||
|
defaultDashboardStatsRefreshTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
|
||||||
|
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
|
||||||
|
|
||||||
|
// DashboardStatsCache 定义仪表盘统计缓存接口。
|
||||||
|
type DashboardStatsCache interface {
|
||||||
|
GetDashboardStats(ctx context.Context) (string, error)
|
||||||
|
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
|
||||||
|
DeleteDashboardStats(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
|
type dashboardStatsRangeFetcher interface {
|
||||||
|
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardStatsCacheEntry struct {
|
||||||
|
Stats *usagestats.DashboardStats `json:"stats"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardService 提供管理员仪表盘统计服务。
|
||||||
|
type DashboardService struct {
|
||||||
|
usageRepo UsageLogRepository
|
||||||
|
aggRepo DashboardAggregationRepository
|
||||||
|
cache DashboardStatsCache
|
||||||
|
cacheFreshTTL time.Duration
|
||||||
|
cacheTTL time.Duration
|
||||||
|
refreshTimeout time.Duration
|
||||||
|
refreshing int32
|
||||||
|
aggEnabled bool
|
||||||
|
aggInterval time.Duration
|
||||||
|
aggLookback time.Duration
|
||||||
|
aggUsageDays int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
|
||||||
|
freshTTL := defaultDashboardStatsFreshTTL
|
||||||
|
cacheTTL := defaultDashboardStatsCacheTTL
|
||||||
|
refreshTimeout := defaultDashboardStatsRefreshTimeout
|
||||||
|
aggEnabled := true
|
||||||
|
aggInterval := time.Minute
|
||||||
|
aggLookback := 2 * time.Minute
|
||||||
|
aggUsageDays := 90
|
||||||
|
if cfg != nil {
|
||||||
|
if !cfg.Dashboard.Enabled {
|
||||||
|
cache = nil
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
|
||||||
|
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsTTLSeconds > 0 {
|
||||||
|
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
|
||||||
|
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
|
||||||
|
}
|
||||||
|
aggEnabled = cfg.DashboardAgg.Enabled
|
||||||
|
if cfg.DashboardAgg.IntervalSeconds > 0 {
|
||||||
|
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.LookbackSeconds > 0 {
|
||||||
|
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
|
||||||
|
}
|
||||||
|
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
|
||||||
|
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
|
||||||
|
}
|
||||||
|
}
|
||||||
return &DashboardService{
|
return &DashboardService{
|
||||||
usageRepo: usageRepo,
|
usageRepo: usageRepo,
|
||||||
|
aggRepo: aggRepo,
|
||||||
|
cache: cache,
|
||||||
|
cacheFreshTTL: freshTTL,
|
||||||
|
cacheTTL: cacheTTL,
|
||||||
|
refreshTimeout: refreshTimeout,
|
||||||
|
aggEnabled: aggEnabled,
|
||||||
|
aggInterval: aggInterval,
|
||||||
|
aggLookback: aggLookback,
|
||||||
|
aggUsageDays: aggUsageDays,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
stats, err := s.usageRepo.GetDashboardStats(ctx)
|
if s.cache != nil {
|
||||||
|
cached, fresh, err := s.getCachedDashboardStats(ctx)
|
||||||
|
if err == nil && cached != nil {
|
||||||
|
s.refreshAggregationStaleness(cached)
|
||||||
|
if !fresh {
|
||||||
|
s.refreshDashboardStatsAsync()
|
||||||
|
}
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := s.refreshDashboardStats(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
||||||
}
|
}
|
||||||
@@ -43,6 +137,169 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||||
|
data, err := s.cache.GetDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var entry dashboardStatsCacheEntry
|
||||||
|
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||||||
|
s.evictDashboardStatsCache(err)
|
||||||
|
return nil, false, ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
if entry.Stats == nil {
|
||||||
|
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
|
||||||
|
return nil, false, ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
|
||||||
|
age := time.Since(time.Unix(entry.UpdatedAt, 0))
|
||||||
|
return entry.Stats, age <= s.cacheFreshTTL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
|
stats, err := s.fetchDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.applyAggregationStatus(ctx, stats)
|
||||||
|
cacheCtx, cancel := s.cacheOperationContext()
|
||||||
|
defer cancel()
|
||||||
|
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) refreshDashboardStatsAsync() {
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer atomic.StoreInt32(&s.refreshing, 0)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stats, err := s.fetchDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.applyAggregationStatus(ctx, stats)
|
||||||
|
cacheCtx, cancel := s.cacheOperationContext()
|
||||||
|
defer cancel()
|
||||||
|
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
|
if !s.aggEnabled {
|
||||||
|
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
|
||||||
|
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.usageRepo.GetDashboardStats(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||||
|
if s.cache == nil || stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := dashboardStatsCacheEntry{
|
||||||
|
Stats: stats,
|
||||||
|
UpdatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(entry)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) evictDashboardStatsCache(reason error) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cacheCtx, cancel := s.cacheOperationContext()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||||
|
}
|
||||||
|
if reason != nil {
|
||||||
|
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||||
|
if stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updatedAt := s.fetchAggregationUpdatedAt(ctx)
|
||||||
|
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
|
||||||
|
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
|
||||||
|
if stats == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
|
||||||
|
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
|
||||||
|
if s.aggRepo == nil {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
if updatedAt.IsZero() {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
return updatedAt.UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
|
||||||
|
if !s.aggEnabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
epoch := time.Unix(0, 0).UTC()
|
||||||
|
if !updatedAt.After(epoch) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
threshold := s.aggInterval + s.aggLookback
|
||||||
|
return now.Sub(updatedAt) > threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStatsUpdatedAt(raw string) time.Time {
|
||||||
|
if raw == "" {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
parsed, err := time.Parse(time.RFC3339, raw)
|
||||||
|
if err != nil {
|
||||||
|
return time.Unix(0, 0).UTC()
|
||||||
|
}
|
||||||
|
return parsed.UTC()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
387
backend/internal/service/dashboard_service_test.go
Normal file
387
backend/internal/service/dashboard_service_test.go
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usageRepoStub struct {
|
||||||
|
UsageLogRepository
|
||||||
|
stats *usagestats.DashboardStats
|
||||||
|
rangeStats *usagestats.DashboardStats
|
||||||
|
err error
|
||||||
|
rangeErr error
|
||||||
|
calls int32
|
||||||
|
rangeCalls int32
|
||||||
|
rangeStart time.Time
|
||||||
|
rangeEnd time.Time
|
||||||
|
onCall chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
|
atomic.AddInt32(&s.calls, 1)
|
||||||
|
if s.onCall != nil {
|
||||||
|
select {
|
||||||
|
case s.onCall <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return s.stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
|
||||||
|
atomic.AddInt32(&s.rangeCalls, 1)
|
||||||
|
s.rangeStart = start
|
||||||
|
s.rangeEnd = end
|
||||||
|
if s.rangeErr != nil {
|
||||||
|
return nil, s.rangeErr
|
||||||
|
}
|
||||||
|
if s.rangeStats != nil {
|
||||||
|
return s.rangeStats, nil
|
||||||
|
}
|
||||||
|
return s.stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardCacheStub struct {
|
||||||
|
get func(ctx context.Context) (string, error)
|
||||||
|
set func(ctx context.Context, data string, ttl time.Duration) error
|
||||||
|
del func(ctx context.Context) error
|
||||||
|
getCalls int32
|
||||||
|
setCalls int32
|
||||||
|
delCalls int32
|
||||||
|
lastSetMu sync.Mutex
|
||||||
|
lastSet string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
|
||||||
|
atomic.AddInt32(&c.getCalls, 1)
|
||||||
|
if c.get != nil {
|
||||||
|
return c.get(ctx)
|
||||||
|
}
|
||||||
|
return "", ErrDashboardStatsCacheMiss
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||||
|
atomic.AddInt32(&c.setCalls, 1)
|
||||||
|
c.lastSetMu.Lock()
|
||||||
|
c.lastSet = data
|
||||||
|
c.lastSetMu.Unlock()
|
||||||
|
if c.set != nil {
|
||||||
|
return c.set(ctx, data, ttl)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
|
||||||
|
atomic.AddInt32(&c.delCalls, 1)
|
||||||
|
if c.del != nil {
|
||||||
|
return c.del(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardAggregationRepoStub struct {
|
||||||
|
watermark time.Time
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return time.Time{}, s.err
|
||||||
|
}
|
||||||
|
return s.watermark, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
|
||||||
|
t.Helper()
|
||||||
|
c.lastSetMu.Lock()
|
||||||
|
data := c.lastSet
|
||||||
|
c.lastSetMu.Unlock()
|
||||||
|
|
||||||
|
var entry dashboardStatsCacheEntry
|
||||||
|
err := json.Unmarshal([]byte(data), &entry)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheHitFresh(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 10,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
entry := dashboardStatsCacheEntry{
|
||||||
|
Stats: stats,
|
||||||
|
UpdatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(entry)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return string(payload), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{
|
||||||
|
stats: &usagestats.DashboardStats{TotalUsers: 99},
|
||||||
|
}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 7,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "", ErrDashboardStatsCacheMiss
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
|
||||||
|
entry := cache.readLastEntry(t)
|
||||||
|
require.Equal(t, stats, entry.Stats)
|
||||||
|
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 3,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
|
||||||
|
staleStats := &usagestats.DashboardStats{
|
||||||
|
TotalUsers: 11,
|
||||||
|
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||||
|
StatsStale: true,
|
||||||
|
}
|
||||||
|
entry := dashboardStatsCacheEntry{
|
||||||
|
Stats: staleStats,
|
||||||
|
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
|
||||||
|
}
|
||||||
|
payload, err := json.Marshal(entry)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return string(payload), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
refreshCh := make(chan struct{}, 1)
|
||||||
|
repo := &usageRepoStub{
|
||||||
|
stats: &usagestats.DashboardStats{TotalUsers: 22},
|
||||||
|
onCall: refreshCh,
|
||||||
|
}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, staleStats, got)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-refreshCh:
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("等待异步刷新超时")
|
||||||
|
}
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return atomic.LoadInt32(&cache.setCalls) >= 1
|
||||||
|
}, 1*time.Second, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "not-json", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stats := &usagestats.DashboardStats{TotalUsers: 9}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, stats, got)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
|
||||||
|
cache := &dashboardCacheStub{
|
||||||
|
get: func(ctx context.Context) (string, error) {
|
||||||
|
return "not-json", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageRepoStub{err: errors.New("db down")}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||||
|
|
||||||
|
_, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
|
||||||
|
stats := &usagestats.DashboardStats{}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||||
|
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
|
||||||
|
require.True(t, got.StatsStale)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
|
||||||
|
aggNow := time.Now().UTC().Truncate(time.Second)
|
||||||
|
stats := &usagestats.DashboardStats{}
|
||||||
|
repo := &usageRepoStub{stats: stats}
|
||||||
|
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: true,
|
||||||
|
IntervalSeconds: 60,
|
||||||
|
LookbackSeconds: 120,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
|
||||||
|
require.False(t, got.StatsStale)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
|
||||||
|
expected := &usagestats.DashboardStats{TotalUsers: 42}
|
||||||
|
repo := &usageRepoStub{
|
||||||
|
rangeStats: expected,
|
||||||
|
err: errors.New("should not call aggregated stats"),
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||||
|
DashboardAgg: config.DashboardAggregationConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Retention: config.DashboardAggregationRetentionConfig{
|
||||||
|
UsageLogsDays: 7,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewDashboardService(repo, nil, nil, cfg)
|
||||||
|
|
||||||
|
got, err := svc.GetDashboardStats(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(42), got.TotalUsers)
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
|
||||||
|
require.False(t, repo.rangeEnd.IsZero())
|
||||||
|
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
|
||||||
|
}
|
||||||
@@ -50,13 +50,15 @@ type UpdateGroupRequest struct {
|
|||||||
|
|
||||||
// GroupService 分组管理服务
|
// GroupService 分组管理服务
|
||||||
type GroupService struct {
|
type GroupService struct {
|
||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupService 创建分组服务实例
|
// NewGroupService 创建分组服务实例
|
||||||
func NewGroupService(groupRepo GroupRepository) *GroupService {
|
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
|
||||||
return &GroupService{
|
return &GroupService{
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
|
|||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, fmt.Errorf("update group: %w", err)
|
return nil, fmt.Errorf("update group: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
@@ -167,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
|||||||
return fmt.Errorf("get group: %w", err)
|
return fmt.Errorf("get group: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||||
|
}
|
||||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||||
return fmt.Errorf("delete group: %w", err)
|
return fmt.Errorf("delete group: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
404
backend/internal/service/openai_codex_transform.go
Normal file
404
backend/internal/service/openai_codex_transform.go
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
|
||||||
|
codexCacheTTL = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var codexModelMap = map[string]string{
|
||||||
|
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-high": "gpt-5.1-codex",
|
||||||
|
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
|
||||||
|
"gpt-5.2": "gpt-5.2",
|
||||||
|
"gpt-5.2-none": "gpt-5.2",
|
||||||
|
"gpt-5.2-low": "gpt-5.2",
|
||||||
|
"gpt-5.2-medium": "gpt-5.2",
|
||||||
|
"gpt-5.2-high": "gpt-5.2",
|
||||||
|
"gpt-5.2-xhigh": "gpt-5.2",
|
||||||
|
"gpt-5.2-codex": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-low": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-medium": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-high": "gpt-5.2-codex",
|
||||||
|
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
|
||||||
|
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5.1": "gpt-5.1",
|
||||||
|
"gpt-5.1-none": "gpt-5.1",
|
||||||
|
"gpt-5.1-low": "gpt-5.1",
|
||||||
|
"gpt-5.1-medium": "gpt-5.1",
|
||||||
|
"gpt-5.1-high": "gpt-5.1",
|
||||||
|
"gpt-5.1-chat-latest": "gpt-5.1",
|
||||||
|
"gpt-5-codex": "gpt-5.1-codex",
|
||||||
|
"codex-mini-latest": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
|
||||||
|
"gpt-5": "gpt-5.1",
|
||||||
|
"gpt-5-mini": "gpt-5.1",
|
||||||
|
"gpt-5-nano": "gpt-5.1",
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexTransformResult struct {
|
||||||
|
Modified bool
|
||||||
|
NormalizedModel string
|
||||||
|
PromptCacheKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type opencodeCacheMetadata struct {
|
||||||
|
ETag string `json:"etag"`
|
||||||
|
LastFetch string `json:"lastFetch,omitempty"`
|
||||||
|
LastChecked int64 `json:"lastChecked"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||||
|
result := codexTransformResult{}
|
||||||
|
|
||||||
|
model := ""
|
||||||
|
if v, ok := reqBody["model"].(string); ok {
|
||||||
|
model = v
|
||||||
|
}
|
||||||
|
normalizedModel := normalizeCodexModel(model)
|
||||||
|
if normalizedModel != "" {
|
||||||
|
if model != normalizedModel {
|
||||||
|
reqBody["model"] = normalizedModel
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
result.NormalizedModel = normalizedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||||
|
reqBody["store"] = false
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||||
|
reqBody["stream"] = true
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||||
|
delete(reqBody, "max_output_tokens")
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
if _, ok := reqBody["max_completion_tokens"]; ok {
|
||||||
|
delete(reqBody, "max_completion_tokens")
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if normalizeCodexTools(reqBody) {
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||||
|
result.PromptCacheKey = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||||
|
existingInstructions, _ := reqBody["instructions"].(string)
|
||||||
|
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||||
|
|
||||||
|
if instructions != "" {
|
||||||
|
if existingInstructions != instructions {
|
||||||
|
reqBody["instructions"] = instructions
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if input, ok := reqBody["input"].([]any); ok {
|
||||||
|
input = filterCodexInput(input)
|
||||||
|
reqBody["input"] = input
|
||||||
|
result.Modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCodexModel(model string) string {
|
||||||
|
if model == "" {
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID := model
|
||||||
|
if strings.Contains(modelID, "/") {
|
||||||
|
parts := strings.Split(modelID, "/")
|
||||||
|
modelID = parts[len(parts)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := strings.ToLower(modelID)
|
||||||
|
|
||||||
|
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
|
||||||
|
return "gpt-5.2-codex"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||||
|
return "gpt-5.2"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||||
|
return "gpt-5.1-codex-max"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
|
||||||
|
return "gpt-5.1-codex-mini"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "codex-mini-latest") ||
|
||||||
|
strings.Contains(normalized, "gpt-5-codex-mini") ||
|
||||||
|
strings.Contains(normalized, "gpt 5 codex mini") {
|
||||||
|
return "codex-mini-latest"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
|
||||||
|
return "gpt-5.1-codex"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "codex") {
|
||||||
|
return "gpt-5.1-codex"
|
||||||
|
}
|
||||||
|
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "gpt-5.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNormalizedCodexModel(modelID string) string {
|
||||||
|
if modelID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if mapped, ok := codexModelMap[modelID]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(modelID)
|
||||||
|
for key, value := range codexModelMap {
|
||||||
|
if strings.ToLower(key) == lower {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
||||||
|
cacheDir := codexCachePath("")
|
||||||
|
if cacheDir == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
cacheFile := filepath.Join(cacheDir, cacheFileName)
|
||||||
|
metaFile := filepath.Join(cacheDir, metaFileName)
|
||||||
|
|
||||||
|
var cachedContent string
|
||||||
|
if content, ok := readFile(cacheFile); ok {
|
||||||
|
cachedContent = content
|
||||||
|
}
|
||||||
|
|
||||||
|
var meta opencodeCacheMetadata
|
||||||
|
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
|
||||||
|
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
|
||||||
|
return cachedContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content, etag, status, err := fetchWithETag(url, meta.ETag)
|
||||||
|
if err == nil && status == http.StatusNotModified && cachedContent != "" {
|
||||||
|
return cachedContent
|
||||||
|
}
|
||||||
|
if err == nil && status >= 200 && status < 300 && content != "" {
|
||||||
|
_ = writeFile(cacheFile, content)
|
||||||
|
meta = opencodeCacheMetadata{
|
||||||
|
ETag: etag,
|
||||||
|
LastFetch: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
LastChecked: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
_ = writeJSON(metaFile, meta)
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
return cachedContent
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOpenCodeCodexHeader() string {
|
||||||
|
return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOpenCodeInstructions() string {
|
||||||
|
return getOpenCodeCodexHeader()
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterCodexInput(input []any) []any {
|
||||||
|
filtered := make([]any, 0, len(input))
|
||||||
|
for _, item := range input {
|
||||||
|
m, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(m, "id")
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCodexTools(reqBody map[string]any) bool {
|
||||||
|
rawTools, ok := reqBody["tools"]
|
||||||
|
if !ok || rawTools == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tools, ok := rawTools.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for idx, tool := range tools {
|
||||||
|
toolMap, ok := tool.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolType, _ := toolMap["type"].(string)
|
||||||
|
if strings.TrimSpace(toolType) != "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
function, ok := toolMap["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := toolMap["name"]; !ok {
|
||||||
|
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||||
|
toolMap["name"] = name
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["description"]; !ok {
|
||||||
|
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
|
||||||
|
toolMap["description"] = desc
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["parameters"]; !ok {
|
||||||
|
if params, ok := function["parameters"]; ok {
|
||||||
|
toolMap["parameters"] = params
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["strict"]; !ok {
|
||||||
|
if strict, ok := function["strict"]; ok {
|
||||||
|
toolMap["strict"] = strict
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tools[idx] = toolMap
|
||||||
|
}
|
||||||
|
|
||||||
|
if modified {
|
||||||
|
reqBody["tools"] = tools
|
||||||
|
}
|
||||||
|
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
func codexCachePath(filename string) string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
cacheDir := filepath.Join(home, ".opencode", "cache")
|
||||||
|
if filename == "" {
|
||||||
|
return cacheDir
|
||||||
|
}
|
||||||
|
return filepath.Join(cacheDir, filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readFile(path string) (string, bool) {
|
||||||
|
if path == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return string(data), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFile(path, content string) error {
|
||||||
|
if path == "" {
|
||||||
|
return fmt.Errorf("empty cache path")
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, []byte(content), 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadJSON(path string, target any) bool {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, target); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(path string, value any) error {
|
||||||
|
if path == "" {
|
||||||
|
return fmt.Errorf("empty json path")
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, data, 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchWithETag(url, etag string) (string, string, int, error) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "sub2api-codex")
|
||||||
|
if etag != "" {
|
||||||
|
req.Header.Set("If-None-Match", etag)
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", 0, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", resp.StatusCode, err
|
||||||
|
}
|
||||||
|
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -528,33 +530,38 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
// Extract model and stream from parsed body
|
// Extract model and stream from parsed body
|
||||||
reqModel, _ := reqBody["model"].(string)
|
reqModel, _ := reqBody["model"].(string)
|
||||||
reqStream, _ := reqBody["stream"].(bool)
|
reqStream, _ := reqBody["stream"].(bool)
|
||||||
|
promptCacheKey := ""
|
||||||
|
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||||
|
promptCacheKey = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
|
||||||
// Track if body needs re-serialization
|
// Track if body needs re-serialization
|
||||||
bodyModified := false
|
bodyModified := false
|
||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
|
|
||||||
// Apply model mapping
|
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
|
||||||
if mappedModel != reqModel {
|
// Apply model mapping (skip for Codex CLI for transparent forwarding)
|
||||||
reqBody["model"] = mappedModel
|
mappedModel := reqModel
|
||||||
bodyModified = true
|
if !isCodexCLI {
|
||||||
|
mappedModel = account.GetMappedModel(reqModel)
|
||||||
|
if mappedModel != reqModel {
|
||||||
|
reqBody["model"] = mappedModel
|
||||||
|
bodyModified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For OAuth accounts using ChatGPT internal API:
|
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
||||||
// 1. Add store: false
|
codexResult := applyCodexOAuthTransform(reqBody)
|
||||||
// 2. Normalize input format for Codex API compatibility
|
if codexResult.Modified {
|
||||||
if account.Type == AccountTypeOAuth {
|
|
||||||
reqBody["store"] = false
|
|
||||||
// Codex 上游不接受 max_output_tokens 参数,需要在转发前移除。
|
|
||||||
delete(reqBody, "max_output_tokens")
|
|
||||||
bodyModified = true
|
|
||||||
|
|
||||||
// Normalize input format: convert AI SDK multi-part content format to simplified format
|
|
||||||
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
|
|
||||||
// Codex API expects: {"content": "..."}
|
|
||||||
if normalizeInputForCodexAPI(reqBody) {
|
|
||||||
bodyModified = true
|
bodyModified = true
|
||||||
}
|
}
|
||||||
|
if codexResult.NormalizedModel != "" {
|
||||||
|
mappedModel = codexResult.NormalizedModel
|
||||||
|
}
|
||||||
|
if codexResult.PromptCacheKey != "" {
|
||||||
|
promptCacheKey = codexResult.PromptCacheKey
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-serialize body only if modified
|
// Re-serialize body only if modified
|
||||||
@@ -573,7 +580,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build upstream request
|
// Build upstream request
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -674,7 +681,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
||||||
// Determine target URL based on account type
|
// Determine target URL based on account type
|
||||||
var targetURL string
|
var targetURL string
|
||||||
switch account.Type {
|
switch account.Type {
|
||||||
@@ -714,12 +721,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
if chatgptAccountID != "" {
|
if chatgptAccountID != "" {
|
||||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||||
}
|
}
|
||||||
// Set accept header based on stream mode
|
|
||||||
if isStream {
|
|
||||||
req.Header.Set("accept", "text/event-stream")
|
|
||||||
} else {
|
|
||||||
req.Header.Set("accept", "application/json")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Whitelist passthrough headers
|
// Whitelist passthrough headers
|
||||||
@@ -731,6 +732,22 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if account.Type == AccountTypeOAuth {
|
||||||
|
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||||
|
if isCodexCLI {
|
||||||
|
req.Header.Set("originator", "codex_cli_rs")
|
||||||
|
} else {
|
||||||
|
req.Header.Set("originator", "opencode")
|
||||||
|
}
|
||||||
|
req.Header.Set("accept", "text/event-stream")
|
||||||
|
if promptCacheKey != "" {
|
||||||
|
req.Header.Set("conversation_id", promptCacheKey)
|
||||||
|
req.Header.Set("session_id", promptCacheKey)
|
||||||
|
} else {
|
||||||
|
req.Header.Del("conversation_id")
|
||||||
|
req.Header.Del("session_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply custom User-Agent if configured
|
// Apply custom User-Agent if configured
|
||||||
customUA := account.GetOpenAIUserAgent()
|
customUA := account.GetOpenAIUserAgent()
|
||||||
@@ -1109,6 +1126,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if account.Type == AccountTypeOAuth {
|
||||||
|
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
|
||||||
|
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
|
||||||
|
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Parse usage
|
// Parse usage
|
||||||
var response struct {
|
var response struct {
|
||||||
Usage struct {
|
Usage struct {
|
||||||
@@ -1148,6 +1172,110 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
|||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isEventStreamResponse(header http.Header) bool {
|
||||||
|
contentType := strings.ToLower(header.Get("Content-Type"))
|
||||||
|
return strings.Contains(contentType, "text/event-stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||||
|
bodyText := string(body)
|
||||||
|
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||||
|
|
||||||
|
usage := &OpenAIUsage{}
|
||||||
|
if ok {
|
||||||
|
var response struct {
|
||||||
|
Usage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
InputTokenDetails struct {
|
||||||
|
CachedTokens int `json:"cached_tokens"`
|
||||||
|
} `json:"input_tokens_details"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(finalResponse, &response); err == nil {
|
||||||
|
usage.InputTokens = response.Usage.InputTokens
|
||||||
|
usage.OutputTokens = response.Usage.OutputTokens
|
||||||
|
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
|
||||||
|
}
|
||||||
|
body = finalResponse
|
||||||
|
if originalModel != mappedModel {
|
||||||
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
usage = s.parseSSEUsageFromBody(bodyText)
|
||||||
|
if originalModel != mappedModel {
|
||||||
|
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
||||||
|
}
|
||||||
|
body = []byte(bodyText)
|
||||||
|
}
|
||||||
|
|
||||||
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||||
|
|
||||||
|
contentType := "application/json; charset=utf-8"
|
||||||
|
if !ok {
|
||||||
|
contentType = resp.Header.Get("Content-Type")
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "text/event-stream"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Data(resp.StatusCode, contentType, body)
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if !openaiSSEDataRe.MatchString(line) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||||
|
if data == "" || data == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var event struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Response json.RawMessage `json:"response"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal([]byte(data), &event) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if event.Type == "response.done" || event.Type == "response.completed" {
|
||||||
|
if len(event.Response) > 0 {
|
||||||
|
return event.Response, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||||
|
usage := &OpenAIUsage{}
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if !openaiSSEDataRe.MatchString(line) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||||
|
if data == "" || data == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.parseSSEUsage(data, usage)
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
if !openaiSSEDataRe.MatchString(line) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||||
@@ -1187,101 +1315,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
|||||||
return newBody
|
return newBody
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
|
|
||||||
// that the ChatGPT internal Codex API expects.
|
|
||||||
//
|
|
||||||
// AI SDK sends content as an array of typed objects:
|
|
||||||
//
|
|
||||||
// {"content": [{"type": "input_text", "text": "hello"}]}
|
|
||||||
//
|
|
||||||
// ChatGPT Codex API expects content as a simple string:
|
|
||||||
//
|
|
||||||
// {"content": "hello"}
|
|
||||||
//
|
|
||||||
// This function modifies reqBody in-place and returns true if any modification was made.
|
|
||||||
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
|
|
||||||
input, ok := reqBody["input"]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle case where input is a simple string (already compatible)
|
|
||||||
if _, isString := input.(string); isString {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle case where input is an array of messages
|
|
||||||
inputArray, ok := input.([]any)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
modified := false
|
|
||||||
for _, item := range inputArray {
|
|
||||||
message, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
content, ok := message["content"]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is already a string, no conversion needed
|
|
||||||
if _, isString := content.(string); isString {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If content is an array (AI SDK format), convert to string
|
|
||||||
contentArray, ok := content.([]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract text from content array
|
|
||||||
var textParts []string
|
|
||||||
for _, part := range contentArray {
|
|
||||||
partMap, ok := part.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle different content types
|
|
||||||
partType, _ := partMap["type"].(string)
|
|
||||||
switch partType {
|
|
||||||
case "input_text", "text":
|
|
||||||
// Extract text from input_text or text type
|
|
||||||
if text, ok := partMap["text"].(string); ok {
|
|
||||||
textParts = append(textParts, text)
|
|
||||||
}
|
|
||||||
case "input_image", "image":
|
|
||||||
// For images, we need to preserve the original format
|
|
||||||
// as ChatGPT Codex API may support images in a different way
|
|
||||||
// For now, skip image parts (they will be lost in conversion)
|
|
||||||
// TODO: Consider preserving image data or handling it separately
|
|
||||||
continue
|
|
||||||
case "input_file", "file":
|
|
||||||
// Similar to images, file inputs may need special handling
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
// For unknown types, try to extract text if available
|
|
||||||
if text, ok := partMap["text"].(string); ok {
|
|
||||||
textParts = append(textParts, text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert content array to string
|
|
||||||
if len(textParts) > 0 {
|
|
||||||
message["content"] = strings.Join(textParts, "\n")
|
|
||||||
modified = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return modified
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAIRecordUsageInput input for recording usage
|
// OpenAIRecordUsageInput input for recording usage
|
||||||
type OpenAIRecordUsageInput struct {
|
type OpenAIRecordUsageInput struct {
|
||||||
Result *OpenAIForwardResult
|
Result *OpenAIForwardResult
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
|||||||
Credentials: map[string]any{"base_url": "://invalid-url"},
|
Credentials: map[string]any{"base_url": "://invalid-url"},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,10 +24,11 @@ var (
|
|||||||
|
|
||||||
// PromoService 优惠码服务
|
// PromoService 优惠码服务
|
||||||
type PromoService struct {
|
type PromoService struct {
|
||||||
promoRepo PromoCodeRepository
|
promoRepo PromoCodeRepository
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
entClient *dbent.Client
|
entClient *dbent.Client
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPromoService 创建优惠码服务实例
|
// NewPromoService 创建优惠码服务实例
|
||||||
@@ -36,12 +37,14 @@ func NewPromoService(
|
|||||||
userRepo UserRepository,
|
userRepo UserRepository,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
entClient *dbent.Client,
|
entClient *dbent.Client,
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
) *PromoService {
|
) *PromoService {
|
||||||
return &PromoService{
|
return &PromoService{
|
||||||
promoRepo: promoRepo,
|
promoRepo: promoRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
entClient: entClient,
|
entClient: entClient,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
|
|||||||
return fmt.Errorf("commit transaction: %w", err)
|
return fmt.Errorf("commit transaction: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
|
||||||
|
|
||||||
// 失效余额缓存
|
// 失效余额缓存
|
||||||
if s.billingCacheService != nil {
|
if s.billingCacheService != nil {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
|
||||||
|
if bonusAmount == 0 || s.authCacheInvalidator == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateRandomCode 生成随机优惠码
|
// GenerateRandomCode 生成随机优惠码
|
||||||
func (s *PromoService) GenerateRandomCode() (string, error) {
|
func (s *PromoService) GenerateRandomCode() (string, error) {
|
||||||
bytes := make([]byte, 8)
|
bytes := make([]byte, 8)
|
||||||
|
|||||||
122
backend/internal/service/prompts/codex_opencode_bridge.txt
Normal file
122
backend/internal/service/prompts/codex_opencode_bridge.txt
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# Codex Running in OpenCode
|
||||||
|
|
||||||
|
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
|
||||||
|
|
||||||
|
## CRITICAL: Tool Replacements
|
||||||
|
|
||||||
|
<critical_rule priority="0">
|
||||||
|
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||||
|
- NEVER use: apply_patch, applyPatch
|
||||||
|
- ALWAYS use: edit tool for ALL file modifications
|
||||||
|
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||||
|
</critical_rule>
|
||||||
|
|
||||||
|
<critical_rule priority="0">
|
||||||
|
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||||
|
- NEVER use: update_plan, updatePlan, read_plan, readPlan
|
||||||
|
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
|
||||||
|
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||||
|
</critical_rule>
|
||||||
|
|
||||||
|
## Available OpenCode Tools
|
||||||
|
|
||||||
|
**File Operations:**
|
||||||
|
- `write` - Create new files
|
||||||
|
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
|
||||||
|
- `edit` - Modify existing files (REPLACES apply_patch)
|
||||||
|
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
|
||||||
|
- `read` - Read file contents
|
||||||
|
|
||||||
|
**Search/Discovery:**
|
||||||
|
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
|
||||||
|
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
|
||||||
|
- `list` - List directories (requires absolute paths)
|
||||||
|
|
||||||
|
**Execution:**
|
||||||
|
- `bash` - Run shell commands
|
||||||
|
- No workdir parameter; do not include it in tool calls.
|
||||||
|
- Always include a short description for the command.
|
||||||
|
- Do not use cd; use absolute paths in commands.
|
||||||
|
- Quote paths containing spaces with double quotes.
|
||||||
|
- Chain multiple commands with ';' or '&&'; avoid newlines.
|
||||||
|
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
|
||||||
|
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
|
||||||
|
- For deletions (rm), verify by listing parent dir with `list`.
|
||||||
|
|
||||||
|
**Network:**
|
||||||
|
- `webfetch` - Fetch web content
|
||||||
|
- Use fully-formed URLs (http/https; http auto-upgrades to https).
|
||||||
|
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
|
||||||
|
- Read-only; short cache window.
|
||||||
|
|
||||||
|
**Task Management:**
|
||||||
|
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
|
||||||
|
- `todoread` - Read current plan
|
||||||
|
|
||||||
|
## Substitution Rules
|
||||||
|
|
||||||
|
Base instruction says: You MUST use instead:
|
||||||
|
apply_patch → edit
|
||||||
|
update_plan → todowrite
|
||||||
|
read_plan → todoread
|
||||||
|
|
||||||
|
**Path Usage:** Use per-tool conventions to avoid conflicts:
|
||||||
|
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
|
||||||
|
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
|
||||||
|
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
|
||||||
|
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
|
||||||
|
|
||||||
|
## Verification Checklist
|
||||||
|
|
||||||
|
Before file/plan modifications:
|
||||||
|
1. Am I using "edit" NOT "apply_patch"?
|
||||||
|
2. Am I using "todowrite" NOT "update_plan"?
|
||||||
|
3. Is this tool in the approved list above?
|
||||||
|
4. Am I following each tool's path requirements?
|
||||||
|
|
||||||
|
If ANY answer is NO → STOP and correct before proceeding.
|
||||||
|
|
||||||
|
## OpenCode Working Style
|
||||||
|
|
||||||
|
**Communication:**
|
||||||
|
- Send brief preambles (8-12 words) before tool calls, building on prior context
|
||||||
|
- Provide progress updates during longer tasks
|
||||||
|
|
||||||
|
**Execution:**
|
||||||
|
- Keep working autonomously until query is fully resolved before yielding
|
||||||
|
- Don't return to user with partial solutions
|
||||||
|
|
||||||
|
**Code Approach:**
|
||||||
|
- New projects: Be ambitious and creative
|
||||||
|
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
|
||||||
|
|
||||||
|
**Testing:**
|
||||||
|
- If tests exist: Start specific to your changes, then broader validation
|
||||||
|
|
||||||
|
## Advanced Tools
|
||||||
|
|
||||||
|
**Task Tool (Sub-Agents):**
|
||||||
|
- Use the Task tool (functions.task) to launch sub-agents
|
||||||
|
- Check the Task tool description for current agent types and their capabilities
|
||||||
|
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
|
||||||
|
- The agent list is dynamically generated - refer to tool schema for available agents
|
||||||
|
|
||||||
|
**Parallelization:**
|
||||||
|
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
|
||||||
|
- Reserve sequential calls for ordered or data-dependent steps.
|
||||||
|
|
||||||
|
**MCP Tools:**
|
||||||
|
- Model Context Protocol servers provide additional capabilities
|
||||||
|
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
|
||||||
|
- Check your available tools for MCP integrations
|
||||||
|
- Use when the tool's functionality matches your task needs
|
||||||
|
|
||||||
|
## What Remains from Codex
|
||||||
|
|
||||||
|
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
|
||||||
|
|
||||||
|
## Approvals & Safety
|
||||||
|
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
|
||||||
|
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
|
||||||
|
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
|
||||||
|
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).
|
||||||
63
backend/internal/service/prompts/tool_remap_message.txt
Normal file
63
backend/internal/service/prompts/tool_remap_message.txt
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
<user_instructions priority="0">
|
||||||
|
<environment_override priority="0">
|
||||||
|
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
|
||||||
|
</environment_override>
|
||||||
|
|
||||||
|
<tool_replacements priority="0">
|
||||||
|
<critical_rule priority="0">
|
||||||
|
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||||
|
- NEVER use: apply_patch, applyPatch
|
||||||
|
- ALWAYS use: edit tool for ALL file modifications
|
||||||
|
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||||
|
</critical_rule>
|
||||||
|
|
||||||
|
<critical_rule priority="0">
|
||||||
|
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||||
|
- NEVER use: update_plan, updatePlan
|
||||||
|
- ALWAYS use: todowrite for ALL task/plan operations
|
||||||
|
- Use todoread to read current plan
|
||||||
|
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||||
|
</critical_rule>
|
||||||
|
</tool_replacements>
|
||||||
|
|
||||||
|
<available_tools priority="0">
|
||||||
|
File Operations:
|
||||||
|
• write - Create new files
|
||||||
|
• edit - Modify existing files (REPLACES apply_patch)
|
||||||
|
• patch - Apply diff patches
|
||||||
|
• read - Read file contents
|
||||||
|
|
||||||
|
Search/Discovery:
|
||||||
|
• grep - Search file contents
|
||||||
|
• glob - Find files by pattern
|
||||||
|
• list - List directories (use relative paths)
|
||||||
|
|
||||||
|
Execution:
|
||||||
|
• bash - Run shell commands
|
||||||
|
|
||||||
|
Network:
|
||||||
|
• webfetch - Fetch web content
|
||||||
|
|
||||||
|
Task Management:
|
||||||
|
• todowrite - Manage tasks/plans (REPLACES update_plan)
|
||||||
|
• todoread - Read current plan
|
||||||
|
</available_tools>
|
||||||
|
|
||||||
|
<substitution_rules priority="0">
|
||||||
|
Base instruction says: You MUST use instead:
|
||||||
|
apply_patch → edit
|
||||||
|
update_plan → todowrite
|
||||||
|
read_plan → todoread
|
||||||
|
absolute paths → relative paths
|
||||||
|
</substitution_rules>
|
||||||
|
|
||||||
|
<verification_checklist priority="0">
|
||||||
|
Before file/plan modifications:
|
||||||
|
1. Am I using "edit" NOT "apply_patch"?
|
||||||
|
2. Am I using "todowrite" NOT "update_plan"?
|
||||||
|
3. Is this tool in the approved list above?
|
||||||
|
4. Am I using relative paths?
|
||||||
|
|
||||||
|
If ANY answer is NO → STOP and correct before proceeding.
|
||||||
|
</verification_checklist>
|
||||||
|
</user_instructions>
|
||||||
@@ -68,12 +68,13 @@ type RedeemCodeResponse struct {
|
|||||||
|
|
||||||
// RedeemService 兑换码服务
|
// RedeemService 兑换码服务
|
||||||
type RedeemService struct {
|
type RedeemService struct {
|
||||||
redeemRepo RedeemCodeRepository
|
redeemRepo RedeemCodeRepository
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
subscriptionService *SubscriptionService
|
subscriptionService *SubscriptionService
|
||||||
cache RedeemCache
|
cache RedeemCache
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
entClient *dbent.Client
|
entClient *dbent.Client
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedeemService 创建兑换码服务实例
|
// NewRedeemService 创建兑换码服务实例
|
||||||
@@ -84,14 +85,16 @@ func NewRedeemService(
|
|||||||
cache RedeemCache,
|
cache RedeemCache,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
entClient *dbent.Client,
|
entClient *dbent.Client,
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||||
) *RedeemService {
|
) *RedeemService {
|
||||||
return &RedeemService{
|
return &RedeemService{
|
||||||
redeemRepo: redeemRepo,
|
redeemRepo: redeemRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
subscriptionService: subscriptionService,
|
subscriptionService: subscriptionService,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
entClient: entClient,
|
entClient: entClient,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,18 +327,33 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
|||||||
|
|
||||||
// invalidateRedeemCaches 失效兑换相关的缓存
|
// invalidateRedeemCaches 失效兑换相关的缓存
|
||||||
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
|
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
|
||||||
if s.billingCacheService == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch redeemCode.Type {
|
switch redeemCode.Type {
|
||||||
case RedeemTypeBalance:
|
case RedeemTypeBalance:
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
if s.billingCacheService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||||
}()
|
}()
|
||||||
|
case RedeemTypeConcurrency:
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
if s.billingCacheService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
case RedeemTypeSubscription:
|
case RedeemTypeSubscription:
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
if s.billingCacheService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if redeemCode.GroupID != nil {
|
if redeemCode.GroupID != nil {
|
||||||
groupID := *redeemCode.GroupID
|
groupID := *redeemCode.GroupID
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -54,17 +54,19 @@ type UsageStats struct {
|
|||||||
|
|
||||||
// UsageService 使用统计服务
|
// UsageService 使用统计服务
|
||||||
type UsageService struct {
|
type UsageService struct {
|
||||||
usageRepo UsageLogRepository
|
usageRepo UsageLogRepository
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
entClient *dbent.Client
|
entClient *dbent.Client
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUsageService 创建使用统计服务实例
|
// NewUsageService 创建使用统计服务实例
|
||||||
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
|
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService {
|
||||||
return &UsageService{
|
return &UsageService{
|
||||||
usageRepo: usageRepo,
|
usageRepo: usageRepo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
entClient: entClient,
|
entClient: entClient,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 扣除用户余额
|
// 扣除用户余额
|
||||||
|
balanceUpdated := false
|
||||||
if inserted && req.ActualCost > 0 {
|
if inserted && req.ActualCost > 0 {
|
||||||
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
|
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
|
||||||
return nil, fmt.Errorf("update user balance: %w", err)
|
return nil, fmt.Errorf("update user balance: %w", err)
|
||||||
}
|
}
|
||||||
|
balanceUpdated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if tx != nil {
|
if tx != nil {
|
||||||
@@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated)
|
||||||
|
|
||||||
return usageLog, nil
|
return usageLog, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) {
|
||||||
|
if !balanceUpdated || s.authCacheInvalidator == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetByID 根据ID获取使用日志
|
// GetByID 根据ID获取使用日志
|
||||||
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
|
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
|
||||||
log, err := s.usageRepo.GetByID(ctx, id)
|
log, err := s.usageRepo.GetByID(ctx, id)
|
||||||
|
|||||||
@@ -55,13 +55,15 @@ type ChangePasswordRequest struct {
|
|||||||
|
|
||||||
// UserService 用户服务
|
// UserService 用户服务
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserService 创建用户服务实例
|
// NewUserService 创建用户服务实例
|
||||||
func NewUserService(userRepo UserRepository) *UserService {
|
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
|
||||||
return &UserService{
|
return &UserService{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get user: %w", err)
|
return nil, fmt.Errorf("get user: %w", err)
|
||||||
}
|
}
|
||||||
|
oldConcurrency := user.Concurrency
|
||||||
|
|
||||||
// 更新字段
|
// 更新字段
|
||||||
if req.Email != nil {
|
if req.Email != nil {
|
||||||
@@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return nil, fmt.Errorf("update user: %w", err)
|
return nil, fmt.Errorf("update user: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
@@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
|||||||
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
|
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
|
||||||
return fmt.Errorf("update balance: %w", err)
|
return fmt.Errorf("update balance: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu
|
|||||||
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
|
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
|
||||||
return fmt.Errorf("update concurrency: %w", err)
|
return fmt.Errorf("update concurrency: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,12 +204,18 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str
|
|||||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||||
return fmt.Errorf("update user: %w", err)
|
return fmt.Errorf("update user: %w", err)
|
||||||
}
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete 删除用户(管理员功能)
|
// Delete 删除用户(管理员功能)
|
||||||
func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
||||||
return fmt.Errorf("delete user: %w", err)
|
return fmt.Errorf("delete user: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,13 @@ func ProvideTokenRefreshService(
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||||
|
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||||
|
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||||
|
svc.Start()
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideAccountExpiryService creates and starts AccountExpiryService.
|
// ProvideAccountExpiryService creates and starts AccountExpiryService.
|
||||||
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
|
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
|
||||||
svc := NewAccountExpiryService(accountRepo, time.Minute)
|
svc := NewAccountExpiryService(accountRepo, time.Minute)
|
||||||
@@ -145,12 +152,18 @@ func ProvideOpsScheduledReportService(
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
|
||||||
|
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
|
||||||
|
return apiKeyService
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all services
|
// ProviderSet is the Wire provider set for all services
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Core services
|
// Core services
|
||||||
NewAuthService,
|
NewAuthService,
|
||||||
NewUserService,
|
NewUserService,
|
||||||
NewAPIKeyService,
|
NewAPIKeyService,
|
||||||
|
ProvideAPIKeyAuthCacheInvalidator,
|
||||||
NewGroupService,
|
NewGroupService,
|
||||||
NewAccountService,
|
NewAccountService,
|
||||||
NewProxyService,
|
NewProxyService,
|
||||||
@@ -194,6 +207,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
ProvideTokenRefreshService,
|
ProvideTokenRefreshService,
|
||||||
ProvideAccountExpiryService,
|
ProvideAccountExpiryService,
|
||||||
ProvideTimingWheelService,
|
ProvideTimingWheelService,
|
||||||
|
ProvideDashboardAggregationService,
|
||||||
ProvideDeferredService,
|
ProvideDeferredService,
|
||||||
NewAntigravityQuotaFetcher,
|
NewAntigravityQuotaFetcher,
|
||||||
NewUserAttributeService,
|
NewUserAttributeService,
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
-- Usage dashboard aggregation tables (hourly/daily) + active-user dedup + watermark.
|
||||||
|
-- These tables support Admin Dashboard statistics without full-table scans on usage_logs.
|
||||||
|
|
||||||
|
-- Hourly aggregates (UTC buckets).
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_dashboard_hourly (
|
||||||
|
bucket_start TIMESTAMPTZ PRIMARY KEY,
|
||||||
|
total_requests BIGINT NOT NULL DEFAULT 0,
|
||||||
|
input_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
output_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
cache_creation_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
cache_read_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||||
|
actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||||
|
total_duration_ms BIGINT NOT NULL DEFAULT 0,
|
||||||
|
active_users BIGINT NOT NULL DEFAULT 0,
|
||||||
|
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_usage_dashboard_hourly_bucket_start
|
||||||
|
ON usage_dashboard_hourly (bucket_start DESC);
|
||||||
|
|
||||||
|
COMMENT ON TABLE usage_dashboard_hourly IS 'Pre-aggregated hourly usage metrics for admin dashboard (UTC buckets).';
|
||||||
|
COMMENT ON COLUMN usage_dashboard_hourly.bucket_start IS 'UTC start timestamp of the hour bucket.';
|
||||||
|
COMMENT ON COLUMN usage_dashboard_hourly.computed_at IS 'When the hourly row was last computed/refreshed.';
|
||||||
|
|
||||||
|
-- Daily aggregates (UTC dates).
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_dashboard_daily (
|
||||||
|
bucket_date DATE PRIMARY KEY,
|
||||||
|
total_requests BIGINT NOT NULL DEFAULT 0,
|
||||||
|
input_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
output_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
cache_creation_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
cache_read_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||||
|
actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0,
|
||||||
|
total_duration_ms BIGINT NOT NULL DEFAULT 0,
|
||||||
|
active_users BIGINT NOT NULL DEFAULT 0,
|
||||||
|
computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_usage_dashboard_daily_bucket_date
|
||||||
|
ON usage_dashboard_daily (bucket_date DESC);
|
||||||
|
|
||||||
|
COMMENT ON TABLE usage_dashboard_daily IS 'Pre-aggregated daily usage metrics for admin dashboard (UTC dates).';
|
||||||
|
COMMENT ON COLUMN usage_dashboard_daily.bucket_date IS 'UTC date of the day bucket.';
|
||||||
|
COMMENT ON COLUMN usage_dashboard_daily.computed_at IS 'When the daily row was last computed/refreshed.';
|
||||||
|
|
||||||
|
-- Hourly active user dedup table.
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_dashboard_hourly_users (
|
||||||
|
bucket_start TIMESTAMPTZ NOT NULL,
|
||||||
|
user_id BIGINT NOT NULL,
|
||||||
|
PRIMARY KEY (bucket_start, user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_usage_dashboard_hourly_users_bucket_start
|
||||||
|
ON usage_dashboard_hourly_users (bucket_start);
|
||||||
|
|
||||||
|
-- Daily active user dedup table.
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_dashboard_daily_users (
|
||||||
|
bucket_date DATE NOT NULL,
|
||||||
|
user_id BIGINT NOT NULL,
|
||||||
|
PRIMARY KEY (bucket_date, user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_usage_dashboard_daily_users_bucket_date
|
||||||
|
ON usage_dashboard_daily_users (bucket_date);
|
||||||
|
|
||||||
|
-- Aggregation watermark table (single row).
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_dashboard_aggregation_watermark (
|
||||||
|
id INT PRIMARY KEY,
|
||||||
|
last_aggregated_at TIMESTAMPTZ NOT NULL DEFAULT TIMESTAMPTZ '1970-01-01 00:00:00+00',
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO usage_dashboard_aggregation_watermark (id)
|
||||||
|
VALUES (1)
|
||||||
|
ON CONFLICT (id) DO NOTHING;
|
||||||
54
backend/migrations/035_usage_logs_partitioning.sql
Normal file
54
backend/migrations/035_usage_logs_partitioning.sql
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
-- usage_logs monthly partition bootstrap.
|
||||||
|
-- Only creates partitions when usage_logs is already partitioned.
|
||||||
|
-- Converting usage_logs to a partitioned table requires a manual migration plan.
|
||||||
|
|
||||||
|
DO $$
|
||||||
|
DECLARE
|
||||||
|
is_partitioned BOOLEAN := FALSE;
|
||||||
|
has_data BOOLEAN := FALSE;
|
||||||
|
month_start DATE;
|
||||||
|
prev_month DATE;
|
||||||
|
next_month DATE;
|
||||||
|
BEGIN
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_partitioned_table pt
|
||||||
|
JOIN pg_class c ON c.oid = pt.partrelid
|
||||||
|
WHERE c.relname = 'usage_logs'
|
||||||
|
) INTO is_partitioned;
|
||||||
|
|
||||||
|
IF NOT is_partitioned THEN
|
||||||
|
SELECT EXISTS(SELECT 1 FROM usage_logs LIMIT 1) INTO has_data;
|
||||||
|
IF NOT has_data THEN
|
||||||
|
-- Automatic conversion is intentionally skipped; see manual migration plan.
|
||||||
|
RAISE NOTICE 'usage_logs is not partitioned; skip automatic partitioning';
|
||||||
|
END IF;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF is_partitioned THEN
|
||||||
|
month_start := date_trunc('month', now() AT TIME ZONE 'UTC')::date;
|
||||||
|
prev_month := (month_start - INTERVAL '1 month')::date;
|
||||||
|
next_month := (month_start + INTERVAL '1 month')::date;
|
||||||
|
|
||||||
|
EXECUTE format(
|
||||||
|
'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)',
|
||||||
|
to_char(prev_month, 'YYYYMM'),
|
||||||
|
prev_month,
|
||||||
|
month_start
|
||||||
|
);
|
||||||
|
|
||||||
|
EXECUTE format(
|
||||||
|
'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)',
|
||||||
|
to_char(month_start, 'YYYYMM'),
|
||||||
|
month_start,
|
||||||
|
next_month
|
||||||
|
);
|
||||||
|
|
||||||
|
EXECUTE format(
|
||||||
|
'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)',
|
||||||
|
to_char(next_month, 'YYYYMM'),
|
||||||
|
next_month,
|
||||||
|
(next_month + INTERVAL '1 month')::date
|
||||||
|
);
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
81
config.yaml
81
config.yaml
@@ -170,6 +170,87 @@ gateway:
|
|||||||
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
||||||
failover_on_400: false
|
failover_on_400: false
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API Key Auth Cache Configuration
|
||||||
|
# API Key 认证缓存配置
|
||||||
|
# =============================================================================
|
||||||
|
api_key_auth_cache:
|
||||||
|
# L1 cache size (entries), in-process LRU/TTL cache
|
||||||
|
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
|
||||||
|
l1_size: 65535
|
||||||
|
# L1 cache TTL (seconds)
|
||||||
|
# L1 缓存 TTL(秒)
|
||||||
|
l1_ttl_seconds: 15
|
||||||
|
# L2 cache TTL (seconds), stored in Redis
|
||||||
|
# L2 缓存 TTL(秒),Redis 中存储
|
||||||
|
l2_ttl_seconds: 300
|
||||||
|
# Negative cache TTL (seconds)
|
||||||
|
# 负缓存 TTL(秒)
|
||||||
|
negative_ttl_seconds: 30
|
||||||
|
# TTL jitter percent (0-100)
|
||||||
|
# TTL 抖动百分比(0-100)
|
||||||
|
jitter_percent: 10
|
||||||
|
# Enable singleflight for cache misses
|
||||||
|
# 缓存未命中时启用 singleflight 合并回源
|
||||||
|
singleflight: true
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Dashboard Cache Configuration
|
||||||
|
# 仪表盘缓存配置
|
||||||
|
# =============================================================================
|
||||||
|
dashboard_cache:
|
||||||
|
# Enable dashboard cache
|
||||||
|
# 启用仪表盘缓存
|
||||||
|
enabled: true
|
||||||
|
# Redis key prefix for multi-environment isolation
|
||||||
|
# Redis key 前缀,用于多环境隔离
|
||||||
|
key_prefix: "sub2api:"
|
||||||
|
# Fresh TTL (seconds); within this window cached stats are considered fresh
|
||||||
|
# 新鲜阈值(秒);命中后处于该窗口视为新鲜数据
|
||||||
|
stats_fresh_ttl_seconds: 15
|
||||||
|
# Cache TTL (seconds) stored in Redis
|
||||||
|
# Redis 缓存 TTL(秒)
|
||||||
|
stats_ttl_seconds: 30
|
||||||
|
# Async refresh timeout (seconds)
|
||||||
|
# 异步刷新超时(秒)
|
||||||
|
stats_refresh_timeout_seconds: 30
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Dashboard Aggregation Configuration
|
||||||
|
# 仪表盘预聚合配置(重启生效)
|
||||||
|
# =============================================================================
|
||||||
|
dashboard_aggregation:
|
||||||
|
# Enable aggregation job
|
||||||
|
# 启用聚合作业
|
||||||
|
enabled: true
|
||||||
|
# Refresh interval (seconds)
|
||||||
|
# 刷新间隔(秒)
|
||||||
|
interval_seconds: 60
|
||||||
|
# Lookback window (seconds) for late-arriving data
|
||||||
|
# 回看窗口(秒),处理迟到数据
|
||||||
|
lookback_seconds: 120
|
||||||
|
# Allow manual backfill
|
||||||
|
# 允许手动回填
|
||||||
|
backfill_enabled: false
|
||||||
|
# Backfill max range (days)
|
||||||
|
# 回填最大跨度(天)
|
||||||
|
backfill_max_days: 31
|
||||||
|
# Recompute recent N days on startup
|
||||||
|
# 启动时重算最近 N 天
|
||||||
|
recompute_days: 2
|
||||||
|
# Retention windows (days)
|
||||||
|
# 保留窗口(天)
|
||||||
|
retention:
|
||||||
|
# Raw usage_logs retention
|
||||||
|
# 原始 usage_logs 保留天数
|
||||||
|
usage_logs_days: 90
|
||||||
|
# Hourly aggregation retention
|
||||||
|
# 小时聚合保留天数
|
||||||
|
hourly_days: 180
|
||||||
|
# Daily aggregation retention
|
||||||
|
# 日聚合保留天数
|
||||||
|
daily_days: 730
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Concurrency Wait Configuration
|
# Concurrency Wait Configuration
|
||||||
# 并发等待配置
|
# 并发等待配置
|
||||||
|
|||||||
@@ -69,6 +69,33 @@ JWT_EXPIRE_HOUR=24
|
|||||||
# Leave unset to use default ./config.yaml
|
# Leave unset to use default ./config.yaml
|
||||||
#CONFIG_FILE=./config.yaml
|
#CONFIG_FILE=./config.yaml
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Dashboard Aggregation (Optional)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Enable aggregation job
|
||||||
|
# 启用仪表盘预聚合
|
||||||
|
DASHBOARD_AGGREGATION_ENABLED=true
|
||||||
|
# Refresh interval (seconds)
|
||||||
|
# 刷新间隔(秒)
|
||||||
|
DASHBOARD_AGGREGATION_INTERVAL_SECONDS=60
|
||||||
|
# Lookback window (seconds)
|
||||||
|
# 回看窗口(秒)
|
||||||
|
DASHBOARD_AGGREGATION_LOOKBACK_SECONDS=120
|
||||||
|
# Allow manual backfill
|
||||||
|
# 允许手动回填
|
||||||
|
DASHBOARD_AGGREGATION_BACKFILL_ENABLED=false
|
||||||
|
# Backfill max range (days)
|
||||||
|
# 回填最大跨度(天)
|
||||||
|
DASHBOARD_AGGREGATION_BACKFILL_MAX_DAYS=31
|
||||||
|
# Recompute recent N days on startup
|
||||||
|
# 启动时重算最近 N 天
|
||||||
|
DASHBOARD_AGGREGATION_RECOMPUTE_DAYS=2
|
||||||
|
# Retention windows (days)
|
||||||
|
# 保留窗口(天)
|
||||||
|
DASHBOARD_AGGREGATION_RETENTION_USAGE_LOGS_DAYS=90
|
||||||
|
DASHBOARD_AGGREGATION_RETENTION_HOURLY_DAYS=180
|
||||||
|
DASHBOARD_AGGREGATION_RETENTION_DAILY_DAYS=730
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Security Configuration
|
# Security Configuration
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -170,6 +170,87 @@ gateway:
|
|||||||
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
# 允许在特定 400 错误时进行故障转移(默认:关闭)
|
||||||
failover_on_400: false
|
failover_on_400: false
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API Key Auth Cache Configuration
|
||||||
|
# API Key 认证缓存配置
|
||||||
|
# =============================================================================
|
||||||
|
api_key_auth_cache:
|
||||||
|
# L1 cache size (entries), in-process LRU/TTL cache
|
||||||
|
# L1 缓存容量(条目数),进程内 LRU/TTL 缓存
|
||||||
|
l1_size: 65535
|
||||||
|
# L1 cache TTL (seconds)
|
||||||
|
# L1 缓存 TTL(秒)
|
||||||
|
l1_ttl_seconds: 15
|
||||||
|
# L2 cache TTL (seconds), stored in Redis
|
||||||
|
# L2 缓存 TTL(秒),Redis 中存储
|
||||||
|
l2_ttl_seconds: 300
|
||||||
|
# Negative cache TTL (seconds)
|
||||||
|
# 负缓存 TTL(秒)
|
||||||
|
negative_ttl_seconds: 30
|
||||||
|
# TTL jitter percent (0-100)
|
||||||
|
# TTL 抖动百分比(0-100)
|
||||||
|
jitter_percent: 10
|
||||||
|
# Enable singleflight for cache misses
|
||||||
|
# 缓存未命中时启用 singleflight 合并回源
|
||||||
|
singleflight: true
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Dashboard Cache Configuration
|
||||||
|
# 仪表盘缓存配置
|
||||||
|
# =============================================================================
|
||||||
|
dashboard_cache:
|
||||||
|
# Enable dashboard cache
|
||||||
|
# 启用仪表盘缓存
|
||||||
|
enabled: true
|
||||||
|
# Redis key prefix for multi-environment isolation
|
||||||
|
# Redis key 前缀,用于多环境隔离
|
||||||
|
key_prefix: "sub2api:"
|
||||||
|
# Fresh TTL (seconds); within this window cached stats are considered fresh
|
||||||
|
# 新鲜阈值(秒);命中后处于该窗口视为新鲜数据
|
||||||
|
stats_fresh_ttl_seconds: 15
|
||||||
|
# Cache TTL (seconds) stored in Redis
|
||||||
|
# Redis 缓存 TTL(秒)
|
||||||
|
stats_ttl_seconds: 30
|
||||||
|
# Async refresh timeout (seconds)
|
||||||
|
# 异步刷新超时(秒)
|
||||||
|
stats_refresh_timeout_seconds: 30
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Dashboard Aggregation Configuration
|
||||||
|
# 仪表盘预聚合配置(重启生效)
|
||||||
|
# =============================================================================
|
||||||
|
dashboard_aggregation:
|
||||||
|
# Enable aggregation job
|
||||||
|
# 启用聚合作业
|
||||||
|
enabled: true
|
||||||
|
# Refresh interval (seconds)
|
||||||
|
# 刷新间隔(秒)
|
||||||
|
interval_seconds: 60
|
||||||
|
# Lookback window (seconds) for late-arriving data
|
||||||
|
# 回看窗口(秒),处理迟到数据
|
||||||
|
lookback_seconds: 120
|
||||||
|
# Allow manual backfill
|
||||||
|
# 允许手动回填
|
||||||
|
backfill_enabled: false
|
||||||
|
# Backfill max range (days)
|
||||||
|
# 回填最大跨度(天)
|
||||||
|
backfill_max_days: 31
|
||||||
|
# Recompute recent N days on startup
|
||||||
|
# 启动时重算最近 N 天
|
||||||
|
recompute_days: 2
|
||||||
|
# Retention windows (days)
|
||||||
|
# 保留窗口(天)
|
||||||
|
retention:
|
||||||
|
# Raw usage_logs retention
|
||||||
|
# 原始 usage_logs 保留天数
|
||||||
|
usage_logs_days: 90
|
||||||
|
# Hourly aggregation retention
|
||||||
|
# 小时聚合保留天数
|
||||||
|
hourly_days: 180
|
||||||
|
# Daily aggregation retention
|
||||||
|
# 日聚合保留天数
|
||||||
|
daily_days: 730
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Concurrency Wait Configuration
|
# Concurrency Wait Configuration
|
||||||
# 并发等待配置
|
# 并发等待配置
|
||||||
|
|||||||
@@ -275,11 +275,15 @@ export async function bulkUpdate(
|
|||||||
): Promise<{
|
): Promise<{
|
||||||
success: number
|
success: number
|
||||||
failed: number
|
failed: number
|
||||||
|
success_ids?: number[]
|
||||||
|
failed_ids?: number[]
|
||||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||||
}> {
|
}> {
|
||||||
const { data } = await apiClient.post<{
|
const { data } = await apiClient.post<{
|
||||||
success: number
|
success: number
|
||||||
failed: number
|
failed: number
|
||||||
|
success_ids?: number[]
|
||||||
|
failed_ids?: number[]
|
||||||
results: Array<{ account_id: number; success: boolean; error?: string }>
|
results: Array<{ account_id: number; success: boolean; error?: string }>
|
||||||
}>('/admin/accounts/bulk-update', {
|
}>('/admin/accounts/bulk-update', {
|
||||||
account_ids: accountIds,
|
account_ids: accountIds,
|
||||||
|
|||||||
@@ -83,7 +83,7 @@
|
|||||||
<tr
|
<tr
|
||||||
v-else
|
v-else
|
||||||
v-for="(row, index) in sortedData"
|
v-for="(row, index) in sortedData"
|
||||||
:key="index"
|
:key="resolveRowKey(row, index)"
|
||||||
class="hover:bg-gray-50 dark:hover:bg-dark-800"
|
class="hover:bg-gray-50 dark:hover:bg-dark-800"
|
||||||
>
|
>
|
||||||
<td
|
<td
|
||||||
@@ -210,6 +210,7 @@ interface Props {
|
|||||||
stickyActionsColumn?: boolean
|
stickyActionsColumn?: boolean
|
||||||
expandableActions?: boolean
|
expandableActions?: boolean
|
||||||
actionsCount?: number // 操作按钮总数,用于判断是否需要展开功能
|
actionsCount?: number // 操作按钮总数,用于判断是否需要展开功能
|
||||||
|
rowKey?: string | ((row: any) => string | number)
|
||||||
}
|
}
|
||||||
|
|
||||||
const props = withDefaults(defineProps<Props>(), {
|
const props = withDefaults(defineProps<Props>(), {
|
||||||
@@ -222,6 +223,18 @@ const props = withDefaults(defineProps<Props>(), {
|
|||||||
const sortKey = ref<string>('')
|
const sortKey = ref<string>('')
|
||||||
const sortOrder = ref<'asc' | 'desc'>('asc')
|
const sortOrder = ref<'asc' | 'desc'>('asc')
|
||||||
const actionsExpanded = ref(false)
|
const actionsExpanded = ref(false)
|
||||||
|
const resolveRowKey = (row: any, index: number) => {
|
||||||
|
if (typeof props.rowKey === 'function') {
|
||||||
|
const key = props.rowKey(row)
|
||||||
|
return key ?? index
|
||||||
|
}
|
||||||
|
if (typeof props.rowKey === 'string' && props.rowKey) {
|
||||||
|
const key = row?.[props.rowKey]
|
||||||
|
return key ?? index
|
||||||
|
}
|
||||||
|
const key = row?.id
|
||||||
|
return key ?? index
|
||||||
|
}
|
||||||
|
|
||||||
// 数据/列变化时重新检查滚动状态
|
// 数据/列变化时重新检查滚动状态
|
||||||
// 注意:不能监听 actionsExpanded,因为 checkActionsColumnWidth 会临时修改它,会导致无限循环
|
// 注意:不能监听 actionsExpanded,因为 checkActionsColumnWidth 会临时修改它,会导致无限循环
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ A generic data table component with sorting, loading states, and custom cell ren
|
|||||||
- `columns: Column[]` - Array of column definitions with key, label, sortable, and formatter
|
- `columns: Column[]` - Array of column definitions with key, label, sortable, and formatter
|
||||||
- `data: any[]` - Array of data objects to display
|
- `data: any[]` - Array of data objects to display
|
||||||
- `loading?: boolean` - Show loading skeleton
|
- `loading?: boolean` - Show loading skeleton
|
||||||
|
- `rowKey?: string | (row: any) => string | number` - Row key field or resolver (defaults to `row.id`, falls back to index)
|
||||||
|
|
||||||
**Slots:**
|
**Slots:**
|
||||||
|
|
||||||
|
|||||||
@@ -28,8 +28,8 @@
|
|||||||
{{ platformDescription }}
|
{{ platformDescription }}
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<!-- Client Tabs (only for Antigravity platform) -->
|
<!-- Client Tabs -->
|
||||||
<div v-if="platform === 'antigravity'" class="border-b border-gray-200 dark:border-dark-700">
|
<div v-if="clientTabs.length" class="border-b border-gray-200 dark:border-dark-700">
|
||||||
<nav class="-mb-px flex space-x-6" aria-label="Client">
|
<nav class="-mb-px flex space-x-6" aria-label="Client">
|
||||||
<button
|
<button
|
||||||
v-for="tab in clientTabs"
|
v-for="tab in clientTabs"
|
||||||
@@ -51,7 +51,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- OS/Shell Tabs -->
|
<!-- OS/Shell Tabs -->
|
||||||
<div class="border-b border-gray-200 dark:border-dark-700">
|
<div v-if="showShellTabs" class="border-b border-gray-200 dark:border-dark-700">
|
||||||
<nav class="-mb-px flex space-x-4" aria-label="Tabs">
|
<nav class="-mb-px flex space-x-4" aria-label="Tabs">
|
||||||
<button
|
<button
|
||||||
v-for="tab in currentTabs"
|
v-for="tab in currentTabs"
|
||||||
@@ -111,7 +111,7 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Usage Note -->
|
<!-- Usage Note -->
|
||||||
<div class="flex items-start gap-3 p-3 rounded-lg bg-blue-50 dark:bg-blue-900/20 border border-blue-100 dark:border-blue-800">
|
<div v-if="showPlatformNote" class="flex items-start gap-3 p-3 rounded-lg bg-blue-50 dark:bg-blue-900/20 border border-blue-100 dark:border-blue-800">
|
||||||
<Icon name="infoCircle" size="md" class="text-blue-500 flex-shrink-0 mt-0.5" />
|
<Icon name="infoCircle" size="md" class="text-blue-500 flex-shrink-0 mt-0.5" />
|
||||||
<p class="text-sm text-blue-700 dark:text-blue-300">
|
<p class="text-sm text-blue-700 dark:text-blue-300">
|
||||||
{{ platformNote }}
|
{{ platformNote }}
|
||||||
@@ -173,17 +173,28 @@ const { copyToClipboard: clipboardCopy } = useClipboard()
|
|||||||
|
|
||||||
const copiedIndex = ref<number | null>(null)
|
const copiedIndex = ref<number | null>(null)
|
||||||
const activeTab = ref<string>('unix')
|
const activeTab = ref<string>('unix')
|
||||||
const activeClientTab = ref<string>('claude') // Level 1 tab for antigravity platform
|
const activeClientTab = ref<string>('claude')
|
||||||
|
|
||||||
// Reset tabs when platform changes
|
// Reset tabs when platform changes
|
||||||
watch(() => props.platform, (newPlatform) => {
|
const defaultClientTab = computed(() => {
|
||||||
activeTab.value = 'unix'
|
switch (props.platform) {
|
||||||
if (newPlatform === 'antigravity') {
|
case 'openai':
|
||||||
activeClientTab.value = 'claude'
|
return 'codex'
|
||||||
|
case 'gemini':
|
||||||
|
return 'gemini'
|
||||||
|
case 'antigravity':
|
||||||
|
return 'claude'
|
||||||
|
default:
|
||||||
|
return 'claude'
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Reset shell tab when client changes (for antigravity)
|
watch(() => props.platform, () => {
|
||||||
|
activeTab.value = 'unix'
|
||||||
|
activeClientTab.value = defaultClientTab.value
|
||||||
|
}, { immediate: true })
|
||||||
|
|
||||||
|
// Reset shell tab when client changes
|
||||||
watch(activeClientTab, () => {
|
watch(activeClientTab, () => {
|
||||||
activeTab.value = 'unix'
|
activeTab.value = 'unix'
|
||||||
})
|
})
|
||||||
@@ -251,11 +262,32 @@ const SparkleIcon = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client tabs for Antigravity platform (Level 1)
|
const clientTabs = computed((): TabConfig[] => {
|
||||||
const clientTabs = computed((): TabConfig[] => [
|
if (!props.platform) return []
|
||||||
{ id: 'claude', label: t('keys.useKeyModal.antigravity.claudeCode'), icon: TerminalIcon },
|
switch (props.platform) {
|
||||||
{ id: 'gemini', label: t('keys.useKeyModal.antigravity.geminiCli'), icon: SparkleIcon }
|
case 'openai':
|
||||||
])
|
return [
|
||||||
|
{ id: 'codex', label: t('keys.useKeyModal.cliTabs.codexCli'), icon: TerminalIcon },
|
||||||
|
{ id: 'opencode', label: t('keys.useKeyModal.cliTabs.opencode'), icon: TerminalIcon }
|
||||||
|
]
|
||||||
|
case 'gemini':
|
||||||
|
return [
|
||||||
|
{ id: 'gemini', label: t('keys.useKeyModal.cliTabs.geminiCli'), icon: SparkleIcon },
|
||||||
|
{ id: 'opencode', label: t('keys.useKeyModal.cliTabs.opencode'), icon: TerminalIcon }
|
||||||
|
]
|
||||||
|
case 'antigravity':
|
||||||
|
return [
|
||||||
|
{ id: 'claude', label: t('keys.useKeyModal.cliTabs.claudeCode'), icon: TerminalIcon },
|
||||||
|
{ id: 'gemini', label: t('keys.useKeyModal.cliTabs.geminiCli'), icon: SparkleIcon },
|
||||||
|
{ id: 'opencode', label: t('keys.useKeyModal.cliTabs.opencode'), icon: TerminalIcon }
|
||||||
|
]
|
||||||
|
default:
|
||||||
|
return [
|
||||||
|
{ id: 'claude', label: t('keys.useKeyModal.cliTabs.claudeCode'), icon: TerminalIcon },
|
||||||
|
{ id: 'opencode', label: t('keys.useKeyModal.cliTabs.opencode'), icon: TerminalIcon }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Shell tabs (3 types for environment variable based configs)
|
// Shell tabs (3 types for environment variable based configs)
|
||||||
const shellTabs: TabConfig[] = [
|
const shellTabs: TabConfig[] = [
|
||||||
@@ -270,11 +302,13 @@ const openaiTabs: TabConfig[] = [
|
|||||||
{ id: 'windows', label: 'Windows', icon: WindowsIcon }
|
{ id: 'windows', label: 'Windows', icon: WindowsIcon }
|
||||||
]
|
]
|
||||||
|
|
||||||
|
const showShellTabs = computed(() => activeClientTab.value !== 'opencode')
|
||||||
|
|
||||||
const currentTabs = computed(() => {
|
const currentTabs = computed(() => {
|
||||||
|
if (!showShellTabs.value) return []
|
||||||
if (props.platform === 'openai') {
|
if (props.platform === 'openai') {
|
||||||
return openaiTabs // 2 tabs: unix, windows
|
return openaiTabs
|
||||||
}
|
}
|
||||||
// All other platforms (anthropic, gemini, antigravity) use shell tabs
|
|
||||||
return shellTabs
|
return shellTabs
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -308,6 +342,8 @@ const platformNote = computed(() => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const showPlatformNote = computed(() => activeClientTab.value !== 'opencode')
|
||||||
|
|
||||||
const escapeHtml = (value: string) => value
|
const escapeHtml = (value: string) => value
|
||||||
.replace(/&/g, '&')
|
.replace(/&/g, '&')
|
||||||
.replace(/</g, '<')
|
.replace(/</g, '<')
|
||||||
@@ -329,6 +365,35 @@ const comment = (value: string) => wrapToken('text-slate-500', value)
|
|||||||
const currentFiles = computed((): FileConfig[] => {
|
const currentFiles = computed((): FileConfig[] => {
|
||||||
const baseUrl = props.baseUrl || window.location.origin
|
const baseUrl = props.baseUrl || window.location.origin
|
||||||
const apiKey = props.apiKey
|
const apiKey = props.apiKey
|
||||||
|
const baseRoot = baseUrl.replace(/\/v1\/?$/, '').replace(/\/+$/, '')
|
||||||
|
const ensureV1 = (value: string) => {
|
||||||
|
const trimmed = value.replace(/\/+$/, '')
|
||||||
|
return trimmed.endsWith('/v1') ? trimmed : `${trimmed}/v1`
|
||||||
|
}
|
||||||
|
const apiBase = ensureV1(baseRoot)
|
||||||
|
const antigravityBase = ensureV1(`${baseRoot}/antigravity`)
|
||||||
|
const antigravityGeminiBase = (() => {
|
||||||
|
const trimmed = `${baseRoot}/antigravity`.replace(/\/+$/, '')
|
||||||
|
return trimmed.endsWith('/v1beta') ? trimmed : `${trimmed}/v1beta`
|
||||||
|
})()
|
||||||
|
|
||||||
|
if (activeClientTab.value === 'opencode') {
|
||||||
|
switch (props.platform) {
|
||||||
|
case 'anthropic':
|
||||||
|
return [generateOpenCodeConfig('anthropic', apiBase, apiKey)]
|
||||||
|
case 'openai':
|
||||||
|
return [generateOpenCodeConfig('openai', apiBase, apiKey)]
|
||||||
|
case 'gemini':
|
||||||
|
return [generateOpenCodeConfig('gemini', apiBase, apiKey)]
|
||||||
|
case 'antigravity':
|
||||||
|
return [
|
||||||
|
generateOpenCodeConfig('antigravity-claude', antigravityBase, apiKey, 'opencode.json (Claude)'),
|
||||||
|
generateOpenCodeConfig('antigravity-gemini', antigravityGeminiBase, apiKey, 'opencode.json (Gemini)')
|
||||||
|
]
|
||||||
|
default:
|
||||||
|
return [generateOpenCodeConfig('openai', apiBase, apiKey)]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch (props.platform) {
|
switch (props.platform) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
@@ -336,12 +401,11 @@ const currentFiles = computed((): FileConfig[] => {
|
|||||||
case 'gemini':
|
case 'gemini':
|
||||||
return [generateGeminiCliContent(baseUrl, apiKey)]
|
return [generateGeminiCliContent(baseUrl, apiKey)]
|
||||||
case 'antigravity':
|
case 'antigravity':
|
||||||
// Both Claude Code and Gemini CLI need /antigravity suffix for antigravity platform
|
if (activeClientTab.value === 'gemini') {
|
||||||
if (activeClientTab.value === 'claude') {
|
return [generateGeminiCliContent(`${baseUrl}/antigravity`, apiKey)]
|
||||||
return generateAnthropicFiles(`${baseUrl}/antigravity`, apiKey)
|
|
||||||
}
|
}
|
||||||
return [generateGeminiCliContent(`${baseUrl}/antigravity`, apiKey)]
|
return generateAnthropicFiles(`${baseUrl}/antigravity`, apiKey)
|
||||||
default: // anthropic
|
default:
|
||||||
return generateAnthropicFiles(baseUrl, apiKey)
|
return generateAnthropicFiles(baseUrl, apiKey)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -456,6 +520,76 @@ requires_openai_auth = true`
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: string, pathLabel?: string): FileConfig {
|
||||||
|
const provider: Record<string, any> = {
|
||||||
|
[platform]: {
|
||||||
|
options: {
|
||||||
|
baseURL: baseUrl,
|
||||||
|
apiKey,
|
||||||
|
...(platform === 'openai' ? { store: false } : {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const openaiModels = {
|
||||||
|
'gpt-5.2-codex': {
|
||||||
|
name: 'GPT-5.2 Codex',
|
||||||
|
variants: {
|
||||||
|
low: {},
|
||||||
|
medium: {},
|
||||||
|
high: {},
|
||||||
|
xhigh: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const geminiModels = {
|
||||||
|
'gemini-3-pro-high': { name: 'Gemini 3 Pro High' },
|
||||||
|
'gemini-3-pro-low': { name: 'Gemini 3 Pro Low' },
|
||||||
|
'gemini-3-pro-preview': { name: 'Gemini 3 Pro Preview' },
|
||||||
|
'gemini-3-pro-image': { name: 'Gemini 3 Pro Image' },
|
||||||
|
'gemini-3-flash': { name: 'Gemini 3 Flash' },
|
||||||
|
'gemini-2.5-flash-thinking': { name: 'Gemini 2.5 Flash Thinking' },
|
||||||
|
'gemini-2.5-flash': { name: 'Gemini 2.5 Flash' },
|
||||||
|
'gemini-2.5-flash-lite': { name: 'Gemini 2.5 Flash Lite' }
|
||||||
|
}
|
||||||
|
const claudeModels = {
|
||||||
|
'claude-opus-4-5-thinking': { name: 'Claude Opus 4.5 Thinking' },
|
||||||
|
'claude-sonnet-4-5-thinking': { name: 'Claude Sonnet 4.5 Thinking' },
|
||||||
|
'claude-sonnet-4-5': { name: 'Claude Sonnet 4.5' }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (platform === 'gemini') {
|
||||||
|
provider[platform].npm = '@ai-sdk/google'
|
||||||
|
provider[platform].models = geminiModels
|
||||||
|
} else if (platform === 'anthropic') {
|
||||||
|
provider[platform].npm = '@ai-sdk/anthropic'
|
||||||
|
} else if (platform === 'antigravity-claude') {
|
||||||
|
provider[platform].npm = '@ai-sdk/anthropic'
|
||||||
|
provider[platform].name = 'Antigravity (Claude)'
|
||||||
|
provider[platform].models = claudeModels
|
||||||
|
} else if (platform === 'antigravity-gemini') {
|
||||||
|
provider[platform].npm = '@ai-sdk/google'
|
||||||
|
provider[platform].name = 'Antigravity (Gemini)'
|
||||||
|
provider[platform].models = geminiModels
|
||||||
|
} else if (platform === 'openai') {
|
||||||
|
provider[platform].models = openaiModels
|
||||||
|
}
|
||||||
|
|
||||||
|
const content = JSON.stringify(
|
||||||
|
{
|
||||||
|
provider,
|
||||||
|
$schema: 'https://opencode.ai/config.json'
|
||||||
|
},
|
||||||
|
null,
|
||||||
|
2
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
path: pathLabel ?? 'opencode.json',
|
||||||
|
content,
|
||||||
|
hint: t('keys.useKeyModal.opencode.hint')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const copyContent = async (content: string, index: number) => {
|
const copyContent = async (content: string, index: number) => {
|
||||||
const success = await clipboardCopy(content, t('keys.copied'))
|
const success = await clipboardCopy(content, t('keys.copied'))
|
||||||
if (success) {
|
if (success) {
|
||||||
|
|||||||
@@ -368,6 +368,12 @@ export default {
|
|||||||
note: 'Make sure the config directory exists. macOS/Linux users can run mkdir -p ~/.codex to create it.',
|
note: 'Make sure the config directory exists. macOS/Linux users can run mkdir -p ~/.codex to create it.',
|
||||||
noteWindows: 'Press Win+R and enter %userprofile%\\.codex to open the config directory. Create it manually if it does not exist.',
|
noteWindows: 'Press Win+R and enter %userprofile%\\.codex to open the config directory. Create it manually if it does not exist.',
|
||||||
},
|
},
|
||||||
|
cliTabs: {
|
||||||
|
claudeCode: 'Claude Code',
|
||||||
|
geminiCli: 'Gemini CLI',
|
||||||
|
codexCli: 'Codex CLI',
|
||||||
|
opencode: 'OpenCode',
|
||||||
|
},
|
||||||
antigravity: {
|
antigravity: {
|
||||||
description: 'Configure API access for Antigravity group. Select the configuration method based on your client.',
|
description: 'Configure API access for Antigravity group. Select the configuration method based on your client.',
|
||||||
claudeCode: 'Claude Code',
|
claudeCode: 'Claude Code',
|
||||||
@@ -380,6 +386,11 @@ export default {
|
|||||||
modelComment: 'If you have Gemini 3 access, you can use: gemini-3-pro-preview',
|
modelComment: 'If you have Gemini 3 access, you can use: gemini-3-pro-preview',
|
||||||
note: 'These environment variables will be active in the current terminal session. For permanent configuration, add them to ~/.bashrc, ~/.zshrc, or the appropriate configuration file.',
|
note: 'These environment variables will be active in the current terminal session. For permanent configuration, add them to ~/.bashrc, ~/.zshrc, or the appropriate configuration file.',
|
||||||
},
|
},
|
||||||
|
opencode: {
|
||||||
|
title: 'OpenCode Example',
|
||||||
|
subtitle: 'opencode.json',
|
||||||
|
hint: 'This is a group configuration example. Adjust model and options as needed.',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
customKeyLabel: 'Custom Key',
|
customKeyLabel: 'Custom Key',
|
||||||
customKeyPlaceholder: 'Enter your custom key (min 16 chars)',
|
customKeyPlaceholder: 'Enter your custom key (min 16 chars)',
|
||||||
@@ -1109,6 +1120,8 @@ export default {
|
|||||||
rateLimitCleared: 'Rate limit cleared successfully',
|
rateLimitCleared: 'Rate limit cleared successfully',
|
||||||
bulkSchedulableEnabled: 'Successfully enabled scheduling for {count} account(s)',
|
bulkSchedulableEnabled: 'Successfully enabled scheduling for {count} account(s)',
|
||||||
bulkSchedulableDisabled: 'Successfully disabled scheduling for {count} account(s)',
|
bulkSchedulableDisabled: 'Successfully disabled scheduling for {count} account(s)',
|
||||||
|
bulkSchedulablePartial: 'Scheduling updated partially: {success} succeeded, {failed} failed',
|
||||||
|
bulkSchedulableResultUnknown: 'Bulk scheduling result incomplete. Please retry or refresh.',
|
||||||
bulkActions: {
|
bulkActions: {
|
||||||
selected: '{count} account(s) selected',
|
selected: '{count} account(s) selected',
|
||||||
selectCurrentPage: 'Select this page',
|
selectCurrentPage: 'Select this page',
|
||||||
|
|||||||
@@ -366,6 +366,12 @@ export default {
|
|||||||
note: '请确保配置目录存在。macOS/Linux 用户可运行 mkdir -p ~/.codex 创建目录。',
|
note: '请确保配置目录存在。macOS/Linux 用户可运行 mkdir -p ~/.codex 创建目录。',
|
||||||
noteWindows: '按 Win+R,输入 %userprofile%\\.codex 打开配置目录。如目录不存在,请先手动创建。',
|
noteWindows: '按 Win+R,输入 %userprofile%\\.codex 打开配置目录。如目录不存在,请先手动创建。',
|
||||||
},
|
},
|
||||||
|
cliTabs: {
|
||||||
|
claudeCode: 'Claude Code',
|
||||||
|
geminiCli: 'Gemini CLI',
|
||||||
|
codexCli: 'Codex CLI',
|
||||||
|
opencode: 'OpenCode',
|
||||||
|
},
|
||||||
antigravity: {
|
antigravity: {
|
||||||
description: '为 Antigravity 分组配置 API 访问。请根据您使用的客户端选择对应的配置方式。',
|
description: '为 Antigravity 分组配置 API 访问。请根据您使用的客户端选择对应的配置方式。',
|
||||||
claudeCode: 'Claude Code',
|
claudeCode: 'Claude Code',
|
||||||
@@ -378,6 +384,11 @@ export default {
|
|||||||
modelComment: '如果你有 Gemini 3 权限可以填:gemini-3-pro-preview',
|
modelComment: '如果你有 Gemini 3 权限可以填:gemini-3-pro-preview',
|
||||||
note: '这些环境变量将在当前终端会话中生效。如需永久配置,请将其添加到 ~/.bashrc、~/.zshrc 或相应的配置文件中。',
|
note: '这些环境变量将在当前终端会话中生效。如需永久配置,请将其添加到 ~/.bashrc、~/.zshrc 或相应的配置文件中。',
|
||||||
},
|
},
|
||||||
|
opencode: {
|
||||||
|
title: 'OpenCode 配置示例',
|
||||||
|
subtitle: 'opencode.json',
|
||||||
|
hint: '示例仅用于演示分组配置,模型与选项可按需调整。',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
customKeyLabel: '自定义密钥',
|
customKeyLabel: '自定义密钥',
|
||||||
customKeyPlaceholder: '输入自定义密钥(至少16个字符)',
|
customKeyPlaceholder: '输入自定义密钥(至少16个字符)',
|
||||||
@@ -1246,6 +1257,8 @@ export default {
|
|||||||
accountDeletedSuccess: '账号删除成功',
|
accountDeletedSuccess: '账号删除成功',
|
||||||
bulkSchedulableEnabled: '成功启用 {count} 个账号的调度',
|
bulkSchedulableEnabled: '成功启用 {count} 个账号的调度',
|
||||||
bulkSchedulableDisabled: '成功停止 {count} 个账号的调度',
|
bulkSchedulableDisabled: '成功停止 {count} 个账号的调度',
|
||||||
|
bulkSchedulablePartial: '部分调度更新成功:成功 {success} 个,失败 {failed} 个',
|
||||||
|
bulkSchedulableResultUnknown: '批量调度结果不完整,请稍后重试或刷新列表',
|
||||||
bulkActions: {
|
bulkActions: {
|
||||||
selected: '已选择 {count} 个账号',
|
selected: '已选择 {count} 个账号',
|
||||||
selectCurrentPage: '本页全选',
|
selectCurrentPage: '本页全选',
|
||||||
|
|||||||
@@ -652,6 +652,9 @@ export interface DashboardStats {
|
|||||||
total_users: number
|
total_users: number
|
||||||
today_new_users: number // 今日新增用户数
|
today_new_users: number // 今日新增用户数
|
||||||
active_users: number // 今日有请求的用户数
|
active_users: number // 今日有请求的用户数
|
||||||
|
hourly_active_users: number // 当前小时活跃用户数(UTC)
|
||||||
|
stats_updated_at: string // 统计更新时间(UTC RFC3339)
|
||||||
|
stats_stale: boolean // 统计是否过期
|
||||||
|
|
||||||
// API Key 统计
|
// API Key 统计
|
||||||
total_api_keys: number
|
total_api_keys: number
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
</template>
|
</template>
|
||||||
<template #table>
|
<template #table>
|
||||||
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @edit="showBulkEdit = true" @clear="selIds = []" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" />
|
<AccountBulkActionsBar :selected-ids="selIds" @delete="handleBulkDelete" @edit="showBulkEdit = true" @clear="selIds = []" @select-page="selectPage" @toggle-schedulable="handleBulkToggleSchedulable" />
|
||||||
<DataTable :columns="cols" :data="accounts" :loading="loading">
|
<DataTable :columns="cols" :data="accounts" :loading="loading" row-key="id">
|
||||||
<template #cell-select="{ row }">
|
<template #cell-select="{ row }">
|
||||||
<input type="checkbox" :checked="selIds.includes(row.id)" @change="toggleSel(row.id)" class="rounded border-gray-300 text-primary-600 focus:ring-primary-500" />
|
<input type="checkbox" :checked="selIds.includes(row.id)" @change="toggleSel(row.id)" class="rounded border-gray-300 text-primary-600 focus:ring-primary-500" />
|
||||||
</template>
|
</template>
|
||||||
@@ -209,18 +209,107 @@ const openMenu = (a: Account, e: MouseEvent) => { menu.acc = a; menu.pos = { top
|
|||||||
const toggleSel = (id: number) => { const i = selIds.value.indexOf(id); if(i === -1) selIds.value.push(id); else selIds.value.splice(i, 1) }
|
const toggleSel = (id: number) => { const i = selIds.value.indexOf(id); if(i === -1) selIds.value.push(id); else selIds.value.splice(i, 1) }
|
||||||
const selectPage = () => { selIds.value = [...new Set([...selIds.value, ...accounts.value.map(a => a.id)])] }
|
const selectPage = () => { selIds.value = [...new Set([...selIds.value, ...accounts.value.map(a => a.id)])] }
|
||||||
const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); selIds.value = []; reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } }
|
const handleBulkDelete = async () => { if(!confirm(t('common.confirm'))) return; try { await Promise.all(selIds.value.map(id => adminAPI.accounts.delete(id))); selIds.value = []; reload() } catch (error) { console.error('Failed to bulk delete accounts:', error) } }
|
||||||
|
const updateSchedulableInList = (accountIds: number[], schedulable: boolean) => {
|
||||||
|
if (accountIds.length === 0) return
|
||||||
|
const idSet = new Set(accountIds)
|
||||||
|
accounts.value = accounts.value.map((account) => (idSet.has(account.id) ? { ...account, schedulable } : account))
|
||||||
|
}
|
||||||
|
const normalizeBulkSchedulableResult = (
|
||||||
|
result: {
|
||||||
|
success?: number
|
||||||
|
failed?: number
|
||||||
|
success_ids?: number[]
|
||||||
|
failed_ids?: number[]
|
||||||
|
results?: Array<{ account_id: number; success: boolean }>
|
||||||
|
},
|
||||||
|
accountIds: number[]
|
||||||
|
) => {
|
||||||
|
const responseSuccessIds = Array.isArray(result.success_ids) ? result.success_ids : []
|
||||||
|
const responseFailedIds = Array.isArray(result.failed_ids) ? result.failed_ids : []
|
||||||
|
if (responseSuccessIds.length > 0 || responseFailedIds.length > 0) {
|
||||||
|
return {
|
||||||
|
successIds: responseSuccessIds,
|
||||||
|
failedIds: responseFailedIds,
|
||||||
|
successCount: typeof result.success === 'number' ? result.success : responseSuccessIds.length,
|
||||||
|
failedCount: typeof result.failed === 'number' ? result.failed : responseFailedIds.length,
|
||||||
|
hasIds: true,
|
||||||
|
hasCounts: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const results = Array.isArray(result.results) ? result.results : []
|
||||||
|
if (results.length > 0) {
|
||||||
|
const successIds = results.filter(item => item.success).map(item => item.account_id)
|
||||||
|
const failedIds = results.filter(item => !item.success).map(item => item.account_id)
|
||||||
|
return {
|
||||||
|
successIds,
|
||||||
|
failedIds,
|
||||||
|
successCount: typeof result.success === 'number' ? result.success : successIds.length,
|
||||||
|
failedCount: typeof result.failed === 'number' ? result.failed : failedIds.length,
|
||||||
|
hasIds: true,
|
||||||
|
hasCounts: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasExplicitCounts = typeof result.success === 'number' || typeof result.failed === 'number'
|
||||||
|
const successCount = typeof result.success === 'number' ? result.success : 0
|
||||||
|
const failedCount = typeof result.failed === 'number' ? result.failed : 0
|
||||||
|
if (hasExplicitCounts && failedCount === 0 && successCount === accountIds.length && accountIds.length > 0) {
|
||||||
|
return {
|
||||||
|
successIds: accountIds,
|
||||||
|
failedIds: [],
|
||||||
|
successCount,
|
||||||
|
failedCount,
|
||||||
|
hasIds: true,
|
||||||
|
hasCounts: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
successIds: [],
|
||||||
|
failedIds: [],
|
||||||
|
successCount,
|
||||||
|
failedCount,
|
||||||
|
hasIds: false,
|
||||||
|
hasCounts: hasExplicitCounts
|
||||||
|
}
|
||||||
|
}
|
||||||
const handleBulkToggleSchedulable = async (schedulable: boolean) => {
|
const handleBulkToggleSchedulable = async (schedulable: boolean) => {
|
||||||
const count = selIds.value.length
|
const accountIds = [...selIds.value]
|
||||||
try {
|
try {
|
||||||
const result = await adminAPI.accounts.bulkUpdate(selIds.value, { schedulable });
|
const result = await adminAPI.accounts.bulkUpdate(accountIds, { schedulable })
|
||||||
const message = schedulable
|
const { successIds, failedIds, successCount, failedCount, hasIds, hasCounts } = normalizeBulkSchedulableResult(result, accountIds)
|
||||||
? t('admin.accounts.bulkSchedulableEnabled', { count: result.success || count })
|
if (!hasIds && !hasCounts) {
|
||||||
: t('admin.accounts.bulkSchedulableDisabled', { count: result.success || count });
|
appStore.showError(t('admin.accounts.bulkSchedulableResultUnknown'))
|
||||||
appStore.showSuccess(message);
|
selIds.value = accountIds
|
||||||
selIds.value = [];
|
load().catch((error) => {
|
||||||
reload()
|
console.error('Failed to refresh accounts:', error)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (successIds.length > 0) {
|
||||||
|
updateSchedulableInList(successIds, schedulable)
|
||||||
|
}
|
||||||
|
if (successCount > 0 && failedCount === 0) {
|
||||||
|
const message = schedulable
|
||||||
|
? t('admin.accounts.bulkSchedulableEnabled', { count: successCount })
|
||||||
|
: t('admin.accounts.bulkSchedulableDisabled', { count: successCount })
|
||||||
|
appStore.showSuccess(message)
|
||||||
|
}
|
||||||
|
if (failedCount > 0) {
|
||||||
|
const message = hasCounts || hasIds
|
||||||
|
? t('admin.accounts.bulkSchedulablePartial', { success: successCount, failed: failedCount })
|
||||||
|
: t('admin.accounts.bulkSchedulableResultUnknown')
|
||||||
|
appStore.showError(message)
|
||||||
|
selIds.value = failedIds.length > 0 ? failedIds : accountIds
|
||||||
|
} else {
|
||||||
|
selIds.value = hasIds ? [] : accountIds
|
||||||
|
}
|
||||||
|
load().catch((error) => {
|
||||||
|
console.error('Failed to refresh accounts:', error)
|
||||||
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to bulk toggle schedulable:', error);
|
console.error('Failed to bulk toggle schedulable:', error)
|
||||||
appStore.showError(t('common.error'))
|
appStore.showError(t('common.error'))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -236,7 +325,22 @@ const handleResetStatus = async (a: Account) => { try { await adminAPI.accounts.
|
|||||||
const handleClearRateLimit = async (a: Account) => { try { await adminAPI.accounts.clearRateLimit(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to clear rate limit:', error) } }
|
const handleClearRateLimit = async (a: Account) => { try { await adminAPI.accounts.clearRateLimit(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to clear rate limit:', error) } }
|
||||||
const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true }
|
const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true }
|
||||||
const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } }
|
const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } }
|
||||||
const handleToggleSchedulable = async (a: Account) => { togglingSchedulable.value = a.id; try { await adminAPI.accounts.setSchedulable(a.id, !a.schedulable); load() } finally { togglingSchedulable.value = null } }
|
const handleToggleSchedulable = async (a: Account) => {
|
||||||
|
const nextSchedulable = !a.schedulable
|
||||||
|
togglingSchedulable.value = a.id
|
||||||
|
try {
|
||||||
|
const updated = await adminAPI.accounts.setSchedulable(a.id, nextSchedulable)
|
||||||
|
updateSchedulableInList([a.id], updated?.schedulable ?? nextSchedulable)
|
||||||
|
load().catch((error) => {
|
||||||
|
console.error('Failed to refresh accounts:', error)
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to toggle schedulable:', error)
|
||||||
|
appStore.showError(t('admin.accounts.failedToToggleSchedulable'))
|
||||||
|
} finally {
|
||||||
|
togglingSchedulable.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true }
|
const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true }
|
||||||
const handleTempUnschedReset = async () => { if(!tempUnschedAcc.value) return; try { await adminAPI.accounts.clearError(tempUnschedAcc.value.id); showTempUnsched.value = false; tempUnschedAcc.value = null; load() } catch (error) { console.error('Failed to reset temp unscheduled:', error) } }
|
const handleTempUnschedReset = async () => { if(!tempUnschedAcc.value) return; try { await adminAPI.accounts.clearError(tempUnschedAcc.value.id); showTempUnsched.value = false; tempUnschedAcc.value = null; load() } catch (error) { console.error('Failed to reset temp unscheduled:', error) } }
|
||||||
const formatExpiresAt = (value: number | null) => {
|
const formatExpiresAt = (value: number | null) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user