diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index f0768f09..1efce7d5 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.70 +0.1.70.2 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ab1831d8..1568dd6e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 91437ba8..2fba69cb 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -38,31 +38,32 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - Totp TotpConfig `mapstructure:"totp"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` - UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } type GeminiConfig struct { @@ -147,6 +148,7 @@ type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Mode string `mapstructure:"mode"` // debug/release + FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接 ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) @@ -226,6 +228,9 @@ type GatewayConfig struct { MaxBodySize int64 `mapstructure:"max_body_size"` // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` + // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 + // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 + ForceCodexCLI bool `mapstructure:"force_codex_cli"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -525,6 +530,13 @@ type APIKeyAuthCacheConfig struct { Singleflight bool `mapstructure:"singleflight"` } +// SubscriptionCacheConfig 订阅认证 L1 缓存配置 +type SubscriptionCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` +} + // DashboardCacheConfig 仪表盘统计缓存配置 type DashboardCacheConfig struct { // Enabled: 是否启用仪表盘缓存 @@ -630,6 +642,7 @@ func Load() (*Config, error) { if cfg.Server.Mode == "" { cfg.Server.Mode = "debug" } + cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) @@ -702,7 +715,8 @@ func setDefaults() { // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) - viper.SetDefault("server.mode", "debug") + viper.SetDefault("server.mode", "release") + viper.SetDefault("server.frontend_url", "") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) @@ -737,7 +751,7 @@ func setDefaults() { viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true) - viper.SetDefault("security.response_headers.enabled", false) + viper.SetDefault("security.response_headers.enabled", true) viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.csp.enabled", true) @@ -775,9 +789,9 @@ func setDefaults() { viper.SetDefault("database.user", "postgres") viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") - viper.SetDefault("database.sslmode", "disable") - viper.SetDefault("database.max_open_conns", 50) - viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.sslmode", "prefer") + viper.SetDefault("database.max_open_conns", 256) + viper.SetDefault("database.max_idle_conns", 128) viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_idle_time_minutes", 5) @@ -789,8 +803,8 @@ func setDefaults() { viper.SetDefault("redis.dial_timeout_seconds", 5) viper.SetDefault("redis.read_timeout_seconds", 3) viper.SetDefault("redis.write_timeout_seconds", 3) - viper.SetDefault("redis.pool_size", 128) - viper.SetDefault("redis.min_idle_conns", 10) + viper.SetDefault("redis.pool_size", 1024) + viper.SetDefault("redis.min_idle_conns", 128) viper.SetDefault("redis.enable_tls", false) // Ops (vNext) @@ -849,6 +863,11 @@ func setDefaults() { viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.singleflight", true) + // Subscription auth L1 cache + viper.SetDefault("subscription_cache.l1_size", 16384) + viper.SetDefault("subscription_cache.l1_ttl_seconds", 10) + viper.SetDefault("subscription_cache.jitter_percent", 10) + // Dashboard cache viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") @@ -882,13 +901,14 @@ func setDefaults() { viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches_gemini", 3) + viper.SetDefault("gateway.force_codex_cli", false) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) - viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) - viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) @@ -933,6 +953,22 @@ func setDefaults() { } func (c *Config) Validate() error { + if strings.TrimSpace(c.Server.FrontendURL) != "" { + if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL)) + if err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + if u.RawQuery != "" || u.ForceQuery { + return fmt.Errorf("server.frontend_url invalid: must not include query") + } + if u.User != nil { + return fmt.Errorf("server.frontend_url invalid: must not include userinfo") + } + warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL) + } if c.JWT.ExpireHour <= 0 { return fmt.Errorf("jwt.expire_hour must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f734619f..a645d343 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { if !cfg.Security.URLAllowlist.AllowPrivateHosts { t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") } - if cfg.Security.ResponseHeaders.Enabled { - t.Fatalf("ResponseHeaders.Enabled = true, want false") + if !cfg.Security.ResponseHeaders.Enabled { + t.Fatalf("ResponseHeaders.Enabled = false, want true") + } +} + +func TestLoadDefaultServerMode(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Server.Mode != "release" { + t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release") + } +} + +func TestLoadDefaultDatabaseSSLMode(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Database.SSLMode != "prefer" { + t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer") } } @@ -424,6 +450,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { } } +func TestValidateServerFrontendURL(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com/path" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url with path valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com?utm=1" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with query") + } + + cfg.Server.FrontendURL = "https://user:pass@example.com" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with userinfo") + } + + cfg.Server.FrontendURL = "/relative" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject relative server.frontend_url") + } +} + func TestValidateFrontendRedirectURL(t *testing.T) { if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 9a13b57c..f1c9f303 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -3,6 +3,7 @@ package admin import ( "errors" + "fmt" "strconv" "strings" "sync" @@ -789,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { } ctx := c.Request.Context() - success := 0 - failed := 0 - results := []gin.H{} + // 阶段一:预验证所有账号存在,收集 credentials + type accountUpdate struct { + ID int64 + Credentials map[string]any + } + updates := make([]accountUpdate, 0, len(req.AccountIDs)) for _, accountID := range req.AccountIDs { - // Get account account, err := h.adminService.GetAccount(ctx, accountID) if err != nil { - failed++ - results = append(results, gin.H{ - "account_id": accountID, - "success": false, - "error": "Account not found", - }) - continue + response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID)) + return } - - // Update credentials field if account.Credentials == nil { account.Credentials = make(map[string]any) } - account.Credentials[req.Field] = req.Value + updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials}) + } - // Update account + // 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败) + for _, u := range updates { updateInput := &service.UpdateAccountInput{ - Credentials: account.Credentials, + Credentials: u.Credentials, } - - _, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) - if err != nil { - failed++ - results = append(results, gin.H{ - "account_id": accountID, - "success": false, - "error": err.Error(), - }) - continue + if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil { + response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err)) + return } - - success++ - results = append(results, gin.H{ - "account_id": accountID, - "success": true, - }) } response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + "success": len(updates), + "failed": 0, }) } diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go new file mode 100644 index 00000000..4c47fadb --- /dev/null +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -0,0 +1,200 @@ +//go:build unit + +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。 +type failingAdminService struct { + *stubAdminService + failOnAccountID int64 + updateCallCount atomic.Int64 +} + +func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + f.updateCallCount.Add(1) + if id == f.failOnAccountID { + return nil, errors.New("database error") + } + return f.stubAdminService.UpdateAccount(ctx, id, input) +} + +func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) { + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials) + return router, handler +} + +func TestBatchUpdateCredentials_AllSuccess(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test-uuid", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200") + require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount") +} + +func TestBatchUpdateCredentials_FailFast(t *testing.T) { + // 让第 2 个账号(ID=2)更新时失败 + svc := &failingAdminService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 2, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "org_uuid", + Value: "test-org", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500") + // 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用 + require.Equal(t, int64(2), svc.updateCallCount.Load(), + "fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)") +} + +func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) { + // GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub + svc := &getAccountFailingService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 1, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404") +} + +// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。 +type getAccountFailingService struct { + *stubAdminService + failOnAccountID int64 +} + +func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + if id == f.failOnAccountID { + return nil, errors.New("not found") + } + return f.stubAdminService.GetAccount(ctx, id) +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // intercept_warmup_requests 传入非 bool 类型(string),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": "not-a-bool", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "intercept_warmup_requests 传入非 bool 值应返回 400") +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": true, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "intercept_warmup_requests 传入合法 bool 值应返回 200") +} + +func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入非 string 类型(number),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": 12345, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "account_uuid 传入非 string 值应返回 400") +} + +func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入 null(设置为空),应正常通过 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": nil, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "account_uuid 传入 null 应返回 200") +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 18365186..fab66c04 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) + stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get user usage stats") return @@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) + stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get API key usage stats") return diff --git a/backend/internal/handler/admin/search_truncate_test.go b/backend/internal/handler/admin/search_truncate_test.go new file mode 100644 index 00000000..ffd60e2a --- /dev/null +++ b/backend/internal/handler/admin/search_truncate_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑 +func truncateSearchByRune(search string, maxRunes int) string { + if runes := []rune(search); len(runes) > maxRunes { + return string(runes[:maxRunes]) + } + return search +} + +func TestTruncateSearchByRune(t *testing.T) { + tests := []struct { + name string + input string + maxRunes int + wantLen int // 期望的 rune 长度 + }{ + { + name: "纯中文超长", + input: string(make([]rune, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "纯 ASCII 超长", + input: string(make([]byte, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "空字符串", + input: "", + maxRunes: 100, + wantLen: 0, + }, + { + name: "恰好 100 个字符", + input: string(make([]rune, 100)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "不足 100 字符不截断", + input: "hello世界", + maxRunes: 100, + wantLen: 7, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := truncateSearchByRune(tc.input, tc.maxRunes) + require.Equal(t, tc.wantLen, len([]rune(result))) + }) + } +} + +func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) { + // 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8 + input := "" + for i := 0; i < 101; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + require.Equal(t, 100, len([]rune(result))) + // 验证截断结果是有效的 UTF-8(每个中文字符 3 字节) + require.Equal(t, 300, len(result)) +} + +func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) { + // 50 个 ASCII + 51 个中文 = 101 个 rune + input := "" + for i := 0; i < 50; i++ { + input += "a" + } + for i := 0; i < 51; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + runes := []rune(result) + require.Equal(t, 100, len(runes)) + // 前 50 个应该是 'a',后 50 个应该是 '中' + require.Equal(t, 'a', runes[0]) + require.Equal(t, 'a', runes[49]) + require.Equal(t, '中', runes[50]) + require.Equal(t, '中', runes[99]) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 1c772e7d..0427e77e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) { search := c.Query("search") // 标准化和验证 search 参数 search = strings.TrimSpace(search) - if len(search) > 100 { - search = search[:100] + if runes := []rune(search); len(runes) > 100 { + search = string(runes[:100]) } filters := service.UserListFilters{ diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 34ed63bc..204af666 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -2,6 +2,7 @@ package handler import ( "log/slog" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { return } - // Build frontend base URL from request - scheme := "https" - if c.Request.TLS == nil { - // Check X-Forwarded-Proto header (common in reverse proxy setups) - if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { - scheme = proto - } else { - scheme = "http" - } + frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL) + if frontendBaseURL == "" { + slog.Error("server.frontend_url not configured; cannot build password reset link") + response.InternalError(c, "Password reset is not configured") + return } - frontendBaseURL := scheme + "://" + c.Request.Host // Request password reset (async) // Note: This returns success even if email doesn't exist (to prevent enumeration) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ca4442e4..80bb3539 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -236,7 +236,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + log.Printf("[Gateway] SelectAccount failed: %v", err) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -284,12 +285,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err == nil && canWait { accountWaitCounted = true } - // Ensure the wait counter is decremented if we exit before acquiring the slot. - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -301,14 +302,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } // Slot acquired: no longer waiting in queue. - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } @@ -398,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + log.Printf("[Gateway] SelectAccount failed: %v", err) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -446,11 +446,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err == nil && canWait { accountWaitCounted = true } - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -462,13 +463,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + // Slot acquired: no longer waiting in queue. + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } @@ -967,7 +967,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 选择支持该模型的账号 account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + log.Printf("[Gateway] SelectAccountForModel failed: %v", err) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") return } setOpsSelectedAccount(c, account.ID) @@ -1238,7 +1239,8 @@ func billingErrorDetails(err error) (status int, code, message string) { } msg := pkgerrors.Message(err) if msg == "" { - msg = err.Error() + log.Printf("[Gateway] billing error details: %v", err) + msg = "Billing error" } return http.StatusForbidden, "billing_error", msg } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 835297b8..dba7b70a 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -28,6 +28,7 @@ type OpenAIGatewayHandler struct { errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int + cfg *config.Config } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -54,6 +55,7 @@ func NewOpenAIGatewayHandler( errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, + cfg: cfg, } } @@ -109,7 +111,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } userAgent := c.GetHeader("User-Agent") - if !openai.IsCodexCLIRequest(userAgent) { + isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI) + if !isCodexCLI { existingInstructions, _ := reqBody["instructions"].(string) if strings.TrimSpace(existingInstructions) == "" { if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { @@ -218,7 +221,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err != nil { log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + log.Printf("[OpenAI Gateway] SelectAccount failed: %v", err) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -251,11 +255,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err == nil && canWait { accountWaitCounted = true } - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -267,13 +272,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + // Slot acquired: no longer waiting in queue. + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 129dbfa6..b8182dad 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { return } - stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs) + stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index eb16f09d..1f58eb8e 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -1,6 +1,7 @@ package antigravity import ( + "crypto/rand" "encoding/json" "fmt" "log" @@ -341,12 +342,16 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { return builder.String() } -// generateRandomID 生成随机 ID +// generateRandomID 生成密码学安全的随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" result := make([]byte, 12) - for i := range result { - result[i] = chars[i%len(chars)] + randBytes := make([]byte, 12) + if _, err := rand.Read(randBytes); err != nil { + panic("crypto/rand unavailable: " + err.Error()) + } + for i, b := range randBytes { + result[i] = chars[int(b)%len(chars)] } return string(result) } diff --git a/backend/internal/pkg/antigravity/response_transformer_test.go b/backend/internal/pkg/antigravity/response_transformer_test.go new file mode 100644 index 00000000..9731d906 --- /dev/null +++ b/backend/internal/pkg/antigravity/response_transformer_test.go @@ -0,0 +1,36 @@ +//go:build unit + +package antigravity + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGenerateRandomID_Uniqueness(t *testing.T) { + seen := make(map[string]struct{}, 100) + for i := 0; i < 100; i++ { + id := generateRandomID() + require.Len(t, id, 12, "ID 长度应为 12") + _, dup := seen[id] + require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id) + seen[id] = struct{}{} + } +} + +func TestGenerateRandomID_Charset(t *testing.T) { + const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + validSet := make(map[byte]struct{}, len(validChars)) + for i := 0; i < len(validChars); i++ { + validSet[validChars[i]] = struct{}{} + } + + for i := 0; i < 50; i++ { + id := generateRandomID() + for j := 0; j < len(id); j++ { + _, ok := validSet[id[j]] + require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id) + } + } +} diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 97109c0c..6ab2ff72 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -54,29 +54,34 @@ func normalizeIP(ip string) string { return ip } -// isPrivateIP 检查 IP 是否为私有地址。 -func isPrivateIP(ipStr string) bool { - ip := net.ParseIP(ipStr) - if ip == nil { - return false - } +// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 +var privateNets []*net.IPNet - // 私有 IP 范围 - privateBlocks := []string{ +func init() { + for _, cidr := range []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", - } - - for _, block := range privateBlocks { - _, cidr, err := net.ParseCIDR(block) + } { + _, block, err := net.ParseCIDR(cidr) if err != nil { - continue + panic("invalid CIDR: " + cidr) } - if cidr.Contains(ip) { + privateNets = append(privateNets, block) + } +} + +// isPrivateIP 检查 IP 是否为私有地址。 +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + for _, block := range privateNets { + if block.Contains(ip) { return true } } diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go new file mode 100644 index 00000000..c3c90c74 --- /dev/null +++ b/backend/internal/pkg/ip/ip_test.go @@ -0,0 +1,51 @@ +//go:build unit + +package ip + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + expected bool + }{ + // 私有 IPv4 + {"10.x 私有地址", "10.0.0.1", true}, + {"10.x 私有地址段末", "10.255.255.255", true}, + {"172.16.x 私有地址", "172.16.0.1", true}, + {"172.31.x 私有地址", "172.31.255.255", true}, + {"192.168.x 私有地址", "192.168.1.1", true}, + {"127.0.0.1 本地回环", "127.0.0.1", true}, + {"127.x 回环段", "127.255.255.255", true}, + + // 公网 IPv4 + {"8.8.8.8 公网 DNS", "8.8.8.8", false}, + {"1.1.1.1 公网", "1.1.1.1", false}, + {"172.15.255.255 非私有", "172.15.255.255", false}, + {"172.32.0.0 非私有", "172.32.0.0", false}, + {"11.0.0.1 公网", "11.0.0.1", false}, + + // IPv6 + {"::1 IPv6 回环", "::1", true}, + {"fc00:: IPv6 私有", "fc00::1", true}, + {"fd00:: IPv6 私有", "fd00::1", true}, + {"2001:db8::1 IPv6 公网", "2001:db8::1", false}, + + // 无效输入 + {"空字符串", "", false}, + {"非法字符串", "not-an-ip", false}, + {"不完整 IP", "192.168", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isPrivateIP(tc.ip) + require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip) + }) + } +} diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go index 33caffd7..cfc91bee 100644 --- a/backend/internal/pkg/oauth/oauth.go +++ b/backend/internal/pkg/oauth/oauth.go @@ -50,6 +50,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // Set stores a session diff --git a/backend/internal/pkg/oauth/oauth_test.go b/backend/internal/pkg/oauth/oauth_test.go new file mode 100644 index 00000000..9e59f0f0 --- /dev/null +++ b/backend/internal/pkg/oauth/oauth_test.go @@ -0,0 +1,43 @@ +package oauth + +import ( + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index df972a13..bb120b57 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -47,6 +47,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -92,7 +93,9 @@ func (s *SessionStore) Delete(sessionID string) { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // cleanup removes expired sessions periodically diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go new file mode 100644 index 00000000..f1d616a6 --- /dev/null +++ b/backend/internal/pkg/openai/oauth_test.go @@ -0,0 +1,43 @@ +package openai + +import ( + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 42510986..992f8b0a 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st return nil, fmt.Errorf("apply TLS preset: %w", err) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) _ = conn.Close() return nil, fmt.Errorf("TLS handshake failed: %w", err) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index c0cfd256..c86968b7 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -375,36 +375,19 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } -// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值 func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { - // Use raw SQL for atomic increment to avoid race conditions - // First get current value - m, err := r.activeQuery(). - Where(apikey.IDEQ(id)). - Select(apikey.FieldQuotaUsed). - Only(ctx) + updated, err := r.client.APIKey.UpdateOneID(id). + Where(apikey.DeletedAtIsNil()). + AddQuotaUsed(amount). + Save(ctx) if err != nil { if dbent.IsNotFound(err) { return 0, service.ErrAPIKeyNotFound } return 0, err } - - newValue := m.QuotaUsed + amount - - // Update with new value - affected, err := r.client.APIKey.Update(). - Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). - SetQuotaUsed(newValue). - Save(ctx) - if err != nil { - return 0, err - } - if affected == 0 { - return 0, service.ErrAPIKeyNotFound - } - - return newValue, nil + return updated.QuotaUsed, nil } func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 879a0576..303d7126 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -4,11 +4,14 @@ package repository import ( "context" + "sync" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group s.Require().NoError(s.repo.Create(s.ctx, k), "create api key") return k } + +// --- IncrementQuotaUsed --- + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() { + user := s.mustCreateUser("incr-basic@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil) + + newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5) + s.Require().NoError(err, "IncrementQuotaUsed") + s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5") + + newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsed second") + s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() { + _, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { + user := s.mustCreateUser("incr-deleted@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil) + + s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete") + + _, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") +} + +// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 +// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 +func TestIncrementQuotaUsed_Concurrent(t *testing.T) { + client := testEntClient(t) + repo := NewAPIKeyRepository(client).(*apiKeyRepository) + ctx := context.Background() + + // 创建测试用户和 API Key + u, err := client.User.Create(). + SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com"). + SetPasswordHash("hash"). + SetStatus(service.StatusActive). + SetRole(service.RoleUser). + Save(ctx) + require.NoError(t, err, "create user") + + k := &service.APIKey{ + UserID: u.ID, + Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano), + Name: "Concurrent", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, k), "create api key") + t.Cleanup(func() { + _ = client.APIKey.DeleteOneID(k.ID).Exec(ctx) + _ = client.User.DeleteOneID(u.ID).Exec(ctx) + }) + + // 10 个 goroutine 各递增 1.0,总计应为 10.0 + const goroutines = 10 + const increment = 1.0 + var wg sync.WaitGroup + errs := make([]error, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment) + }(i) + } + wg.Wait() + + for i, e := range errs { + require.NoError(t, e, "goroutine %d failed", i) + } + + // 验证最终结果 + got, err := repo.GetByID(ctx, k.ID) + require.NoError(t, err, "GetByID") + require.Equal(t, float64(goroutines)*increment, got.QuotaUsed, + "并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed) +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index ac5803a1..50ea0da9 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "math/rand" "strconv" "time" @@ -16,8 +17,15 @@ const ( billingBalanceKeyPrefix = "billing:balance:" billingSubKeyPrefix = "billing:sub:" billingCacheTTL = 5 * time.Minute + billingCacheJitter = 30 * time.Second ) +// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 +func jitteredTTL() time.Duration { + jitter := time.Duration(rand.Int63n(int64(2*billingCacheJitter))) - billingCacheJitter + return billingCacheTTL + jitter +} + // billingBalanceKey generates the Redis key for user balance cache. func billingBalanceKey(userID int64) string { return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) @@ -82,14 +90,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6 func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { key := billingBalanceKey(userID) - return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() + return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err() } func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { key := billingBalanceKey(userID) - _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() + _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) + return err } return nil } @@ -163,16 +172,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID pipe := c.rdb.Pipeline() pipe.HSet(ctx, key, fields) - pipe.Expire(ctx, key, billingCacheTTL) + pipe.Expire(ctx, key, jitteredTTL()) _, err := pipe.Exec(ctx) return err } func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { key := billingSubKey(userID, groupID) - _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() + _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) + return err } return nil } diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go index 2f7c69a7..4b7377b1 100644 --- a/backend/internal/repository/billing_cache_integration_test.go +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() { } } +// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { + tests := []struct { + name string + fn func(ctx context.Context, cache service.BillingCache) + expectErr bool + }{ + { + name: "key_not_exists_returns_nil", + fn: func(ctx context.Context, cache service.BillingCache) { + // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误 + err := cache.DeductUserBalance(ctx, 99999, 1.0) + require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil") + }, + }, + { + name: "existing_key_deducts_successfully", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0)) + err := cache.DeductUserBalance(ctx, 200, 10.0) + require.NoError(s.T(), err, "DeductUserBalance should succeed") + + bal, err := cache.GetUserBalance(ctx, 200) + require.NoError(s.T(), err) + require.Equal(s.T(), 40.0, bal, "余额应为 40.0") + }, + }, + { + name: "cancelled_context_propagates_error", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() // 立即取消 + + err := cache.DeductUserBalance(cancelCtx, 201, 10.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + tt.fn(ctx, cache) + }) + } +} + +// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() { + s.Run("key_not_exists_returns_nil", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0) + require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil") + }) + + s.Run("cancelled_context_propagates_error", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + data := &service.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + Version: 1, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }) +} + func TestBillingCacheSuite(t *testing.T) { suite.Run(t, new(BillingCacheSuite)) } diff --git a/backend/internal/repository/billing_cache_test.go b/backend/internal/repository/billing_cache_test.go index 7d3fd19d..2de1da87 100644 --- a/backend/internal/repository/billing_cache_test.go +++ b/backend/internal/repository/billing_cache_test.go @@ -5,6 +5,7 @@ package repository import ( "math" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) { }) } } + +func TestJitteredTTL(t *testing.T) { + const ( + minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s + maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s + ) + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl) + require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl) + } +} + +func TestJitteredTTL_HasVariation(t *testing.T) { + // 多次调用应该产生不同的值(验证抖动存在) + seen := make(map[time.Duration]struct{}, 50) + for i := 0; i < 50; i++ { + seen[jitteredTTL()] = struct{}{} + } + // 50 次调用中应该至少有 2 个不同的值 + require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值") +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index d8cec491..412a8164 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -183,7 +183,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination q = q.Where(group.IsExclusiveEQ(*isExclusive)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } diff --git a/backend/internal/repository/promo_code_repo.go b/backend/internal/repository/promo_code_repo.go index 98b422e0..95ce687a 100644 --- a/backend/internal/repository/promo_code_repo.go +++ b/backend/internal/repository/promo_code_repo.go @@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina q = q.Where(promocode.CodeContainsFold(search)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo q := r.client.PromoCodeUsage.Query(). Where(promocodeusage.PromoCodeIDEQ(promoCodeID)) - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 2db1764f..d51669aa 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -24,6 +24,22 @@ import ( const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at" +// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL +var dateFormatWhitelist = map[string]string{ + "hour": "YYYY-MM-DD HH24:00", + "day": "YYYY-MM-DD", + "week": "IYYY-IW", + "month": "YYYY-MM", +} + +// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值 +func safeDateFormat(granularity string) string { + if f, ok := dateFormatWhitelist[granularity]; ok { + return f + } + return "YYYY-MM-DD" +} + type usageLogRepository struct { client *dbent.Client sql sqlExecutor @@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, } func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) return logs, nil, err } @@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string { } func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint // GetAPIKeyUsageTrend returns usage trend data grouped by API key and date func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_keys AS ( @@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, // GetUserUsageTrend returns usage trend data grouped by user and date func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_users AS ( @@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey // GetUserUsageTrendByUserID 获取指定用户的使用趋势 func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT @@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats -// GetBatchUserUsageStats gets today and total actual_cost for multiple users -func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { +// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) if len(userIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range userIDs { result[id] = &BatchUserUsageStats{UserID: id} } @@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs query := ` SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE user_id = ANY($1) + WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) if err != nil { return nil, err } @@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs // BatchAPIKeyUsageStats represents usage stats for a single API key type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats -// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys -func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { +// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) if len(apiKeyIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range apiKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } @@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe query := ` SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE api_key_id = ANY($1) + WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) if err != nil { return nil, err } @@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe // GetUsageTrendWithFilters returns usage trend data with optional filters func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index eb220f22..8cb3aab1 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchUserUsageStats") s.Require().Len(stats, 2) s.Require().NotNil(stats[user1.ID]) @@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { } func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchAPIKeyUsageStats") s.Require().Len(stats, 2) } func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go new file mode 100644 index 00000000..d0e14ffd --- /dev/null +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -0,0 +1,41 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSafeDateFormat(t *testing.T) { + tests := []struct { + name string + granularity string + expected string + }{ + // 合法值 + {"hour", "hour", "YYYY-MM-DD HH24:00"}, + {"day", "day", "YYYY-MM-DD"}, + {"week", "week", "IYYY-IW"}, + {"month", "month", "YYYY-MM"}, + + // 非法值回退到默认 + {"空字符串", "", "YYYY-MM-DD"}, + {"未知粒度 year", "year", "YYYY-MM-DD"}, + {"未知粒度 minute", "minute", "YYYY-MM-DD"}, + + // 恶意字符串 + {"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"}, + {"带引号", "day'", "YYYY-MM-DD"}, + {"带括号", "day)", "YYYY-MM-DD"}, + {"Unicode", "日", "YYYY-MM-DD"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := safeDateFormat(tc.granularity) + require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity) + }) + } +} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index efef0452..804af548 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -592,13 +592,13 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil) + userService := service.NewUserService(userRepo, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) - subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil) + subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) @@ -1602,11 +1602,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index 8f30107c..4167b7ab 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -176,6 +176,12 @@ func validateJWTForAdmin( return false } + // 校验 TokenVersion,确保管理员改密后旧 token 失效 + if claims.TokenVersion != user.TokenVersion { + AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)") + return false + } + // 检查管理员权限 if !user.IsAdmin() { AbortWithError(c, 403, "FORBIDDEN", "Admin access required") diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go new file mode 100644 index 00000000..7b6d4ce8 --- /dev/null +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -0,0 +1,194 @@ +//go:build unit + +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} + authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil) + + admin := &service.User{ + ID: 1, + Email: "admin@example.com", + Role: service.RoleAdmin, + Status: service.StatusActive, + TokenVersion: 2, + Concurrency: 1, + } + + userRepo := &stubUserRepo{ + getByID: func(ctx context.Context, id int64) (*service.User, error) { + if id != admin.ID { + return nil, service.ErrUserNotFound + } + clone := *admin + return &clone, nil + }, + } + userService := service.NewUserService(userRepo, nil, nil) + + router := gin.New() + router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + t.Run("token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("websocket_token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) +} + +type stubUserRepo struct { + getByID func(ctx context.Context, id int64) (*service.User, error) +} + +func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error { + panic("unexpected Create call") +} + +func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + if s.getByID == nil { + panic("GetByID not stubbed") + } + return s.getByID(ctx, id) +} + +func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + panic("unexpected GetByEmail call") +} + +func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error { + panic("unexpected Update call") +} + +func (s *stubUserRepo) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected UpdateBalance call") +} + +func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected DeductBalance call") +} + +func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { + panic("unexpected UpdateConcurrency call") +} + +func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + panic("unexpected ExistsByEmail call") +} + +func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error { + panic("unexpected EnableTotp call") +} + +func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error { + panic("unexpected DisableTotp call") +} diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 2f739357..4525aee7 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -3,7 +3,6 @@ package middleware import ( "context" "errors" - "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { - // 订阅模式:验证订阅 + // 订阅模式:获取订阅(L1 缓存 + singleflight) subscription, err := subscriptionService.GetActiveSubscription( c.Request.Context(), apiKey.User.ID, @@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 验证订阅状态(是否过期、暂停等) - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) - return - } - - // 激活滑动窗口(首次使用时) - if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to activate subscription windows: %v", err) - } - - // 检查并重置过期窗口 - if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to reset subscription windows: %v", err) - } - - // 预检查用量限制(使用0作为额外费用进行预检查) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error()) + // 合并验证 + 限额检查(纯内存操作) + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + code := "SUBSCRIPTION_INVALID" + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + code = "USAGE_LIMIT_EXCEEDED" + status = 429 + } + AbortWithError(c, status, code, err.Error()) return } // 将订阅信息存入上下文 c.Set(string(ContextKeySubscription), subscription) + + // 窗口维护异步化(不阻塞请求) + // 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race + if needsMaintenance { + maintenanceCopy := *subscription + go subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { // 余额模式:检查用户余额 if apiKey.User.Balance <= 0 { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 9d514818..3605aaff 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeSimple} apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) - subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, } - subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index 7d82f183..b54a0b0e 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") + c.Writer.Header().Set("Access-Control-Max-Age", "86400") // 处理预检请求 if c.Request.Method == http.MethodOptions { diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 304c5781..7698223e 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -36,8 +36,8 @@ type UsageLogRepository interface { GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) - GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) - GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) + GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) + GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 3d3c9cca..8f93ac3a 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1582,6 +1582,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque return changed, nil } +// ForwardUpstream 透传请求到上游 Antigravity 服务 +// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + // 获取上游配置 + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, fmt.Errorf("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 解析请求获取模型信息 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, fmt.Errorf("parse claude request: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, fmt.Errorf("missing model") + } + originalModel := claudeReq.Model + billingModel := originalModel + + // 构建上游请求 URL + upstreamURL := baseURL + "/v1/messages" + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create upstream request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) // Claude API 兼容 + + // 透传 Claude 相关 headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + log.Printf("%s upstream request failed: %v", prefix, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 429 错误时标记账号限流 + if resp.StatusCode == http.StatusTooManyRequests { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude) + } + + // 透传上游错误 + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(resp.StatusCode) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + Model: billingModel, + }, nil + } + + // 处理成功响应(流式/非流式) + var usage *ClaudeUsage + var firstTokenMs *int + + if claudeReq.Stream { + // 流式响应:透传 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime) + } else { + // 非流式响应:直接透传 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read upstream response: %w", err) + } + + // 提取 usage + usage = s.extractClaudeUsage(respBody) + + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(http.StatusOK) + _, _ = c.Writer.Write(respBody) + } + + // 构建计费结果 + duration := time.Since(startTime) + log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + + return &ForwardResult{ + Model: billingModel, + Stream: claudeReq.Stream, + Duration: duration, + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + }, + }, nil +} + +// streamUpstreamResponse 透传上游流式响应并提取 usage +func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) { + usage := &ClaudeUsage{} + var firstTokenMs *int + var firstTokenRecorded bool + + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 1024*1024) + + for scanner.Scan() { + line := scanner.Bytes() + + // 记录首 token 时间 + if !firstTokenRecorded && len(line) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + firstTokenRecorded = true + } + + // 尝试从 message_delta 或 message_stop 事件提取 usage + if bytes.HasPrefix(line, []byte("data: ")) { + dataStr := bytes.TrimPrefix(line, []byte("data: ")) + var event map[string]any + if json.Unmarshal(dataStr, &event) == nil { + if u, ok := event["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheCreationInputTokens = int(v) + } + } + } + } + + // 透传行 + _, _ = c.Writer.Write(line) + _, _ = c.Writer.Write([]byte("\n")) + c.Writer.Flush() + } + + return usage, firstTokenMs +} + +// extractClaudeUsage 从非流式 Claude 响应提取 usage +func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + var resp map[string]any + if json.Unmarshal(body, &resp) != nil { + return usage + } + if u, ok := resp["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok { + usage.CacheCreationInputTokens = int(v) + } + } + return usage +} + // ForwardGemini 转发 Gemini 协议请求 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() @@ -1613,7 +1815,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Usage: ClaudeUsage{}, Model: originalModel, Stream: false, - Duration: time.Since(time.Now()), + Duration: time.Since(startTime), FirstTokenMs: nil, }, nil default: @@ -2288,7 +2490,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2309,7 +2512,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2320,7 +2524,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2445,7 +2649,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2473,7 +2678,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2484,7 +2690,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2888,7 +3094,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) var firstTokenMs *int var last map[string]any @@ -2914,7 +3121,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2925,7 +3133,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -3068,7 +3276,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -3100,7 +3309,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -3111,7 +3321,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 91cefc28..ecad4171 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -391,3 +392,37 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode) require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } + +func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n")) + _, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n")) + }() + + svc := &AntigravityGatewayService{} + start := time.Now().Add(-10 * time.Millisecond) + usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start) + _ = pr.Close() + + require.NotNil(t, usage) + require.Equal(t, 1, usage.InputTokens) + // 第二次事件覆盖 output_tokens + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 3, usage.CacheReadInputTokens) + require.Equal(t, 4, usage.CacheCreationInputTokens) + + if firstTokenMs == nil { + t.Fatalf("expected firstTokenMs to be set") + } + // 确保有透传输出 + require.True(t, strings.Contains(writer.Body.String(), "data:")) +} diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index f5bba7d0..4bf84ddd 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -6,8 +6,7 @@ import ( "encoding/hex" "errors" "fmt" - "math/rand" - "sync" + "math/rand/v2" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct { singleflight bool } -var ( - jitterRandMu sync.Mutex - // 认证缓存抖动使用独立随机源,避免全局 Seed - jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) -) - func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { if cfg == nil { return apiKeyAuthCacheConfig{} @@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool { return c.negativeTTL > 0 } +// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。 +// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。 func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { if ttl <= 0 { return ttl @@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { percent = 100 } delta := float64(percent) / 100 - jitterRandMu.Lock() - randVal := jitterRand.Float64() - jitterRandMu.Unlock() + randVal := rand.Float64() factor := 1 - delta + randVal*(2*delta) if factor <= 0 { return ttl diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index cd11923e..32704a94 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } -func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { - stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs) +func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { + stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch user usage stats: %w", err) } return stats, nil } -func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { - stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) +func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch api key usage stats: %w", err) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 32646b11..efe24ca4 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4145,7 +4145,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -4164,7 +4165,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -4175,7 +4177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -4481,24 +4483,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } // replaceModelInResponseBody 替换响应体中的model字段 +// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化 func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody } - - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body - } - - resp["model"] = toModel - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - - return newBody + return body } // RecordUsageInput 记录使用量的输入参数 diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go new file mode 100644 index 00000000..48667f58 --- /dev/null +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + + svc := &GatewayService{ + cfg: cfg, + rateLimitService: &RateLimitService{}, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // Minimal SSE event to trigger parseSSEUsage + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", nil, false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 3, result.usage.InputTokens) + require.Equal(t, 7, result.usage.OutputTokens) +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index cea81693..6b6e8398 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -2,19 +2,7 @@ package service import ( _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" "strings" - "time" -) - -const ( - opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt" - codexCacheTTL = 15 * time.Minute ) //go:embed prompts/codex_cli_instructions.md @@ -77,12 +65,6 @@ type codexTransformResult struct { PromptCacheKey string } -type opencodeCacheMetadata struct { - ETag string `json:"etag"` - LastFetch string `json:"lastFetch,omitempty"` - LastChecked int64 `json:"lastChecked"` -} - func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 @@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string { return "" } -func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { - cacheDir := codexCachePath("") - if cacheDir == "" { - return "" - } - cacheFile := filepath.Join(cacheDir, cacheFileName) - metaFile := filepath.Join(cacheDir, metaFileName) - - var cachedContent string - if content, ok := readFile(cacheFile); ok { - cachedContent = content - } - - var meta opencodeCacheMetadata - if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" { - if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { - return cachedContent - } - } - - content, etag, status, err := fetchWithETag(url, meta.ETag) - if err == nil && status == http.StatusNotModified && cachedContent != "" { - return cachedContent - } - if err == nil && status >= 200 && status < 300 && content != "" { - _ = writeFile(cacheFile, content) - meta = opencodeCacheMetadata{ - ETag: etag, - LastFetch: time.Now().UTC().Format(time.RFC3339), - LastChecked: time.Now().UnixMilli(), - } - _ = writeJSON(metaFile, meta) - return content - } - - return cachedContent -} - func getOpenCodeCodexHeader() string { - // 优先从 opencode 仓库缓存获取指令。 - opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") - - // 若 opencode 指令可用,直接返回。 - if opencodeInstructions != "" { - return opencodeInstructions - } - - // 否则回退使用本地 Codex CLI 指令。 + // 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。 + // 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。 return getCodexCLIInstructions() } @@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string { } // applyInstructions 处理 instructions 字段 -// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令) -// isCodexCLI=false: 优先使用 opencode 指令覆盖 +// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令) +// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { if isCodexCLI { return applyCodexCLIInstructions(reqBody) @@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { } // applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions -// 仅在 instructions 为空时添加 opencode 指令 +// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源) func applyCodexCLIInstructions(reqBody map[string]any) bool { if !isInstructionsEmpty(reqBody) { return false // 已有有效 instructions,不修改 } - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + instructions := strings.TrimSpace(getCodexCLIInstructions()) if instructions != "" { reqBody["instructions"] = instructions return true @@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool { return false } -// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令 -// 优先使用 opencode 指令覆盖 +// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名) +// 优先使用内置 Codex CLI 指令覆盖 func applyOpenCodeInstructions(reqBody map[string]any) bool { instructions := strings.TrimSpace(getOpenCodeCodexHeader()) existingInstructions, _ := reqBody["instructions"].(string) @@ -489,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool { return modified } - -func codexCachePath(filename string) string { - home, err := os.UserHomeDir() - if err != nil { - return "" - } - cacheDir := filepath.Join(home, ".opencode", "cache") - if filename == "" { - return cacheDir - } - return filepath.Join(cacheDir, filename) -} - -func readFile(path string) (string, bool) { - if path == "" { - return "", false - } - data, err := os.ReadFile(path) - if err != nil { - return "", false - } - return string(data), true -} - -func writeFile(path, content string) error { - if path == "" { - return fmt.Errorf("empty cache path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - return os.WriteFile(path, []byte(content), 0o644) -} - -func loadJSON(path string, target any) bool { - data, err := os.ReadFile(path) - if err != nil { - return false - } - if err := json.Unmarshal(data, target); err != nil { - return false - } - return true -} - -func writeJSON(path string, value any) error { - if path == "" { - return fmt.Errorf("empty json path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - data, err := json.Marshal(value) - if err != nil { - return err - } - return os.WriteFile(path, data, 0o644) -} - -func fetchWithETag(url, etag string) (string, string, int, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return "", "", 0, err - } - req.Header.Set("User-Agent", "sub2api-codex") - if etag != "" { - req.Header.Set("If-None-Match", etag) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", "", 0, err - } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", "", resp.StatusCode, err - } - return string(body), resp.Header.Get("etag"), resp.StatusCode, nil -} diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index cc0acafc..106bcee8 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -1,18 +1,13 @@ package service import ( - "encoding/json" - "os" - "path/filepath" "testing" - "time" "github.com/stretchr/testify/require" ) func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { // 续链场景:保留 item_reference 与 id,但不再强制 store=true。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.2", @@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { // 显式 store=true 也会强制为 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) { // 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { } func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) { - setupCodexCache(t) - reqBody := map[string]any{ "model": "gpt-5.1", "tools": []any{ @@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -187,88 +176,27 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { for input, expected := range cases { require.Equal(t, expected, normalizeCodexModel(input)) } - } func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { - // Codex CLI 场景:已有 instructions 时保持不变 - setupCodexCache(t) + // Codex CLI 场景:已有 instructions 时不修改 reqBody := map[string]any{ "model": "gpt-5.1", - "instructions": "user custom instructions", - "input": []any{}, + "instructions": "existing instructions", } - result := applyCodexOAuthTransform(reqBody, true) + result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) - require.Equal(t, "user custom instructions", instructions) - // instructions 未变,但其他字段(如 store、stream)可能被修改 - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) { - // Codex CLI 场景:无 instructions 时补充内置指令 - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, true) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.NotEmpty(t, instructions) - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header) - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, false) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容 - require.True(t, result.Modified) -} - -func setupCodexCache(t *testing.T) { - t.Helper() - - // 使用临时 HOME 避免触发网络拉取 header。 - // Windows 使用 USERPROFILE,Unix 使用 HOME。 - tempDir := t.TempDir() - t.Setenv("HOME", tempDir) - t.Setenv("USERPROFILE", tempDir) - - cacheDir := filepath.Join(tempDir, ".opencode", "cache") - require.NoError(t, os.MkdirAll(cacheDir, 0o755)) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644)) - - meta := map[string]any{ - "etag": "", - "lastFetch": time.Now().UTC().Format(time.RFC3339), - "lastChecked": time.Now().UnixMilli(), - } - data, err := json.Marshal(meta) - require.NoError(t, err) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) + require.Equal(t, "existing instructions", instructions) + // Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变 + _ = result } func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) { // Codex CLI 场景:无 instructions 时补充默认值 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -284,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T } func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令覆盖 - setupCodexCache(t) + // 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖 reqBody := map[string]any{ "model": "gpt-5.1", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index fbe81cb4..450075fb 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -24,6 +24,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) const ( @@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified := false originalModel := reqModel - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) @@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + if usage == nil { + usage = &OpenAIUsage{} + } + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) return &OpenAIForwardResult{ @@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. req.Header.Set("user-agent", customUA) } + // 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。 + // 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。 + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", "codex_cli_rs/0.98.0") + } + // Ensure required headers exist if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st return line } - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return line - } - - // Replace model in response - if m, ok := event["model"].(string); ok && m == fromModel { - event["model"] = toModel - newData, err := json.Marshal(event) + // 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化 + if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "model", toModel) if err != nil { return line } - return "data: " + string(newData) + return "data: " + newData } - // Check nested response - if response, ok := event["response"].(map[string]any); ok { - if m, ok := response["model"].(string); ok && m == fromModel { - response["model"] = toModel - newData, err := json.Marshal(event) - if err != nil { - return line - } - return "data: " + string(newData) + // 检查嵌套的 response.model 字段 + if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "response.model", toModel) + if err != nil { + return line } + return "data: " + newData } return line @@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro } func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody } - - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body - } - - resp["model"] = toModel - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - - return newBody + return body } // OpenAIRecordUsageInput input for recording usage diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 1c2c81ca..91dbaa4b 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" ) type stubOpenAIAccountRepo struct { @@ -1082,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { } } +func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + require.Equal(t, 2, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) +} + func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1165,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { t.Fatalf("expected non-allowlisted host to fail") } } + +// ==================== P1-08 修复:model 替换性能优化测试 ==================== + +func TestReplaceModelInSSELine(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + line string + from string + to string + expected string + }{ + { + name: "顶层 model 字段替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "my-custom-model", + expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`, + }, + { + name: "嵌套 response.model 替换", + line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`, + }, + { + name: "model 不匹配时不替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段时不替换", + line: `data: {"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "空 data 行", + line: `data: `, + from: "gpt-4o", + to: "my-model", + expected: `data: `, + }, + { + name: "[DONE] 行", + line: `data: [DONE]`, + from: "gpt-4o", + to: "my-model", + expected: `data: [DONE]`, + }, + { + name: "非 data: 前缀行", + line: `event: message`, + from: "gpt-4o", + to: "my-model", + expected: `event: message`, + }, + { + name: "非法 JSON 不替换", + line: `data: {invalid json}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {invalid json}`, + }, + { + name: "无空格 data: 格式", + line: `data:{"id":"x","model":"gpt-4o"}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"x","model":"my-model"}`, + }, + { + name: "model 名含特殊字符", + line: `data: {"model":"org/model-v2.1-beta"}`, + from: "org/model-v2.1-beta", + to: "custom/alias", + expected: `data: {"model":"custom/alias"}`, + }, + { + name: "空行", + line: "", + from: "gpt-4o", + to: "my-model", + expected: "", + }, + { + name: "保持其他字段不变", + line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + }, + { + name: "顶层优先于嵌套:同时存在两个 model", + line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`, + from: "gpt-4o", + to: "replaced", + expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInSSEBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "多行 SSE body 替换", + body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + }, + { + name: "无需替换的 body", + body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + }, + { + name: "混合 event 和 data 行", + body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n", + from: "gpt-4o", + to: "alias", + expected: "event: message\ndata: {\"model\":\"alias\"}\n\n", + }, + { + name: "空 body", + body: "", + from: "gpt-4o", + to: "alias", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInResponseBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "替换顶层 model", + body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`, + }, + { + name: "model 不匹配不替换", + body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段不替换", + body: `{"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "非法 JSON 返回原值", + body: `not json`, + from: "gpt-4o", + to: "alias", + expected: `not json`, + }, + { + name: "空 body 返回原值", + body: ``, + from: "gpt-4o", + to: "alias", + expected: ``, + }, + { + name: "保持嵌套结构不变", + body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to) + require.Equal(t, tt.expected, string(got)) + }) + } +} diff --git a/backend/internal/service/sse_scanner_buffer_pool.go b/backend/internal/service/sse_scanner_buffer_pool.go new file mode 100644 index 00000000..7475547f --- /dev/null +++ b/backend/internal/service/sse_scanner_buffer_pool.go @@ -0,0 +1,24 @@ +package service + +import "sync" + +const sseScannerBuf64KSize = 64 * 1024 + +type sseScannerBuf64K [sseScannerBuf64KSize]byte + +var sseScannerBuf64KPool = sync.Pool{ + New: func() any { + return new(sseScannerBuf64K) + }, +} + +func getSSEScannerBuf64K() *sseScannerBuf64K { + return sseScannerBuf64KPool.Get().(*sseScannerBuf64K) +} + +func putSSEScannerBuf64K(buf *sseScannerBuf64K) { + if buf == nil { + return + } + sseScannerBuf64KPool.Put(buf) +} diff --git a/backend/internal/service/sse_scanner_buffer_pool_test.go b/backend/internal/service/sse_scanner_buffer_pool_test.go new file mode 100644 index 00000000..09b8ad21 --- /dev/null +++ b/backend/internal/service/sse_scanner_buffer_pool_test.go @@ -0,0 +1,19 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSSEScannerBuf64KPool_GetPutDoesNotPanic(t *testing.T) { + buf := getSSEScannerBuf64K() + require.NotNil(t, buf) + require.Equal(t, sseScannerBuf64KSize, len(buf[:])) + + buf[0] = 1 + putSSEScannerBuf64K(buf) + + // 允许传入 nil,确保不会 panic + putSSEScannerBuf64K(nil) +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 3c42852e..21694d41 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -4,10 +4,15 @@ import ( "context" "fmt" "log" + "math/rand/v2" + "strconv" "time" + "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/dgraph-io/ristretto" + "golang.org/x/sync/singleflight" ) // MaxExpiresAt is the maximum allowed expiration date (year 2099) @@ -35,15 +40,76 @@ type SubscriptionService struct { groupRepo GroupRepository userSubRepo UserSubscriptionRepository billingCacheService *BillingCacheService + + // L1 缓存:加速中间件热路径的订阅查询 + subCacheL1 *ristretto.Cache + subCacheGroup singleflight.Group + subCacheTTL time.Duration + subCacheJitter int // 抖动百分比 } // NewSubscriptionService 创建订阅服务 -func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService { - return &SubscriptionService{ +func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService { + svc := &SubscriptionService{ groupRepo: groupRepo, userSubRepo: userSubRepo, billingCacheService: billingCacheService, } + svc.initSubCache(cfg) + return svc +} + +// initSubCache 初始化订阅 L1 缓存 +func (s *SubscriptionService) initSubCache(cfg *config.Config) { + if cfg == nil { + return + } + sc := cfg.SubscriptionCache + if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 { + return + } + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: int64(sc.L1Size) * 10, + MaxCost: int64(sc.L1Size), + BufferItems: 64, + }) + if err != nil { + log.Printf("Warning: failed to init subscription L1 cache: %v", err) + return + } + s.subCacheL1 = cache + s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second + s.subCacheJitter = sc.JitterPercent +} + +// subCacheKey 生成订阅缓存 key(热路径,避免 fmt.Sprintf 开销) +func subCacheKey(userID, groupID int64) string { + return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10) +} + +// jitteredTTL 为 TTL 添加抖动,避免集中过期 +func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration { + if ttl <= 0 || s.subCacheJitter <= 0 { + return ttl + } + pct := s.subCacheJitter + if pct > 100 { + pct = 100 + } + delta := float64(pct) / 100 + factor := 1 - delta + rand.Float64()*(2*delta) + if factor <= 0 { + return ttl + } + return time.Duration(float64(ttl) * factor) +} + +// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存 +func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) { + if s.subCacheL1 == nil { + return + } + s.subCacheL1.Del(subCacheKey(userID, groupID)) } // AssignSubscriptionInput 分配订阅输入 @@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass } // 失效订阅缓存 + s.InvalidateSubCache(input.UserID, input.GroupID) if s.billingCacheService != nil { userID, groupID := input.UserID, input.GroupID go func() { @@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 失效订阅缓存 + s.InvalidateSubCache(input.UserID, input.GroupID) if s.billingCacheService != nil { userID, groupID := input.UserID, input.GroupID go func() { @@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 失效订阅缓存 + s.InvalidateSubCache(input.UserID, input.GroupID) if s.billingCacheService != nil { userID, groupID := input.UserID, input.GroupID go func() { @@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti } // 失效订阅缓存 + s.InvalidateSubCache(sub.UserID, sub.GroupID) if s.billingCacheService != nil { userID, groupID := sub.UserID, sub.GroupID go func() { @@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti } // 失效订阅缓存 + s.InvalidateSubCache(sub.UserID, sub.GroupID) if s.billingCacheService != nil { userID, groupID := sub.UserID, sub.GroupID go func() { @@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc } // GetActiveSubscription 获取用户对特定分组的有效订阅 +// 使用 L1 缓存 + singleflight 加速中间件热路径。 +// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。 func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) { - sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) - if err != nil { - return nil, ErrSubscriptionNotFound + key := subCacheKey(userID, groupID) + + // L1 缓存命中:返回浅拷贝 + if s.subCacheL1 != nil { + if v, ok := s.subCacheL1.Get(key); ok { + if sub, ok := v.(*UserSubscription); ok { + cp := *sub + return &cp, nil + } + } } - return sub, nil + + // singleflight 防止并发击穿 + value, err, _ := s.subCacheGroup.Do(key, func() (any, error) { + sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) + if err != nil { + return nil, ErrSubscriptionNotFound + } + // 写入 L1 缓存 + if s.subCacheL1 != nil { + _ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL)) + } + return sub, nil + }) + if err != nil { + return nil, err + } + // singleflight 返回的也是缓存指针,需要浅拷贝 + cp := *value.(*UserSubscription) + return &cp, nil } // ListUserSubscriptions 获取用户的所有订阅 @@ -521,9 +619,12 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use needsInvalidateCache = true } - // 如果有窗口被重置,失效 Redis 缓存以保持一致性 - if needsInvalidateCache && s.billingCacheService != nil { - _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) + // 如果有窗口被重置,失效缓存以保持一致性 + if needsInvalidateCache { + s.InvalidateSubCache(sub.UserID, sub.GroupID) + if s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) + } } return nil @@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub return nil } +// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用) +// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。 +// 返回 needsMaintenance 表示是否需要异步执行窗口维护。 +func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) { + // 1. 验证订阅状态 + if sub.Status == SubscriptionStatusExpired { + return false, ErrSubscriptionExpired + } + if sub.Status == SubscriptionStatusSuspended { + return false, ErrSubscriptionSuspended + } + if sub.IsExpired() { + return false, ErrSubscriptionExpired + } + + // 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户 + // 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成 + if sub.NeedsDailyReset() { + sub.DailyUsageUSD = 0 + needsMaintenance = true + } + if sub.NeedsWeeklyReset() { + sub.WeeklyUsageUSD = 0 + needsMaintenance = true + } + if sub.NeedsMonthlyReset() { + sub.MonthlyUsageUSD = 0 + needsMaintenance = true + } + if !sub.IsWindowActivated() { + needsMaintenance = true + } + + // 3. 检查用量限额 + if !sub.CheckDailyLimit(group, 0) { + return needsMaintenance, ErrDailyLimitExceeded + } + if !sub.CheckWeeklyLimit(group, 0) { + return needsMaintenance, ErrWeeklyLimitExceeded + } + if !sub.CheckMonthlyLimit(group, 0) { + return needsMaintenance, ErrMonthlyLimitExceeded + } + + return needsMaintenance, nil +} + +// DoWindowMaintenance 异步执行窗口维护(激活+重置) +// 使用独立 context,不受请求取消影响。 +// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用, +// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误, +// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。 +func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 激活窗口(首次使用时) + if !sub.IsWindowActivated() { + if err := s.CheckAndActivateWindow(ctx, sub); err != nil { + log.Printf("Failed to activate subscription windows: %v", err) + } + } + + // 重置过期窗口 + if err := s.CheckAndResetWindows(ctx, sub); err != nil { + log.Printf("Failed to reset subscription windows: %v", err) + } + + // 失效 L1 缓存,确保后续请求拿到更新后的数据 + s.InvalidateSubCache(sub.UserID, sub.GroupID) +} + // RecordUsage 记录使用量到订阅 func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error { return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD) diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 5594e53f..f21a2855 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star } // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. -func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { - stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) +func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch api key usage stats: %w", err) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 1bfb392e..510e734e 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -3,6 +3,8 @@ package service import ( "context" "fmt" + "log" + "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -62,13 +64,15 @@ type ChangePasswordRequest struct { type UserService struct { userRepo UserRepository authCacheInvalidator APIKeyAuthCacheInvalidator + billingCache BillingCache } // NewUserService 创建用户服务实例 -func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { +func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService { return &UserService{ userRepo: userRepo, authCacheInvalidator: authCacheInvalidator, + billingCache: billingCache, } } @@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } + if s.billingCache != nil { + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { + log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + } + }() + } return nil } diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go new file mode 100644 index 00000000..0f355d70 --- /dev/null +++ b/backend/internal/service/user_service_test.go @@ -0,0 +1,186 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// --- mock: UserRepository --- + +type mockUserRepo struct { + updateBalanceErr error + updateBalanceFn func(ctx context.Context, id int64, amount float64) error +} + +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) Update(context.Context, *User) error { return nil } +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + if m.updateBalanceFn != nil { + return m.updateBalanceFn(ctx, id, amount) + } + return m.updateBalanceErr +} +func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil } +func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } +func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } + +// --- mock: APIKeyAuthCacheInvalidator --- + +type mockAuthCacheInvalidator struct { + invalidatedUserIDs []int64 + mu sync.Mutex +} + +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID) +} + +// --- mock: BillingCache --- + +type mockBillingCache struct { + invalidateErr error + invalidateCallCount atomic.Int64 + invalidatedUserIDs []int64 + mu sync.Mutex +} + +func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil } +func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error { + m.invalidateCallCount.Add(1) + m.mu.Lock() + defer m.mu.Unlock() + m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID) + return m.invalidateErr +} +func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) { + return nil, nil +} +func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error { + return nil +} +func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error { + return nil +} +func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error { + return nil +} + +// --- 测试 --- + +func TestUpdateBalance_Success(t *testing.T) { + repo := &mockUserRepo{} + cache := &mockBillingCache{} + svc := NewUserService(repo, nil, cache) + + err := svc.UpdateBalance(context.Background(), 42, 100.0) + require.NoError(t, err) + + // 等待异步 goroutine 完成 + require.Eventually(t, func() bool { + return cache.invalidateCallCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance") + + cache.mu.Lock() + defer cache.mu.Unlock() + require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存") +} + +func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { + repo := &mockUserRepo{} + svc := NewUserService(repo, nil, nil) // billingCache = nil + + err := svc.UpdateBalance(context.Background(), 1, 50.0) + require.NoError(t, err, "billingCache 为 nil 时不应 panic") +} + +func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { + repo := &mockUserRepo{} + cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")} + svc := NewUserService(repo, nil, cache) + + err := svc.UpdateBalance(context.Background(), 99, 200.0) + require.NoError(t, err, "缓存失效失败不应影响主流程返回值") + + // 等待异步 goroutine 完成(即使失败也应调用) + require.Eventually(t, func() bool { + return cache.invalidateCallCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance") +} + +func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) { + repo := &mockUserRepo{updateBalanceErr: errors.New("database error")} + cache := &mockBillingCache{} + svc := NewUserService(repo, nil, cache) + + err := svc.UpdateBalance(context.Background(), 1, 100.0) + require.Error(t, err, "repo 失败时应返回错误") + require.Contains(t, err.Error(), "update balance") + + // repo 失败时不应触发缓存失效 + time.Sleep(100 * time.Millisecond) + require.Equal(t, int64(0), cache.invalidateCallCount.Load(), + "repo 失败时不应调用 InvalidateUserBalance") +} + +func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) { + repo := &mockUserRepo{} + auth := &mockAuthCacheInvalidator{} + cache := &mockBillingCache{} + svc := NewUserService(repo, auth, cache) + + err := svc.UpdateBalance(context.Background(), 77, 300.0) + require.NoError(t, err) + + // 验证 auth cache 同步失效 + auth.mu.Lock() + require.Equal(t, []int64{77}, auth.invalidatedUserIDs) + auth.mu.Unlock() + + // 验证 billing cache 异步失效 + require.Eventually(t, func() bool { + return cache.invalidateCallCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) +} + +func TestNewUserService_FieldsAssignment(t *testing.T) { + repo := &mockUserRepo{} + auth := &mockAuthCacheInvalidator{} + cache := &mockBillingCache{} + + svc := NewUserService(repo, auth, cache) + require.NotNil(t, svc) + require.Equal(t, repo, svc.userRepo) + require.Equal(t, auth, svc.authCacheInvalidator) + require.Equal(t, cache, svc.billingCache) +} diff --git a/deploy/.env.example b/deploy/.env.example index c5e850ae..26bb99b5 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -58,13 +58,67 @@ TZ=Asia/Shanghai POSTGRES_USER=sub2api POSTGRES_PASSWORD=change_this_secure_password POSTGRES_DB=sub2api +# PostgreSQL 监听端口(同时用于 PG 服务端和应用连接,默认 5432) +DATABASE_PORT=5432 + +# ----------------------------------------------------------------------------- +# PostgreSQL 服务端参数(可选;主要用于 deploy/docker-compose-aicodex.yml) +# ----------------------------------------------------------------------------- +# POSTGRES_MAX_CONNECTIONS:PostgreSQL 服务端允许的最大连接数。 +# 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。 +POSTGRES_MAX_CONNECTIONS=1024 +# POSTGRES_SHARED_BUFFERS:PostgreSQL 用于缓存数据页的共享内存。 +# 常见建议:物理内存的 10%~25%(容器内存受限时请按实际限制调整)。 +# 8GB 内存容器参考:1GB。 +POSTGRES_SHARED_BUFFERS=1GB +# POSTGRES_EFFECTIVE_CACHE_SIZE:查询规划器“假设可用的 OS 缓存大小”(不等于实际分配)。 +# 常见建议:物理内存的 50%~75%。 +# 8GB 内存容器参考:6GB。 +POSTGRES_EFFECTIVE_CACHE_SIZE=4GB +# POSTGRES_MAINTENANCE_WORK_MEM:维护操作内存(VACUUM/CREATE INDEX 等)。 +# 值越大维护越快,但会占用更多内存。 +# 8GB 内存容器参考:128MB。 +POSTGRES_MAINTENANCE_WORK_MEM=128MB + +# ----------------------------------------------------------------------------- +# PostgreSQL 连接池参数(可选,默认与程序内置一致) +# ----------------------------------------------------------------------------- +# 说明: +# - 这些参数控制 Sub2API 进程到 PostgreSQL 的连接池大小(不是 PostgreSQL 自身的 max_connections)。 +# - 多实例/多副本部署时,总连接上限约等于:实例数 * DATABASE_MAX_OPEN_CONNS。 +# - 连接池过大可能导致:数据库连接耗尽、内存占用上升、上下文切换增多,反而变慢。 +# - 建议结合 PostgreSQL 的 max_connections 与机器规格逐步调优: +# 通常把应用总连接上限控制在 max_connections 的 50%~80% 更稳妥。 +# +# DATABASE_MAX_OPEN_CONNS:最大打开连接数(活跃+空闲),达到后新请求会等待可用连接。 +# 典型范围:50~500(取决于 DB 规格、实例数、SQL 复杂度)。 +DATABASE_MAX_OPEN_CONNS=256 +# DATABASE_MAX_IDLE_CONNS:最大空闲连接数(热连接),建议 <= MAX_OPEN。 +# 太小会频繁建连增加延迟;太大会长期占用数据库资源。 +DATABASE_MAX_IDLE_CONNS=128 +# DATABASE_CONN_MAX_LIFETIME_MINUTES:单个连接最大存活时间(单位:分钟)。 +# 用于避免连接长期不重建导致的中间件/LB/NAT 异常或服务端重启后的“僵尸连接”。 +# 设置为 0 表示不限制(一般不建议生产环境)。 +DATABASE_CONN_MAX_LIFETIME_MINUTES=30 +# DATABASE_CONN_MAX_IDLE_TIME_MINUTES:空闲连接最大存活时间(单位:分钟)。 +# 超过该时间的空闲连接会被回收,防止长时间闲置占用连接数。 +# 设置为 0 表示不限制(一般不建议生产环境)。 +DATABASE_CONN_MAX_IDLE_TIME_MINUTES=5 # ----------------------------------------------------------------------------- # Redis Configuration # ----------------------------------------------------------------------------- +# Redis 监听端口(同时用于应用连接和 Redis 服务端,默认 6379) +REDIS_PORT=6379 # Leave empty for no password (default for local development) REDIS_PASSWORD= REDIS_DB=0 +# Redis 服务端最大客户端连接数(可选;主要用于 deploy/docker-compose-aicodex.yml) +REDIS_MAXCLIENTS=50000 +# Redis 连接池大小(默认 1024) +REDIS_POOL_SIZE=4096 +# Redis 最小空闲连接数(默认 10) +REDIS_MIN_IDLE_CONNS=256 REDIS_ENABLE_TLS=false # ----------------------------------------------------------------------------- @@ -119,6 +173,19 @@ RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10 # Gateway Scheduling (Optional) # 调度缓存与受控回源配置(缓存就绪且命中时不读 DB) # ----------------------------------------------------------------------------- +# Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI. +# 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。 +# +# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。 +# +# 默认:false +GATEWAY_FORCE_CODEX_CLI=false +# 上游连接池:每主机最大连接数(默认 1024;流式/HTTP1.1 场景可调大,如 2400/4096) +GATEWAY_MAX_CONNS_PER_HOST=2048 +# 上游连接池:最大空闲连接总数(默认 2560;账号/代理隔离 + 高并发场景可调大) +GATEWAY_MAX_IDLE_CONNS=8192 +# 上游连接池:每主机最大空闲连接(默认 120) +GATEWAY_MAX_IDLE_CONNS_PER_HOST=4096 # 粘性会话最大排队长度 GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING=3 # 粘性会话等待超时(时间段,例如 45s) diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index d9f5f2ab..8b15b54f 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -20,6 +20,10 @@ server: # Mode: "debug" for development, "release" for production # 运行模式:"debug" 用于开发,"release" 用于生产环境 mode: "release" + # Frontend base URL used to generate external links in emails (e.g. password reset) + # 用于生成邮件中的外部链接(例如:重置密码链接)的前端基础地址 + # Example: "https://example.com" + frontend_url: "" # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] @@ -108,9 +112,9 @@ security: # 白名单禁用时是否允许 http:// URL(默认: false,要求 https) allow_insecure_http: true response_headers: - # Enable configurable response header filtering (disable to use default allowlist) - # 启用可配置的响应头过滤(禁用则使用默认白名单) - enabled: false + # Enable configurable response header filtering (default: true) + # 启用可配置的响应头过滤(默认启用,过滤上游敏感响应头) + enabled: true # Extra allowed response headers from upstream # 额外允许的上游响应头 additional_allowed: [] @@ -151,17 +155,22 @@ gateway: # - account_proxy: Isolate by account+proxy combination (default, finest granularity) # - account_proxy: 按账户+代理组合隔离(默认,最细粒度) connection_pool_isolation: "account_proxy" + # Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI. + # 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。 + # + # 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。 + force_codex_cli: false # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # Max idle connections across all hosts # 所有主机的最大空闲连接数 - max_idle_conns: 240 + max_idle_conns: 2560 # Max idle connections per host # 每个主机的最大空闲连接数 max_idle_conns_per_host: 120 # Max connections per host # 每个主机的最大连接数 - max_conns_per_host: 240 + max_conns_per_host: 1024 # Idle connection timeout (seconds) # 空闲连接超时时间(秒) idle_conn_timeout_seconds: 90 @@ -381,9 +390,22 @@ database: # Database name # 数据库名称 dbname: "sub2api" - # SSL mode: disable, require, verify-ca, verify-full - # SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证) - sslmode: "disable" + # SSL mode: disable, prefer, require, verify-ca, verify-full + # SSL 模式:disable(禁用), prefer(优先加密,默认), require(要求), verify-ca(验证CA), verify-full(完全验证) + # 默认值为 "prefer",数据库支持 SSL 时自动使用加密连接,不支持时回退明文 + sslmode: "prefer" + # Max open connections (高并发场景建议 256+,需配合 PostgreSQL max_connections 调整) + # 最大打开连接数 + max_open_conns: 256 + # Max idle connections (建议为 max_open_conns 的 50%,减少频繁建连开销) + # 最大空闲连接数 + max_idle_conns: 128 + # Connection max lifetime (minutes) + # 连接最大存活时间(分钟) + conn_max_lifetime_minutes: 30 + # Connection max idle time (minutes) + # 空闲连接最大存活时间(分钟) + conn_max_idle_time_minutes: 5 # ============================================================================= # Redis Configuration @@ -402,6 +424,12 @@ redis: # Database number (0-15) # 数据库编号(0-15) db: 0 + # Connection pool size (max concurrent connections) + # 连接池大小(最大并发连接数) + pool_size: 1024 + # Minimum number of idle connections (高并发场景建议 128+,保持足够热连接) + # 最小空闲连接数 + min_idle_conns: 128 # Enable TLS/SSL connection # 是否启用 TLS/SSL 连接 enable_tls: false diff --git a/deploy/docker-compose-aicodex.yml b/deploy/docker-compose-aicodex.yml new file mode 100644 index 00000000..f650a60e --- /dev/null +++ b/deploy/docker-compose-aicodex.yml @@ -0,0 +1,233 @@ +# ============================================================================= +# Sub2API Docker Compose Host Configuration (Local Build) +# ============================================================================= +# Quick Start: +# 1. Copy .env.example to .env and configure +# 2. docker-compose -f docker-compose-host.yml up -d --build +# 3. Check logs: docker-compose -f docker-compose-host.yml logs -f sub2api +# 4. Access: http://localhost:8080 +# +# This configuration builds the image from source (Dockerfile in project root). +# All configuration is done via environment variables. +# No Setup Wizard needed - the system auto-initializes on first run. +# ============================================================================= + +services: + # =========================================================================== + # Sub2API Application + # =========================================================================== + sub2api: + #image: weishaw/sub2api:latest + image: yangjianbo/aicodex2api:latest + build: + context: .. + dockerfile: Dockerfile + container_name: sub2api + restart: unless-stopped + network_mode: host + ulimits: + nofile: + soft: 800000 + hard: 800000 + volumes: + # Data persistence (config.yaml will be auto-generated here) + - sub2api_data:/app/data + # Mount custom config.yaml (optional, overrides auto-generated config) + #- ./config.yaml:/app/data/config.yaml:ro + environment: + # ======================================================================= + # Auto Setup (REQUIRED for Docker deployment) + # ======================================================================= + - AUTO_SETUP=true + + # ======================================================================= + # Server Configuration + # ======================================================================= + - SERVER_HOST=0.0.0.0 + - SERVER_PORT=8080 + - SERVER_MODE=${SERVER_MODE:-release} + - RUN_MODE=${RUN_MODE:-standard} + + # ======================================================================= + # Database Configuration (PostgreSQL) + # ======================================================================= + # Using host network: point to host/external DB by DATABASE_HOST/DATABASE_PORT + - DATABASE_HOST=${DATABASE_HOST:-127.0.0.1} + - DATABASE_PORT=${DATABASE_PORT:-5432} + - DATABASE_USER=${POSTGRES_USER:-sub2api} + - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} + - DATABASE_SSLMODE=disable + - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} + - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} + - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} + - DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5} + + # ======================================================================= + # Gateway Configuration + # ======================================================================= + - GATEWAY_FORCE_CODEX_CLI=${GATEWAY_FORCE_CODEX_CLI:-false} + - GATEWAY_MAX_IDLE_CONNS=${GATEWAY_MAX_IDLE_CONNS:-2560} + - GATEWAY_MAX_IDLE_CONNS_PER_HOST=${GATEWAY_MAX_IDLE_CONNS_PER_HOST:-120} + - GATEWAY_MAX_CONNS_PER_HOST=${GATEWAY_MAX_CONNS_PER_HOST:-8192} + + # ======================================================================= + # Redis Configuration + # ======================================================================= + # Using host network: point to host/external Redis by REDIS_HOST/REDIS_PORT + - REDIS_HOST=${REDIS_HOST:-127.0.0.1} + - REDIS_PORT=${REDIS_PORT:-6379} + - REDIS_PASSWORD=${REDIS_PASSWORD:-} + - REDIS_DB=${REDIS_DB:-0} + - REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024} + - REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10} + - REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false} + + # ======================================================================= + # Admin Account (auto-created on first run) + # ======================================================================= + - ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local} + - ADMIN_PASSWORD=${ADMIN_PASSWORD:-} + + # ======================================================================= + # JWT Configuration + # ======================================================================= + # Leave empty to auto-generate (recommended) + - JWT_SECRET=${JWT_SECRET:-} + - JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24} + + # ======================================================================= + # TOTP (2FA) Configuration + # ======================================================================= + # IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty, + # a random key will be generated on each startup, causing all existing + # TOTP configurations to become invalid (users won't be able to login + # with 2FA). + # Generate a secure key: openssl rand -hex 32 + - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} + + # ======================================================================= + # Timezone Configuration + # This affects ALL time operations in the application: + # - Database timestamps + # - Usage statistics "today" boundary + # - Subscription expiry times + # - Log timestamps + # Common values: Asia/Shanghai, America/New_York, Europe/London, UTC + # ======================================================================= + - TZ=${TZ:-Asia/Shanghai} + + # ======================================================================= + # Gemini OAuth Configuration (for Gemini accounts) + # ======================================================================= + - GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-} + - GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-} + - GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-} + - GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-} + + # ======================================================================= + # Security Configuration (URL Allowlist) + # ======================================================================= + # Allow private IP addresses for CRS sync (for internal deployments) + - SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true} + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + # =========================================================================== + # PostgreSQL Database + # =========================================================================== + postgres: + image: postgres:18-alpine + container_name: sub2api-postgres + restart: unless-stopped + network_mode: host + ulimits: + nofile: + soft: 800000 + hard: 800000 + volumes: + - postgres_data:/var/lib/postgresql/data + environment: + - POSTGRES_USER=${POSTGRES_USER:-sub2api} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} + - POSTGRES_DB=${POSTGRES_DB:-sub2api} + - TZ=${TZ:-Asia/Shanghai} + command: + - "postgres" + - "-c" + - "listen_addresses=127.0.0.1" + # 监听端口:与应用侧 DATABASE_PORT 保持一致。 + - "-c" + - "port=${DATABASE_PORT:-5432}" + # 连接数上限:需要结合应用侧 DATABASE_MAX_OPEN_CONNS 调整。 + # 注意:max_connections 过大可能导致内存占用与上下文切换开销显著上升。 + - "-c" + - "max_connections=${POSTGRES_MAX_CONNECTIONS:-1024}" + # 典型内存参数(建议结合机器内存调优;不确定就保持默认或小步调大)。 + - "-c" + - "shared_buffers=${POSTGRES_SHARED_BUFFERS:-1GB}" + - "-c" + - "effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-6GB}" + - "-c" + - "maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-128MB}" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api} -p ${DATABASE_PORT:-5432}"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + # Note: bound to localhost only; not exposed to external network by default. + + # =========================================================================== + # Redis Cache + # =========================================================================== + redis: + image: redis:8-alpine + container_name: sub2api-redis + restart: unless-stopped + network_mode: host + ulimits: + nofile: + soft: 100000 + hard: 100000 + volumes: + - redis_data:/data + command: > + redis-server + --bind 127.0.0.1 + --port ${REDIS_PORT:-6379} + --maxclients ${REDIS_MAXCLIENTS:-50000} + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} + environment: + - TZ=${TZ:-Asia/Shanghai} + # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) + - REDISCLI_AUTH=${REDIS_PASSWORD:-} + healthcheck: + test: ["CMD-SHELL", "redis-cli -p ${REDIS_PORT:-6379} -a \"$REDISCLI_AUTH\" ping | grep -q PONG || redis-cli -p ${REDIS_PORT:-6379} ping | grep -q PONG"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 5s + +# ============================================================================= +# Volumes +# ============================================================================= +volumes: + sub2api_data: + driver: local + postgres_data: + driver: local + redis_data: + driver: local diff --git a/deploy/docker-compose-test.yml b/deploy/docker-compose-test.yml index bcda3141..81cd5222 100644 --- a/deploy/docker-compose-test.yml +++ b/deploy/docker-compose-test.yml @@ -57,6 +57,10 @@ services: - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - DATABASE_SSLMODE=disable + - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} + - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} + - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} + - DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5} # ======================================================================= # Redis Configuration @@ -65,6 +69,8 @@ services: - REDIS_PORT=6379 - REDIS_PASSWORD=${REDIS_PASSWORD:-} - REDIS_DB=${REDIS_DB:-0} + - REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024} + - REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10} # ======================================================================= # Admin Account (auto-created on first run) diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 05ce129a..e778612c 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -62,6 +62,10 @@ services: - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - DATABASE_SSLMODE=disable + - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} + - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} + - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} + - DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5} # ======================================================================= # Redis Configuration @@ -70,6 +74,8 @@ services: - REDIS_PORT=6379 - REDIS_PASSWORD=${REDIS_PASSWORD:-} - REDIS_DB=${REDIS_DB:-0} + - REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024} + - REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10} - REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false} # ======================================================================= diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml index 97903bc5..bb0041de 100644 --- a/deploy/docker-compose.standalone.yml +++ b/deploy/docker-compose.standalone.yml @@ -48,6 +48,10 @@ services: - DATABASE_PASSWORD=${DATABASE_PASSWORD:?DATABASE_PASSWORD is required} - DATABASE_DBNAME=${DATABASE_DBNAME:-sub2api} - DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable} + - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} + - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} + - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} + - DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5} # ======================================================================= # Redis Configuration - Required @@ -56,6 +60,8 @@ services: - REDIS_PORT=${REDIS_PORT:-6379} - REDIS_PASSWORD=${REDIS_PASSWORD:-} - REDIS_DB=${REDIS_DB:-0} + - REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024} + - REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10} - REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false} # ======================================================================= diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 033731ac..4297ad0e 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -54,6 +54,10 @@ services: - DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - DATABASE_DBNAME=${POSTGRES_DB:-sub2api} - DATABASE_SSLMODE=disable + - DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50} + - DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10} + - DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30} + - DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5} # ======================================================================= # Redis Configuration @@ -62,6 +66,8 @@ services: - REDIS_PORT=6379 - REDIS_PASSWORD=${REDIS_PASSWORD:-} - REDIS_DB=${REDIS_DB:-0} + - REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024} + - REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10} - REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false} # =======================================================================