From 7efa8b54c4a15c2d9140f2774b2eb0921f039368 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 08:50:12 +0800 Subject: [PATCH] =?UTF-8?q?perf(=E5=90=8E=E7=AB=AF):=20=E5=AE=8C=E6=88=90?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96=E4=B8=8E=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=B1=A0=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 DB/Redis 连接池配置与校验,并补充单测 网关请求体大小限制与 413 处理 HTTP/req 客户端池化并调整上游连接池默认值 并发槽位改为 ZSET+Lua 与指数退避 用量统计改 SQL 聚合并新增索引迁移 计费缓存写入改工作池并补测试/基准 测试: 在 backend/ 下运行 go test ./... --- backend/cmd/server/wire.go | 5 + backend/cmd/server/wire_gen.go | 21 ++- backend/internal/config/config.go | 106 +++++++++++ backend/internal/handler/gateway_handler.go | 59 +++--- backend/internal/handler/gateway_helper.go | 59 +++++- .../internal/handler/gemini_v1beta_handler.go | 7 +- .../handler/openai_gateway_handler.go | 4 + .../internal/handler/request_body_limit.go | 27 +++ .../handler/request_body_limit_test.go | 45 +++++ backend/internal/infrastructure/db_pool.go | 32 ++++ .../internal/infrastructure/db_pool_test.go | 50 +++++ backend/internal/infrastructure/ent.go | 1 + backend/internal/infrastructure/redis.go | 33 +++- backend/internal/infrastructure/redis_test.go | 35 ++++ backend/internal/pkg/httpclient/pool.go | 152 +++++++++++++++ .../repository/claude_oauth_service.go | 14 +- .../repository/claude_usage_service.go | 22 +-- .../internal/repository/concurrency_cache.go | 156 +++++++++------- .../concurrency_cache_benchmark_test.go | 135 ++++++++++++++ .../concurrency_cache_integration_test.go | 18 +- .../repository/gemini_oauth_client.go | 9 +- .../repository/geminicli_codeassist_client.go | 9 +- .../repository/github_release_service.go | 20 +- backend/internal/repository/http_upstream.go | 95 +++++++--- .../http_upstream_benchmark_test.go | 46 +++++ .../internal/repository/http_upstream_test.go | 4 +- .../repository/openai_oauth_service.go | 12 +- .../internal/repository/pricing_service.go | 11 +- .../repository/proxy_probe_service.go | 48 +---- .../repository/proxy_probe_service_test.go | 18 +- .../internal/repository/req_client_pool.go | 59 ++++++ .../internal/repository/turnstile_service.go | 13 +- backend/internal/repository/usage_log_repo.go | 155 ++++++++++++++++ backend/internal/repository/wire.go | 11 +- backend/internal/server/api_contract_test.go | 12 ++ .../server/middleware/request_body_limit.go | 15 ++ backend/internal/server/routes/gateway.go | 8 +- .../internal/service/account_usage_service.go | 3 + .../internal/service/billing_cache_service.go | 151 +++++++++++++-- .../service/billing_cache_service_test.go | 75 ++++++++ .../internal/service/concurrency_service.go | 14 +- backend/internal/service/crs_sync_service.go | 9 +- backend/internal/service/gateway_request.go | 70 +++++++ .../internal/service/gateway_request_test.go | 38 ++++ backend/internal/service/gateway_service.go | 174 ++++++++---------- .../service/gateway_service_benchmark_test.go | 50 +++++ .../service/gemini_messages_compat_service.go | 6 +- .../internal/service/gemini_oauth_service.go | 13 +- .../service/openai_gateway_service.go | 12 +- backend/internal/service/pricing_service.go | 15 +- backend/internal/service/usage_service.go | 78 +++----- .../010_add_usage_logs_aggregated_indexes.sql | 4 + deploy/config.example.yaml | 16 ++ 53 files changed, 1805 insertions(+), 449 deletions(-) create mode 100644 backend/internal/handler/request_body_limit.go create mode 100644 backend/internal/handler/request_body_limit_test.go create mode 100644 backend/internal/infrastructure/db_pool.go create mode 100644 backend/internal/infrastructure/db_pool_test.go create mode 100644 backend/internal/infrastructure/redis_test.go create mode 100644 backend/internal/pkg/httpclient/pool.go create mode 100644 backend/internal/repository/concurrency_cache_benchmark_test.go create mode 100644 backend/internal/repository/http_upstream_benchmark_test.go create mode 100644 backend/internal/repository/req_client_pool.go create mode 100644 backend/internal/server/middleware/request_body_limit.go create mode 100644 backend/internal/service/billing_cache_service_test.go create mode 100644 backend/internal/service/gateway_request.go create mode 100644 backend/internal/service/gateway_request_test.go create mode 100644 backend/internal/service/gateway_service_benchmark_test.go create mode 100644 backend/migrations/010_add_usage_logs_aggregated_indexes.sql diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index a746955b..fffcd5f9 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -67,6 +67,7 @@ func provideCleanup( tokenRefresh *service.TokenRefreshService, pricing *service.PricingService, emailQueue *service.EmailQueueService, + billingCache *service.BillingCacheService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -94,6 +95,10 @@ func provideCleanup( emailQueue.Stop() return nil }}, + {"BillingCacheService", func() error { + billingCache.Stop() + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f37d696b..664e7aca 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -39,11 +39,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { if err != nil { return nil, err } - sqlDB, err := infrastructure.ProvideSQLDB(client) + db, err := infrastructure.ProvideSQLDB(client) if err != nil { return nil, err } - userRepository := repository.NewUserRepository(client, sqlDB) + userRepository := repository.NewUserRepository(client, db) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) redisClient := infrastructure.ProvideRedis(configConfig) @@ -57,12 +57,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { authHandler := handler.NewAuthHandler(configConfig, authService, userService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewApiKeyRepository(client) - groupRepository := repository.NewGroupRepository(client, sqlDB) + groupRepository := repository.NewGroupRepository(client, db) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) apiKeyCache := repository.NewApiKeyCache(redisClient) apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) - usageLogRepository := repository.NewUsageLogRepository(client, sqlDB) + usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) @@ -75,8 +75,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) dashboardHandler := admin.NewDashboardHandler(dashboardService) - accountRepository := repository.NewAccountRepository(client, sqlDB) - proxyRepository := repository.NewProxyRepository(client, sqlDB) + accountRepository := repository.NewAccountRepository(client, db) + proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber() adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) adminUserHandler := admin.NewUserHandler(adminService) @@ -95,7 +95,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream) - concurrencyCache := repository.NewConcurrencyCache(redisClient) + concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.NewConcurrencyService(concurrencyCache) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) @@ -142,7 +142,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { httpServer := server.ProvideHTTPServer(configConfig, engine) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig) - v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher) + v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher) application := &Application{ Server: httpServer, Cleanup: v, @@ -170,6 +170,7 @@ func provideCleanup( tokenRefresh *service.TokenRefreshService, pricing *service.PricingService, emailQueue *service.EmailQueueService, + billingCache *service.BillingCacheService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -196,6 +197,10 @@ func provideCleanup( emailQueue.Stop() return nil }}, + {"BillingCacheService", func() error { + billingCache.Stop() + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 86283fde..dfc9a844 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -79,12 +79,29 @@ type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 // 注意:这不影响流式数据传输,只控制等待响应头的时间 ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` + // 请求体最大字节数,用于网关请求体大小限制 + MaxBodySize int64 `mapstructure:"max_body_size"` + + // HTTP 上游连接池配置(性能优化:支持高并发场景调优) + // MaxIdleConns: 所有主机的最大空闲连接总数 + MaxIdleConns int `mapstructure:"max_idle_conns"` + // MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率) + MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"` + // MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲),0表示无限制 + MaxConnsPerHost int `mapstructure:"max_conns_per_host"` + // IdleConnTimeoutSeconds: 空闲连接超时时间(秒) + IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"` + // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) + // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 + ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` } func (s *ServerConfig) Address() string { return fmt.Sprintf("%s:%d", s.Host, s.Port) } +// DatabaseConfig 数据库连接配置 +// 性能优化:新增连接池参数,避免频繁创建/销毁连接 type DatabaseConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` @@ -92,6 +109,15 @@ type DatabaseConfig struct { Password string `mapstructure:"password"` DBName string `mapstructure:"dbname"` SSLMode string `mapstructure:"sslmode"` + // 连接池配置(性能优化:可配置化连接池参数) + // MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽 + MaxOpenConns int `mapstructure:"max_open_conns"` + // MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟 + MaxIdleConns int `mapstructure:"max_idle_conns"` + // ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏 + ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"` + // ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接 + ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"` } func (d *DatabaseConfig) DSN() string { @@ -112,11 +138,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string { ) } +// RedisConfig Redis 连接配置 +// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量 type RedisConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Password string `mapstructure:"password"` DB int `mapstructure:"db"` + // 连接池与超时配置(性能优化:可配置化连接池参数) + // DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞 + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + // ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池 + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + // WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池 + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + // PoolSize: 连接池大小,控制最大并发连接数 + PoolSize int `mapstructure:"pool_size"` + // MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟 + MinIdleConns int `mapstructure:"min_idle_conns"` } func (r *RedisConfig) Address() string { @@ -203,12 +242,21 @@ func setDefaults() { viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.sslmode", "disable") + viper.SetDefault("database.max_open_conns", 50) + viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.conn_max_lifetime_minutes", 30) + viper.SetDefault("database.conn_max_idle_time_minutes", 5) // Redis viper.SetDefault("redis.host", "localhost") viper.SetDefault("redis.port", 6379) viper.SetDefault("redis.password", "") viper.SetDefault("redis.db", 0) + viper.SetDefault("redis.dial_timeout_seconds", 5) + viper.SetDefault("redis.read_timeout_seconds", 3) + viper.SetDefault("redis.write_timeout_seconds", 3) + viper.SetDefault("redis.pool_size", 128) + viper.SetDefault("redis.min_idle_conns", 10) // JWT viper.SetDefault("jwt.secret", "change-me-in-production") @@ -240,6 +288,13 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + // HTTP 上游连接池配置(针对 5000+ 并发用户优化) + viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒) + viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -263,6 +318,57 @@ func (c *Config) Validate() error { if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" { return fmt.Errorf("jwt.secret must be changed in production") } + if c.Database.MaxOpenConns <= 0 { + return fmt.Errorf("database.max_open_conns must be positive") + } + if c.Database.MaxIdleConns < 0 { + return fmt.Errorf("database.max_idle_conns must be non-negative") + } + if c.Database.MaxIdleConns > c.Database.MaxOpenConns { + return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns") + } + if c.Database.ConnMaxLifetimeMinutes < 0 { + return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative") + } + if c.Database.ConnMaxIdleTimeMinutes < 0 { + return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative") + } + if c.Redis.DialTimeoutSeconds <= 0 { + return fmt.Errorf("redis.dial_timeout_seconds must be positive") + } + if c.Redis.ReadTimeoutSeconds <= 0 { + return fmt.Errorf("redis.read_timeout_seconds must be positive") + } + if c.Redis.WriteTimeoutSeconds <= 0 { + return fmt.Errorf("redis.write_timeout_seconds must be positive") + } + if c.Redis.PoolSize <= 0 { + return fmt.Errorf("redis.pool_size must be positive") + } + if c.Redis.MinIdleConns < 0 { + return fmt.Errorf("redis.min_idle_conns must be non-negative") + } + if c.Redis.MinIdleConns > c.Redis.PoolSize { + return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size") + } + if c.Gateway.MaxBodySize <= 0 { + return fmt.Errorf("gateway.max_body_size must be positive") + } + if c.Gateway.MaxIdleConns <= 0 { + return fmt.Errorf("gateway.max_idle_conns must be positive") + } + if c.Gateway.MaxIdleConnsPerHost <= 0 { + return fmt.Errorf("gateway.max_idle_conns_per_host must be positive") + } + if c.Gateway.MaxConnsPerHost < 0 { + return fmt.Errorf("gateway.max_conns_per_host must be non-negative") + } + if c.Gateway.IdleConnTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") + } + if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { + return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") + } return nil } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index bf179ea1..fc92b2d8 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } @@ -76,15 +80,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 解析请求获取模型名和stream - var req struct { - Model string `json:"model"` - Stream bool `json:"stream"` - } - if err := json.Unmarshal(body, &req); err != nil { + parsedReq, err := service.ParseGatewayRequest(body) + if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + reqModel := parsedReq.Model + reqStream := parsedReq.Stream // Track if we've started streaming (for error handling) streamStarted := false @@ -106,7 +108,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) // 1. 首先获取用户并发槽位 - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted) + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("User concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "user", streamStarted) @@ -124,7 +126,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 计算粘性会话hash - sessionHash := h.gatewayService.GenerateSessionHash(body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 platform := "" @@ -141,7 +143,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { lastFailoverStatus := 0 for { - account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -153,16 +155,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if req.Stream { - sendMockWarmupStream(c, req.Model) + if reqStream { + sendMockWarmupStream(c, reqModel) } else { - sendMockWarmupResponse(c, req.Model) + sendMockWarmupResponse(c, reqModel) } return } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) @@ -172,7 +174,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body) + result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) } else { result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) } @@ -223,7 +225,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -235,16 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if req.Stream { - sendMockWarmupStream(c, req.Model) + if reqStream { + sendMockWarmupStream(c, reqModel) } else { - sendMockWarmupResponse(c, req.Model) + sendMockWarmupResponse(c, reqModel) } return } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) @@ -256,7 +258,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) + result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -496,6 +498,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } @@ -505,11 +511,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - // 解析请求获取模型名 - var req struct { - Model string `json:"model"` - } - if err := json.Unmarshal(body, &req); err != nil { + parsedReq, err := service.ParseGatewayRequest(body) + if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } @@ -525,17 +528,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } // 计算粘性会话 hash - sessionHash := h.gatewayService.GenerateSessionHash(body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 选择支持该模型的账号 - account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) + account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) return } // 转发请求(不记录使用量) - if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil { + if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { log.Printf("Forward count_tokens request failed: %v", err) // 错误响应已在 ForwardCountTokens 中处理 return diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 5cbe462d..4c7bd0f0 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -3,6 +3,7 @@ package handler import ( "context" "fmt" + "math/rand" "net/http" "time" @@ -11,11 +12,28 @@ import ( "github.com/gin-gonic/gin" ) +// 并发槽位等待相关常量 +// +// 性能优化说明: +// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题: +// 1. 高并发时频繁轮询增加 Redis 压力 +// 2. 固定间隔可能导致多个请求同时重试(惊群效应) +// +// 新实现使用指数退避 + 抖动算法: +// 1. 初始退避 100ms,每次乘以 1.5,最大 2s +// 2. 添加 ±20% 的随机抖动,分散重试时间点 +// 3. 减少 Redis 压力,避免惊群效应 const ( - // maxConcurrencyWait is the maximum time to wait for a concurrency slot + // maxConcurrencyWait 等待并发槽位的最大时间 maxConcurrencyWait = 30 * time.Second - // pingInterval is the interval for sending ping events during slot wait + // pingInterval 流式响应等待时发送 ping 的间隔 pingInterval = 15 * time.Second + // initialBackoff 初始退避时间 + initialBackoff = 100 * time.Millisecond + // backoffMultiplier 退避时间乘数(指数退避) + backoffMultiplier = 1.5 + // maxBackoff 最大退避时间 + maxBackoff = 2 * time.Second ) // SSEPingFormat defines the format of SSE ping events for different platforms @@ -131,8 +149,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, pingCh = pingTicker.C } - pollTicker := time.NewTicker(100 * time.Millisecond) - defer pollTicker.Stop() + backoff := initialBackoff + timer := time.NewTimer(backoff) + defer timer.Stop() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) for { select { @@ -156,7 +176,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, } flusher.Flush() - case <-pollTicker.C: + case <-timer.C: // Try to acquire slot var result *service.AcquireResult var err error @@ -174,6 +194,35 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, if result.Acquired { return result.ReleaseFunc, nil } + backoff = nextBackoff(backoff, rng) + timer.Reset(backoff) } } } + +// nextBackoff 计算下一次退避时间 +// 性能优化:使用指数退避 + 随机抖动,避免惊群效应 +// current: 当前退避时间 +// rng: 随机数生成器(可为 nil,此时不添加抖动) +// 返回值:下一次退避时间(100ms ~ 2s 之间) +func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration { + // 指数退避:当前时间 * 1.5 + next := time.Duration(float64(current) * backoffMultiplier) + if next > maxBackoff { + next = maxBackoff + } + if rng == nil { + return next + } + // 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2) + // 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis + jitter := 0.8 + rng.Float64()*0.4 + jittered := time.Duration(float64(next) * jitter) + if jittered < initialBackoff { + return initialBackoff + } + if jittered > maxBackoff { + return maxBackoff + } + return jittered +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index ea1bdf5a..4e99e00d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit)) + return + } googleError(c, http.StatusBadRequest, "Failed to read request body") return } @@ -191,7 +195,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 3) select account (sticky session based on request body) - sessionHash := h.gatewayService.GenerateSessionHash(body) + parsedReq, _ := service.ParseGatewayRequest(body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) const maxAccountSwitches = 3 switchCount := 0 failedAccountIDs := make(map[int64]struct{}) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 2dee9ccd..7fcb329d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Read request body body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } diff --git a/backend/internal/handler/request_body_limit.go b/backend/internal/handler/request_body_limit.go new file mode 100644 index 00000000..d746673b --- /dev/null +++ b/backend/internal/handler/request_body_limit.go @@ -0,0 +1,27 @@ +package handler + +import ( + "errors" + "fmt" + "net/http" +) + +func extractMaxBytesError(err error) (*http.MaxBytesError, bool) { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + return maxErr, true + } + return nil, false +} + +func formatBodyLimit(limit int64) string { + const mb = 1024 * 1024 + if limit >= mb { + return fmt.Sprintf("%dMB", limit/mb) + } + return fmt.Sprintf("%dB", limit) +} + +func buildBodyTooLargeMessage(limit int64) string { + return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit)) +} diff --git a/backend/internal/handler/request_body_limit_test.go b/backend/internal/handler/request_body_limit_test.go new file mode 100644 index 00000000..bd9b8177 --- /dev/null +++ b/backend/internal/handler/request_body_limit_test.go @@ -0,0 +1,45 @@ +package handler + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestRequestBodyLimitTooLarge(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := int64(16) + router := gin.New() + router.Use(middleware.RequestBodyLimit(limit)) + router.POST("/test", func(c *gin.Context) { + _, err := io.ReadAll(c.Request.Body) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": buildBodyTooLargeMessage(maxErr.Limit), + }) + return + } + c.JSON(http.StatusBadRequest, gin.H{ + "error": "read_failed", + }) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + payload := bytes.Repeat([]byte("a"), int(limit+1)) + req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload)) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code) + require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit)) +} diff --git a/backend/internal/infrastructure/db_pool.go b/backend/internal/infrastructure/db_pool.go new file mode 100644 index 00000000..612155bf --- /dev/null +++ b/backend/internal/infrastructure/db_pool.go @@ -0,0 +1,32 @@ +package infrastructure + +import ( + "database/sql" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +type dbPoolSettings struct { + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration + ConnMaxIdleTime time.Duration +} + +func buildDBPoolSettings(cfg *config.Config) dbPoolSettings { + return dbPoolSettings{ + MaxOpenConns: cfg.Database.MaxOpenConns, + MaxIdleConns: cfg.Database.MaxIdleConns, + ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute, + ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute, + } +} + +func applyDBPoolSettings(db *sql.DB, cfg *config.Config) { + settings := buildDBPoolSettings(cfg) + db.SetMaxOpenConns(settings.MaxOpenConns) + db.SetMaxIdleConns(settings.MaxIdleConns) + db.SetConnMaxLifetime(settings.ConnMaxLifetime) + db.SetConnMaxIdleTime(settings.ConnMaxIdleTime) +} diff --git a/backend/internal/infrastructure/db_pool_test.go b/backend/internal/infrastructure/db_pool_test.go new file mode 100644 index 00000000..0f0e9716 --- /dev/null +++ b/backend/internal/infrastructure/db_pool_test.go @@ -0,0 +1,50 @@ +package infrastructure + +import ( + "database/sql" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" + + _ "github.com/lib/pq" +) + +func TestBuildDBPoolSettings(t *testing.T) { + cfg := &config.Config{ + Database: config.DatabaseConfig{ + MaxOpenConns: 50, + MaxIdleConns: 10, + ConnMaxLifetimeMinutes: 30, + ConnMaxIdleTimeMinutes: 5, + }, + } + + settings := buildDBPoolSettings(cfg) + require.Equal(t, 50, settings.MaxOpenConns) + require.Equal(t, 10, settings.MaxIdleConns) + require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime) + require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime) +} + +func TestApplyDBPoolSettings(t *testing.T) { + cfg := &config.Config{ + Database: config.DatabaseConfig{ + MaxOpenConns: 40, + MaxIdleConns: 8, + ConnMaxLifetimeMinutes: 15, + ConnMaxIdleTimeMinutes: 3, + }, + } + + db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable") + require.NoError(t, err) + t.Cleanup(func() { + _ = db.Close() + }) + + applyDBPoolSettings(db, cfg) + stats := db.Stats() + require.Equal(t, 40, stats.MaxOpenConnections) +} diff --git a/backend/internal/infrastructure/ent.go b/backend/internal/infrastructure/ent.go index 13184a83..b1ab9a55 100644 --- a/backend/internal/infrastructure/ent.go +++ b/backend/internal/infrastructure/ent.go @@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { if err != nil { return nil, nil, err } + applyDBPoolSettings(drv.DB(), cfg) // 确保数据库 schema 已准备就绪。 // SQL 迁移文件是 schema 的权威来源(source of truth)。 diff --git a/backend/internal/infrastructure/redis.go b/backend/internal/infrastructure/redis.go index 970a2595..5bb92d19 100644 --- a/backend/internal/infrastructure/redis.go +++ b/backend/internal/infrastructure/redis.go @@ -1,16 +1,39 @@ package infrastructure import ( + "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/redis/go-redis/v9" ) // InitRedis 初始化 Redis 客户端 +// +// 性能优化说明: +// 原实现使用 go-redis 默认配置,未设置连接池和超时参数: +// 1. 默认连接池大小可能不足以支撑高并发 +// 2. 无超时控制可能导致慢操作阻塞 +// +// 新实现支持可配置的连接池和超时参数: +// 1. PoolSize: 控制最大并发连接数(默认 128) +// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10) +// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时 func InitRedis(cfg *config.Config) *redis.Client { - return redis.NewClient(&redis.Options{ - Addr: cfg.Redis.Address(), - Password: cfg.Redis.Password, - DB: cfg.Redis.DB, - }) + return redis.NewClient(buildRedisOptions(cfg)) +} + +// buildRedisOptions 构建 Redis 连接选项 +// 从配置文件读取连接池和超时参数,支持生产环境调优 +func buildRedisOptions(cfg *config.Config) *redis.Options { + return &redis.Options{ + Addr: cfg.Redis.Address(), + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时 + ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时 + WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时 + PoolSize: cfg.Redis.PoolSize, // 连接池大小 + MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接 + } } diff --git a/backend/internal/infrastructure/redis_test.go b/backend/internal/infrastructure/redis_test.go new file mode 100644 index 00000000..5e38e826 --- /dev/null +++ b/backend/internal/infrastructure/redis_test.go @@ -0,0 +1,35 @@ +package infrastructure + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestBuildRedisOptions(t *testing.T) { + cfg := &config.Config{ + Redis: config.RedisConfig{ + Host: "localhost", + Port: 6379, + Password: "secret", + DB: 2, + DialTimeoutSeconds: 5, + ReadTimeoutSeconds: 3, + WriteTimeoutSeconds: 4, + PoolSize: 100, + MinIdleConns: 10, + }, + } + + opts := buildRedisOptions(cfg) + require.Equal(t, "localhost:6379", opts.Addr) + require.Equal(t, "secret", opts.Password) + require.Equal(t, 2, opts.DB) + require.Equal(t, 5*time.Second, opts.DialTimeout) + require.Equal(t, 3*time.Second, opts.ReadTimeout) + require.Equal(t, 4*time.Second, opts.WriteTimeout) + require.Equal(t, 100, opts.PoolSize) + require.Equal(t, 10, opts.MinIdleConns) +} diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go new file mode 100644 index 00000000..f68d50a5 --- /dev/null +++ b/backend/internal/pkg/httpclient/pool.go @@ -0,0 +1,152 @@ +// Package httpclient 提供共享 HTTP 客户端池 +// +// 性能优化说明: +// 原实现在多个服务中重复创建 http.Client: +// 1. proxy_probe_service.go: 每次探测创建新客户端 +// 2. pricing_service.go: 每次请求创建新客户端 +// 3. turnstile_service.go: 每次验证创建新客户端 +// 4. github_release_service.go: 每次请求创建新客户端 +// 5. claude_usage_service.go: 每次请求创建新客户端 +// +// 新实现使用统一的客户端池: +// 1. 相同配置复用同一 http.Client 实例 +// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销 +// 3. 支持 HTTP/HTTPS/SOCKS5 代理 +// 4. 支持严格代理模式(代理失败则返回错误) +package httpclient + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/net/proxy" +) + +// Transport 连接池默认配置 +const ( + defaultMaxIdleConns = 100 // 最大空闲连接数 + defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 + defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 +) + +// Options 定义共享 HTTP 客户端的构建参数 +type Options struct { + ProxyURL string // 代理 URL(支持 http/https/socks5) + Timeout time.Duration // 请求总超时时间 + ResponseHeaderTimeout time.Duration // 等待响应头超时时间 + InsecureSkipVerify bool // 是否跳过 TLS 证书验证 + ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 + + // 可选的连接池参数(不设置则使用默认值) + MaxIdleConns int // 最大空闲连接总数(默认 100) + MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10) + MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制) +} + +// sharedClients 存储按配置参数缓存的 http.Client 实例 +var sharedClients sync.Map + +// GetClient 返回共享的 HTTP 客户端实例 +// 性能优化:相同配置复用同一客户端,避免重复创建 Transport +func GetClient(opts Options) (*http.Client, error) { + key := buildClientKey(opts) + if cached, ok := sharedClients.Load(key); ok { + return cached.(*http.Client), nil + } + + client, err := buildClient(opts) + if err != nil { + if opts.ProxyStrict { + return nil, err + } + fallback := opts + fallback.ProxyURL = "" + client, _ = buildClient(fallback) + } + + actual, _ := sharedClients.LoadOrStore(key, client) + return actual.(*http.Client), nil +} + +func buildClient(opts Options) (*http.Client, error) { + transport, err := buildTransport(opts) + if err != nil { + return nil, err + } + + return &http.Client{ + Transport: transport, + Timeout: opts.Timeout, + }, nil +} + +func buildTransport(opts Options) (*http.Transport, error) { + // 使用自定义值或默认值 + maxIdleConns := opts.MaxIdleConns + if maxIdleConns <= 0 { + maxIdleConns = defaultMaxIdleConns + } + maxIdleConnsPerHost := opts.MaxIdleConnsPerHost + if maxIdleConnsPerHost <= 0 { + maxIdleConnsPerHost = defaultMaxIdleConnsPerHost + } + + transport := &http.Transport{ + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConnsPerHost, + MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制 + IdleConnTimeout: defaultIdleConnTimeout, + ResponseHeaderTimeout: opts.ResponseHeaderTimeout, + } + + if opts.InsecureSkipVerify { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + proxyURL := strings.TrimSpace(opts.ProxyURL) + if proxyURL == "" { + return transport, nil + } + + parsed, err := url.Parse(proxyURL) + if err != nil { + return nil, err + } + + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + transport.Proxy = http.ProxyURL(parsed) + case "socks5", "socks5h": + dialer, err := proxy.FromURL(parsed, proxy.Direct) + if err != nil { + return nil, err + } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + default: + return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme) + } + + return transport, nil +} + +func buildClientKey(opts Options) string { + return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d", + strings.TrimSpace(opts.ProxyURL), + opts.Timeout.String(), + opts.ResponseHeaderTimeout.String(), + opts.InsecureSkipVerify, + opts.ProxyStrict, + opts.MaxIdleConns, + opts.MaxIdleConnsPerHost, + opts.MaxConnsPerHost, + ) +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index f7ff2341..b03b5415 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro } func createReqClient(proxyURL string) *req.Client { - client := req.C(). - ImpersonateChrome(). - SetTimeout(60 * time.Second) - - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 60 * time.Second, + Impersonate: true, + }) } func prefix(s string, n int) string { diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 7ccbeafc..424d1a9a 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -6,9 +6,9 @@ import ( "fmt" "io" "net/http" - "net/url" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { } func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { - transport, ok := http.DefaultTransport.(*http.Transport) - if !ok { - return nil, fmt.Errorf("failed to get default transport") - } - transport = transport.Clone() - if proxyURL != "" { - if parsedURL, err := url.Parse(proxyURL); err == nil { - transport.Proxy = http.ProxyURL(parsedURL) - } - } - - client := &http.Client{ - Transport: transport, - Timeout: 30 * time.Second, + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 30 * time.Second, + }) + if err != nil { + client = &http.Client{Timeout: 30 * time.Second} } req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil) diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 2946f691..31527f22 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -3,67 +3,90 @@ package repository import ( "context" "fmt" - "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) +// 并发控制缓存常量定义 +// +// 性能优化说明: +// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}), +// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。 +// +// 新实现改用 Redis 有序集合(Sorted Set): +// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳 +// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1) +// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL +// 4. 单次 Redis 调用完成计数,减少网络往返 const ( - // Key prefixes for independent slot keys - // Format: concurrency:account:{accountID}:{requestID} + // 并发槽位键前缀(有序集合) + // 格式: concurrency:account:{accountID} accountSlotKeyPrefix = "concurrency:account:" - // Format: concurrency:user:{userID}:{requestID} + // 格式: concurrency:user:{userID} userSlotKeyPrefix = "concurrency:user:" - // Wait queue keeps counter format: concurrency:wait:{userID} + // 等待队列计数器格式: concurrency:wait:{userID} waitQueueKeyPrefix = "concurrency:wait:" - // Slot TTL - each slot expires independently - slotTTL = 5 * time.Minute + // 默认槽位过期时间(分钟),可通过配置覆盖 + defaultSlotTTLMinutes = 15 ) var ( - // acquireScript uses SCAN to count existing slots and creates new slot if under limit - // KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*") - // KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx") + // acquireScript 使用有序集合计数并在未达上限时添加槽位 + // 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题 + // KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id}) // ARGV[1] = maxConcurrency - // ARGV[2] = TTL in seconds + // ARGV[2] = TTL(秒) + // ARGV[3] = requestID acquireScript = redis.NewScript(` - local pattern = KEYS[1] - local slotKey = KEYS[2] + local key = KEYS[1] local maxConcurrency = tonumber(ARGV[1]) local ttl = tonumber(ARGV[2]) + local requestID = ARGV[3] - -- Count existing slots using SCAN - local cursor = "0" - local count = 0 - repeat - local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100) - cursor = result[1] - count = count + #result[2] - until cursor == "0" + -- 使用 Redis 服务器时间,确保多实例时钟一致 + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl - -- Check if we can acquire a slot + -- 清理过期槽位 + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + + -- 检查是否已存在(支持重试场景刷新时间戳) + local exists = redis.call('ZSCORE', key, requestID) + if exists ~= false then + redis.call('ZADD', key, now, requestID) + redis.call('EXPIRE', key, ttl) + return 1 + end + + -- 检查是否达到并发上限 + local count = redis.call('ZCARD', key) if count < maxConcurrency then - redis.call('SET', slotKey, '1', 'EX', ttl) + redis.call('ZADD', key, now, requestID) + redis.call('EXPIRE', key, ttl) return 1 end return 0 `) - // getCountScript counts slots using SCAN - // KEYS[1] = pattern for SCAN + // getCountScript 统计有序集合中的槽位数量并清理过期条目 + // 使用 Redis TIME 命令获取服务器时间 + // KEYS[1] = 有序集合键 + // ARGV[1] = TTL(秒) getCountScript = redis.NewScript(` - local pattern = KEYS[1] - local cursor = "0" - local count = 0 - repeat - local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100) - cursor = result[1] - count = count + #result[2] - until cursor == "0" - return count + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + + -- 使用 Redis 服务器时间 + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl + + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + return redis.call('ZCARD', key) `) // incrementWaitScript - only sets TTL on first creation to avoid refreshing @@ -103,28 +126,29 @@ var ( ) type concurrencyCache struct { - rdb *redis.Client + rdb *redis.Client + slotTTLSeconds int // 槽位过期时间(秒) } -func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache { - return &concurrencyCache{rdb: rdb} +// NewConcurrencyCache 创建并发控制缓存 +// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟 +func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache { + if slotTTLMinutes <= 0 { + slotTTLMinutes = defaultSlotTTLMinutes + } + return &concurrencyCache{ + rdb: rdb, + slotTTLSeconds: slotTTLMinutes * 60, + } } // Helper functions for key generation -func accountSlotKey(accountID int64, requestID string) string { - return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID) +func accountSlotKey(accountID int64) string { + return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) } -func accountSlotPattern(accountID int64) string { - return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID) -} - -func userSlotKey(userID int64, requestID string) string { - return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID) -} - -func userSlotPattern(userID int64) string { - return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID) +func userSlotKey(userID int64) string { + return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) } func waitQueueKey(userID int64) string { @@ -134,10 +158,9 @@ func waitQueueKey(userID int64) string { // Account slot operations func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { - pattern := accountSlotPattern(accountID) - slotKey := accountSlotKey(accountID, requestID) - - result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int() + key := accountSlotKey(accountID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致 + result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int() if err != nil { return false, err } @@ -145,13 +168,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int } func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { - slotKey := accountSlotKey(accountID, requestID) - return c.rdb.Del(ctx, slotKey).Err() + key := accountSlotKey(accountID) + return c.rdb.ZRem(ctx, key, requestID).Err() } func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { - pattern := accountSlotPattern(accountID) - result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() + key := accountSlotKey(accountID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取 + result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int() if err != nil { return 0, err } @@ -161,10 +185,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID // User slot operations func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { - pattern := userSlotPattern(userID) - slotKey := userSlotKey(userID, requestID) - - result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int() + key := userSlotKey(userID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致 + result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int() if err != nil { return false, err } @@ -172,13 +195,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma } func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { - slotKey := userSlotKey(userID, requestID) - return c.rdb.Del(ctx, slotKey).Err() + key := userSlotKey(userID) + return c.rdb.ZRem(ctx, key, requestID).Err() } func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { - pattern := userSlotPattern(userID) - result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() + key := userSlotKey(userID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取 + result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int() if err != nil { return 0, err } @@ -189,7 +213,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { key := waitQueueKey(userID) - result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int() + result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int() if err != nil { return false, err } diff --git a/backend/internal/repository/concurrency_cache_benchmark_test.go b/backend/internal/repository/concurrency_cache_benchmark_test.go new file mode 100644 index 00000000..29cc7fbc --- /dev/null +++ b/backend/internal/repository/concurrency_cache_benchmark_test.go @@ -0,0 +1,135 @@ +package repository + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/redis/go-redis/v9" +) + +// 基准测试用 TTL 配置 +const benchSlotTTLMinutes = 15 + +var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute + +// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。 +func BenchmarkAccountConcurrency(b *testing.B) { + rdb := newBenchmarkRedisClient(b) + defer func() { + _ = rdb.Close() + }() + + cache := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) + ctx := context.Background() + + for _, size := range []int{10, 100, 1000} { + size := size + b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) { + accountID := time.Now().UnixNano() + key := accountSlotKey(accountID) + + b.StopTimer() + members := make([]redis.Z, 0, size) + now := float64(time.Now().Unix()) + for i := 0; i < size; i++ { + members = append(members, redis.Z{ + Score: now, + Member: fmt.Sprintf("req_%d", i), + }) + } + if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil { + b.Fatalf("初始化有序集合失败: %v", err) + } + if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil { + b.Fatalf("设置有序集合 TTL 失败: %v", err) + } + b.StartTimer() + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil { + b.Fatalf("获取并发数量失败: %v", err) + } + } + + b.StopTimer() + if err := rdb.Del(ctx, key).Err(); err != nil { + b.Fatalf("清理有序集合失败: %v", err) + } + }) + + b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) { + accountID := time.Now().UnixNano() + pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID) + keys := make([]string, 0, size) + + b.StopTimer() + pipe := rdb.Pipeline() + for i := 0; i < size; i++ { + key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i) + keys = append(keys, key) + pipe.Set(ctx, key, "1", benchSlotTTL) + } + if _, err := pipe.Exec(ctx); err != nil { + b.Fatalf("初始化扫描键失败: %v", err) + } + b.StartTimer() + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := scanSlotCount(ctx, rdb, pattern); err != nil { + b.Fatalf("SCAN 计数失败: %v", err) + } + } + + b.StopTimer() + if err := rdb.Del(ctx, keys...).Err(); err != nil { + b.Fatalf("清理扫描键失败: %v", err) + } + }) + } +} + +func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) { + var cursor uint64 + count := 0 + for { + keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result() + if err != nil { + return 0, err + } + count += len(keys) + if nextCursor == 0 { + break + } + cursor = nextCursor + } + return count, nil +} + +func newBenchmarkRedisClient(b *testing.B) *redis.Client { + b.Helper() + + redisURL := os.Getenv("TEST_REDIS_URL") + if redisURL == "" { + b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试") + } + + opt, err := redis.ParseURL(redisURL) + if err != nil { + b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err) + } + + client := redis.NewClient(opt) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + b.Fatalf("Redis 连接失败: %v", err) + } + + return client +} diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index c1feaf85..6a7c83f4 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -14,6 +14,12 @@ import ( "github.com/stretchr/testify/suite" ) +// 测试用 TTL 配置(15 分钟,与默认值一致) +const testSlotTTLMinutes = 15 + +// 测试用 TTL Duration,用于 TTL 断言 +var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute + type ConcurrencyCacheSuite struct { IntegrationRedisSuite cache service.ConcurrencyCache @@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct { func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() - s.cache = NewConcurrencyCache(s.rdb) + s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) } func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { @@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { accountID := int64(11) reqID := "req_ttl_test" - slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID) + slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID) require.NoError(s.T(), err, "AcquireAccountSlot") @@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() require.NoError(s.T(), err, "TTL") - s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) } func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() { @@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { userID := int64(200) reqID := "req_ttl_test" - slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID) + slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID) require.NoError(s.T(), err, "AcquireUserSlot") @@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() require.NoError(s.T(), err, "TTL") - s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) } func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { @@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() require.NoError(s.T(), err, "TTL waitKey") - s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 4e9bae3e..bac8736b 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh } func createGeminiReqClient(proxyURL string) *req.Client { - client := req.C().SetTimeout(60 * time.Second) - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 60 * time.Second, + }) } diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index 0a5d813c..d7f54e85 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken } func createGeminiCliReqClient(proxyURL string) *req.Client { - client := req.C().SetTimeout(30 * time.Second) - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 30 * time.Second, + }) } func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest { diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 25b65a4b..3fa4b1ff 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -9,6 +9,7 @@ import ( "os" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -17,10 +18,14 @@ type githubReleaseClient struct { } func NewGitHubReleaseClient() service.GitHubReleaseClient { + sharedClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 30 * time.Second, + }) + if err != nil { + sharedClient = &http.Client{Timeout: 30 * time.Second} + } return &githubReleaseClient{ - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, + httpClient: sharedClient, } } @@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string return err } - client := &http.Client{Timeout: 10 * time.Minute} - resp, err := client.Do(req) + downloadClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 10 * time.Minute, + }) + if err != nil { + downloadClient = &http.Client{Timeout: 10 * time.Minute} + } + resp, err := downloadClient.Do(req) if err != nil { return err } diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 26befb25..0ca85a09 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -3,65 +3,104 @@ package repository import ( "net/http" "net/url" + "strings" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" ) -// httpUpstreamService is a generic HTTP upstream service that can be used for -// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support. +// httpUpstreamService 通用 HTTP 上游服务 +// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理 +// +// 性能优化: +// 1. 使用 sync.Map 缓存代理客户端实例,避免每次请求都创建新的 http.Client +// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销 +// 3. 原实现每次请求都 new 一个 http.Client,导致连接无法复用 type httpUpstreamService struct { + // defaultClient: 无代理时使用的默认客户端(单例复用) defaultClient *http.Client - cfg *config.Config + // proxyClients: 按代理 URL 缓存的客户端池,避免重复创建 + proxyClients sync.Map + cfg *config.Config } -// NewHTTPUpstream creates a new generic HTTP upstream service +// NewHTTPUpstream 创建通用 HTTP 上游服务 +// 使用配置中的连接池参数构建 Transport func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { - responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout == 0 { - responseHeaderTimeout = 300 * time.Second - } - - transport := &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - ResponseHeaderTimeout: responseHeaderTimeout, - } - return &httpUpstreamService{ - defaultClient: &http.Client{Transport: transport}, + defaultClient: &http.Client{Transport: buildUpstreamTransport(cfg, nil)}, cfg: cfg, } } func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) { - if proxyURL == "" { + if strings.TrimSpace(proxyURL) == "" { return s.defaultClient.Do(req) } - client := s.createProxyClient(proxyURL) + client := s.getOrCreateClient(proxyURL) return client.Do(req) } -func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client { +// getOrCreateClient 获取或创建代理客户端 +// 性能优化:使用 sync.Map 实现无锁缓存,相同代理 URL 复用同一客户端 +// LoadOrStore 保证并发安全,避免重复创建 +func (s *httpUpstreamService) getOrCreateClient(proxyURL string) *http.Client { + proxyURL = strings.TrimSpace(proxyURL) + if proxyURL == "" { + return s.defaultClient + } + // 优先从缓存获取,命中则直接返回 + if cached, ok := s.proxyClients.Load(proxyURL); ok { + return cached.(*http.Client) + } + parsedURL, err := url.Parse(proxyURL) if err != nil { return s.defaultClient } - responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout == 0 { + // 创建新客户端并缓存,LoadOrStore 保证只有一个实例被存储 + client := &http.Client{Transport: buildUpstreamTransport(s.cfg, parsedURL)} + actual, _ := s.proxyClients.LoadOrStore(proxyURL, client) + return actual.(*http.Client) +} + +// buildUpstreamTransport 构建上游请求的 Transport +// 使用配置文件中的连接池参数,支持生产环境调优 +func buildUpstreamTransport(cfg *config.Config, proxyURL *url.URL) *http.Transport { + // 读取配置,使用合理的默认值 + maxIdleConns := cfg.Gateway.MaxIdleConns + if maxIdleConns <= 0 { + maxIdleConns = 240 + } + maxIdleConnsPerHost := cfg.Gateway.MaxIdleConnsPerHost + if maxIdleConnsPerHost <= 0 { + maxIdleConnsPerHost = 120 + } + maxConnsPerHost := cfg.Gateway.MaxConnsPerHost + if maxConnsPerHost < 0 { + maxConnsPerHost = 240 + } + idleConnTimeout := time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second + if idleConnTimeout <= 0 { + idleConnTimeout = 300 * time.Second + } + responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second + if responseHeaderTimeout <= 0 { responseHeaderTimeout = 300 * time.Second } transport := &http.Transport{ - Proxy: http.ProxyURL(parsedURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, + MaxIdleConns: maxIdleConns, // 最大空闲连接总数 + MaxIdleConnsPerHost: maxIdleConnsPerHost, // 每主机最大空闲连接 + MaxConnsPerHost: maxConnsPerHost, // 每主机最大连接数(含活跃) + IdleConnTimeout: idleConnTimeout, // 空闲连接超时 ResponseHeaderTimeout: responseHeaderTimeout, } - - return &http.Client{Transport: transport} + if proxyURL != nil { + transport.Proxy = http.ProxyURL(proxyURL) + } + return transport } diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go new file mode 100644 index 00000000..2ea6e31a --- /dev/null +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -0,0 +1,46 @@ +package repository + +import ( + "net/http" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var httpClientSink *http.Client + +// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销。 +func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300}, + } + upstream := NewHTTPUpstream(cfg) + svc, ok := upstream.(*httpUpstreamService) + if !ok { + b.Fatalf("类型断言失败,无法获取 httpUpstreamService") + } + + proxyURL := "http://127.0.0.1:8080" + b.ReportAllocs() + + b.Run("新建", func(b *testing.B) { + parsedProxy, err := url.Parse(proxyURL) + if err != nil { + b.Fatalf("解析代理地址失败: %v", err) + } + for i := 0; i < b.N; i++ { + httpClientSink = &http.Client{ + Transport: buildUpstreamTransport(cfg, parsedProxy), + } + } + }) + + b.Run("复用", func(b *testing.B) { + client := svc.getOrCreateClient(proxyURL) + b.ResetTimer() + for i := 0; i < b.N; i++ { + httpClientSink = client + } + }) +} diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index 9bc38dae..74132e1d 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -40,13 +40,13 @@ func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() { +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDefault() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5} up := NewHTTPUpstream(s.cfg) svc, ok := up.(*httpUpstreamService) require.True(s.T(), ok, "expected *httpUpstreamService") - got := svc.createProxyClient("://bad-proxy-url") + got := svc.getOrCreateClient("://bad-proxy-url") require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback") } diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index da14a338..07d57410 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro } func createOpenAIReqClient(proxyURL string) *req.Client { - client := req.C(). - SetTimeout(60 * time.Second) - - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 60 * time.Second, + }) } diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index ccfebd1b..11f82fd3 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -16,10 +17,14 @@ type pricingRemoteClient struct { } func NewPricingRemoteClient() service.PricingRemoteClient { + sharedClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 30 * time.Second, + }) + if err != nil { + sharedClient = &http.Client{Timeout: 30 * time.Second} + } return &pricingRemoteClient{ - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, + httpClient: sharedClient, } } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 9331859c..8b288c3c 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -2,18 +2,14 @@ package repository import ( "context" - "crypto/tls" "encoding/json" "fmt" "io" - "net" "net/http" - "net/url" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" - - "golang.org/x/net/proxy" ) func NewProxyExitInfoProber() service.ProxyExitInfoProber { @@ -27,14 +23,14 @@ type proxyProbeService struct { } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { - transport, err := createProxyTransport(proxyURL) + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 15 * time.Second, + InsecureSkipVerify: true, + ProxyStrict: true, + }) if err != nil { - return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err) - } - - client := &http.Client{ - Transport: transport, - Timeout: 15 * time.Second, + return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) } startTime := time.Now() @@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s Country: ipInfo.Country, }, latencyMs, nil } - -func createProxyTransport(proxyURL string) (*http.Transport, error) { - parsedURL, err := url.Parse(proxyURL) - if err != nil { - return nil, fmt.Errorf("invalid proxy URL: %w", err) - } - - transport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - - switch parsedURL.Scheme { - case "http", "https": - transport.Proxy = http.ProxyURL(parsedURL) - case "socks5": - dialer, err := proxy.FromURL(parsedURL, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("failed to create socks5 dialer: %w", err) - } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - default: - return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme) - } - - return transport, nil -} diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go index 25ab0f9c..74d99c6d 100644 --- a/backend/internal/repository/proxy_probe_service_test.go +++ b/backend/internal/repository/proxy_probe_service_test.go @@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) { s.proxySrv = httptest.NewServer(handler) } -func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() { - _, err := createProxyTransport("://bad") +func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() { + _, _, err := s.prober.ProbeProxy(s.ctx, "://bad") require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "invalid proxy URL") + require.ErrorContains(s.T(), err, "failed to create proxy client") } -func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() { - _, err := createProxyTransport("ftp://127.0.0.1:1") +func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() { + _, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1") require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "unsupported proxy protocol") -} - -func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() { - tr, err := createProxyTransport("socks5://127.0.0.1:1080") - require.NoError(s.T(), err, "createProxyTransport") - require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5") + require.ErrorContains(s.T(), err, "failed to create proxy client") } func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go new file mode 100644 index 00000000..bfe0ccd2 --- /dev/null +++ b/backend/internal/repository/req_client_pool.go @@ -0,0 +1,59 @@ +package repository + +import ( + "fmt" + "strings" + "sync" + "time" + + "github.com/imroc/req/v3" +) + +// reqClientOptions 定义 req 客户端的构建参数 +type reqClientOptions struct { + ProxyURL string // 代理 URL(支持 http/https/socks5) + Timeout time.Duration // 请求超时时间 + Impersonate bool // 是否模拟 Chrome 浏览器指纹 +} + +// sharedReqClients 存储按配置参数缓存的 req 客户端实例 +// +// 性能优化说明: +// 原实现在每次 OAuth 刷新时都创建新的 req.Client: +// 1. claude_oauth_service.go: 每次刷新创建新客户端 +// 2. openai_oauth_service.go: 每次刷新创建新客户端 +// 3. gemini_oauth_client.go: 每次刷新创建新客户端 +// +// 新实现使用 sync.Map 缓存客户端: +// 1. 相同配置(代理+超时+模拟设置)复用同一客户端 +// 2. 复用底层连接池,减少 TLS 握手开销 +// 3. LoadOrStore 保证并发安全,避免重复创建 +var sharedReqClients sync.Map + +// getSharedReqClient 获取共享的 req 客户端实例 +// 性能优化:相同配置复用同一客户端,避免重复创建 +func getSharedReqClient(opts reqClientOptions) *req.Client { + key := buildReqClientKey(opts) + if cached, ok := sharedReqClients.Load(key); ok { + return cached.(*req.Client) + } + + client := req.C().SetTimeout(opts.Timeout) + if opts.Impersonate { + client = client.ImpersonateChrome() + } + if strings.TrimSpace(opts.ProxyURL) != "" { + client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) + } + + actual, _ := sharedReqClients.LoadOrStore(key, client) + return actual.(*req.Client) +} + +func buildReqClientKey(opts reqClientOptions) string { + return fmt.Sprintf("%s|%s|%t", + strings.TrimSpace(opts.ProxyURL), + opts.Timeout.String(), + opts.Impersonate, + ) +} diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go index c3755011..cf6083e2 100644 --- a/backend/internal/repository/turnstile_service.go +++ b/backend/internal/repository/turnstile_service.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -20,11 +21,15 @@ type turnstileVerifier struct { } func NewTurnstileVerifier() service.TurnstileVerifier { + sharedClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 10 * time.Second, + }) + if err != nil { + sharedClient = &http.Client{Timeout: 10 * time.Second} + } return &turnstileVerifier{ - httpClient: &http.Client{ - Timeout: 10 * time.Second, - }, - verifyURL: turnstileVerifyURL, + httpClient: sharedClient, + verifyURL: turnstileVerifyURL, } } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 9341f20e..4e26d751 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -452,6 +452,161 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe return &stats, nil } +// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据 +// +// 性能优化说明: +// 原实现先查询所有日志记录,再在应用层循环计算统计值: +// 1. 需要传输大量数据到应用层 +// 2. 应用层循环计算增加 CPU 和内存开销 +// +// 新实现使用 SQL 聚合函数: +// 1. 在数据库层完成 COUNT/SUM/AVG 计算 +// 2. 只返回单行聚合结果,大幅减少数据传输量 +// 3. 利用数据库索引优化聚合查询性能 +func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 + ` + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{accountID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 +// 性能优化:数据库层聚合计算,避免应用层循环统计 +func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE model = $1 AND created_at >= $2 AND created_at < $3 + ` + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{modelName, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据 +// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计 +func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) { + query := ` + SELECT + TO_CHAR(created_at, 'YYYY-MM-DD') as date, + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 + GROUP BY 1 + ORDER BY 1 + ` + + rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + result = nil + } + }() + + result = make([]map[string]any, 0) + for rows.Next() { + var ( + date string + totalRequests int64 + totalInputTokens int64 + totalOutputTokens int64 + totalCacheTokens int64 + totalCost float64 + totalActualCost float64 + avgDurationMs float64 + ) + if err = rows.Scan( + &date, + &totalRequests, + &totalInputTokens, + &totalOutputTokens, + &totalCacheTokens, + &totalCost, + &totalActualCost, + &avgDurationMs, + ); err != nil { + return nil, err + } + result = append(result, map[string]any{ + "date": date, + "total_requests": totalRequests, + "total_input_tokens": totalInputTokens, + "total_output_tokens": totalOutputTokens, + "total_cache_tokens": totalCacheTokens, + "total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens, + "total_cost": totalCost, + "total_actual_cost": totalActualCost, + "average_duration_ms": avgDurationMs, + }) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 53d42d90..edeaf782 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -1,9 +1,18 @@ package repository import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/wire" + "github.com/redis/go-redis/v9" ) +// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 +// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 +func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { + return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) +} + // ProviderSet is the Wire provider set for all repositories var ProviderSet = wire.NewSet( NewUserRepository, @@ -20,7 +29,7 @@ var ProviderSet = wire.NewSet( NewGatewayCache, NewBillingCache, NewApiKeyCache, - NewConcurrencyCache, + ProvideConcurrencyCache, NewEmailCache, NewIdentityCache, NewRedeemCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 8d5ace96..5a243bfc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -981,6 +981,18 @@ func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyI return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/request_body_limit.go b/backend/internal/server/middleware/request_body_limit.go new file mode 100644 index 00000000..fce13eea --- /dev/null +++ b/backend/internal/server/middleware/request_body_limit.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。 +func RequestBodyLimit(maxBytes int64) gin.HandlerFunc { + return func(c *gin.Context) { + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes) + c.Next() + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 34792be8..38df9225 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -18,8 +18,11 @@ func RegisterGatewayRoutes( subscriptionService *service.SubscriptionService, cfg *config.Config, ) { + bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) + // API网关(Claude API兼容) gateway := r.Group("/v1") + gateway.Use(bodyLimit) gateway.Use(gin.HandlerFunc(apiKeyAuth)) { gateway.POST("/messages", h.Gateway.Messages) @@ -32,6 +35,7 @@ func RegisterGatewayRoutes( // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) gemini := r.Group("/v1beta") + gemini.Use(bodyLimit) gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) @@ -41,10 +45,11 @@ func RegisterGatewayRoutes( } // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) antigravityV1 := r.Group("/antigravity/v1") + antigravityV1.Use(bodyLimit) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) { @@ -55,6 +60,7 @@ func RegisterGatewayRoutes( } antigravityV1Beta := r.Group("/antigravity/v1beta") + antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 94d4c747..dba670b0 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -52,6 +52,9 @@ type UsageLogRepository interface { // Aggregated stats (optimized) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) } // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 9493a11f..ac320535 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -27,6 +28,45 @@ type subscriptionCacheData struct { Version int64 } +// 缓存写入任务类型 +type cacheWriteKind int + +const ( + cacheWriteSetBalance cacheWriteKind = iota + cacheWriteSetSubscription + cacheWriteUpdateSubscriptionUsage + cacheWriteDeductBalance +) + +// 异步缓存写入工作池配置 +// +// 性能优化说明: +// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题: +// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine +// 2. 无法控制并发数量,可能导致 Redis 连接耗尽 +// 3. goroutine 创建/销毁带来额外开销 +// +// 新实现使用固定大小的工作池: +// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁 +// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值 +// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受) +// 4. 统一超时控制,避免慢操作阻塞工作池 +const ( + cacheWriteWorkerCount = 10 // 工作协程数量 + cacheWriteBufferSize = 1000 // 任务队列缓冲大小 + cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 +) + +// cacheWriteTask 缓存写入任务 +type cacheWriteTask struct { + kind cacheWriteKind + userID int64 + groupID int64 + balance float64 + amount float64 + subscriptionData *subscriptionCacheData +} + // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { @@ -34,16 +74,81 @@ type BillingCacheService struct { userRepo UserRepository subRepo UserSubscriptionRepository cfg *config.Config + + cacheWriteChan chan cacheWriteTask + cacheWriteWg sync.WaitGroup + cacheWriteStopOnce sync.Once } // NewBillingCacheService 创建计费缓存服务 func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { - return &BillingCacheService{ + svc := &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, cfg: cfg, } + svc.startCacheWriteWorkers() + return svc +} + +// Stop 关闭缓存写入工作池 +func (s *BillingCacheService) Stop() { + s.cacheWriteStopOnce.Do(func() { + if s.cacheWriteChan == nil { + return + } + close(s.cacheWriteChan) + s.cacheWriteWg.Wait() + s.cacheWriteChan = nil + }) +} + +func (s *BillingCacheService) startCacheWriteWorkers() { + s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize) + for i := 0; i < cacheWriteWorkerCount; i++ { + s.cacheWriteWg.Add(1) + go s.cacheWriteWorker() + } +} + +func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) { + if s.cacheWriteChan == nil { + return + } + defer func() { + _ = recover() + }() + select { + case s.cacheWriteChan <- task: + default: + } +} + +func (s *BillingCacheService) cacheWriteWorker() { + defer s.cacheWriteWg.Done() + for task := range s.cacheWriteChan { + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + switch task.kind { + case cacheWriteSetBalance: + s.setBalanceCache(ctx, task.userID, task.balance) + case cacheWriteSetSubscription: + s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData) + case cacheWriteUpdateSubscriptionUsage: + if s.cache != nil { + if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil { + log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) + } + } + case cacheWriteDeductBalance: + if s.cache != nil { + if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { + log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err) + } + } + } + cancel() + } } // ============================================ @@ -70,11 +175,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) } // 异步建立缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - s.setBalanceCache(cacheCtx, userID, balance) - }() + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetBalance, + userID: userID, + balance: balance, + }) return balance, nil } @@ -98,7 +203,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, } } -// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存) +// DeductBalanceCache 扣减余额缓存(同步调用) func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error { if s.cache == nil { return nil @@ -106,6 +211,15 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int return s.cache.DeductUserBalance(ctx, userID, amount) } +// QueueDeductBalance 异步扣减余额缓存 +func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteDeductBalance, + userID: userID, + amount: amount, + }) +} + // InvalidateUserBalance 失效用户余额缓存 func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error { if s.cache == nil { @@ -141,11 +255,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, } // 异步建立缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - s.setSubscriptionCache(cacheCtx, userID, groupID, data) - }() + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetSubscription, + userID: userID, + groupID: groupID, + subscriptionData: data, + }) return data, nil } @@ -199,7 +314,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, } } -// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存) +// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用) func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error { if s.cache == nil { return nil @@ -207,6 +322,16 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD) } +// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存 +func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) { + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteUpdateSubscriptionUsage, + userID: userID, + groupID: groupID, + amount: costUSD, + }) +} + // InvalidateSubscription 失效指定订阅缓存 func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error { if s.cache == nil { diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go new file mode 100644 index 00000000..445d5319 --- /dev/null +++ b/backend/internal/service/billing_cache_service_test.go @@ -0,0 +1,75 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type billingCacheWorkerStub struct { + balanceUpdates int64 + subscriptionUpdates int64 +} + +func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("not implemented") +} + +func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + atomic.AddInt64(&b.balanceUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + atomic.AddInt64(&b.balanceUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { + return nil, errors.New("not implemented") +} + +func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { + atomic.AddInt64(&b.subscriptionUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + atomic.AddInt64(&b.subscriptionUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +func TestBillingCacheServiceQueueHighLoad(t *testing.T) { + cache := &billingCacheWorkerStub{} + svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + t.Cleanup(svc.Stop) + + start := time.Now() + for i := 0; i < cacheWriteBufferSize*2; i++ { + svc.QueueDeductBalance(1, 1) + } + require.Less(t, time.Since(start), 2*time.Second) + + svc.QueueUpdateSubscriptionUsage(1, 2, 1.5) + + require.Eventually(t, func() bool { + return atomic.LoadInt64(&cache.balanceUpdates) > 0 + }, 2*time.Second, 10*time.Millisecond) + + require.Eventually(t, func() bool { + return atomic.LoadInt64(&cache.subscriptionUpdates) > 0 + }, 2*time.Second, 10*time.Millisecond) +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index a6cff234..b5229491 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -9,22 +9,22 @@ import ( "time" ) -// ConcurrencyCache defines cache operations for concurrency service -// Uses independent keys per request slot with native Redis TTL for automatic cleanup +// ConcurrencyCache 定义并发控制的缓存接口 +// 使用有序集合存储槽位,按时间戳清理过期条目 type ConcurrencyCache interface { - // Account slot management - each slot is a separate key with independent TTL - // Key format: concurrency:account:{accountID}:{requestID} + // 账号槽位管理 + // 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) - // User slot management - each slot is a separate key with independent TTL - // Key format: concurrency:user:{userID}:{requestID} + // 用户槽位管理 + // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error GetUserConcurrency(ctx context.Context, userID int64) (int, error) - // Wait queue - uses counter with TTL set only on creation + // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error } diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index cd1dbcec..fd23ecb2 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" ) type CRSSyncService struct { @@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput return nil, errors.New("username and password are required") } - client := &http.Client{Timeout: 20 * time.Second} + client, err := httpclient.GetClient(httpclient.Options{ + Timeout: 20 * time.Second, + }) + if err != nil { + client = &http.Client{Timeout: 20 * time.Second} + } adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password) if err != nil { diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go new file mode 100644 index 00000000..6d358c36 --- /dev/null +++ b/backend/internal/service/gateway_request.go @@ -0,0 +1,70 @@ +package service + +import ( + "encoding/json" + "fmt" +) + +// ParsedRequest 保存网关请求的预解析结果 +// +// 性能优化说明: +// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次): +// 1. gateway_handler.go 解析获取 model 和 stream +// 2. gateway_service.go 再次解析获取 system、messages、metadata +// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段 +// +// 新实现一次解析,多处复用: +// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析 +// 2. 将解析结果 ParsedRequest 传递给 Service 层 +// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 +type ParsedRequest struct { + Body []byte // 原始请求体(保留用于转发) + Model string // 请求的模型名称 + Stream bool // 是否为流式请求 + MetadataUserID string // metadata.user_id(用于会话亲和) + System any // system 字段内容 + Messages []any // messages 数组 + HasSystem bool // 是否包含 system 字段 +} + +// ParseGatewayRequest 解析网关请求体并返回结构化结果 +// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal +func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + parsed := &ParsedRequest{ + Body: body, + } + + if rawModel, exists := req["model"]; exists { + model, ok := rawModel.(string) + if !ok { + return nil, fmt.Errorf("invalid model field type") + } + parsed.Model = model + } + if rawStream, exists := req["stream"]; exists { + stream, ok := rawStream.(bool) + if !ok { + return nil, fmt.Errorf("invalid stream field type") + } + parsed.Stream = stream + } + if metadata, ok := req["metadata"].(map[string]any); ok { + if userID, ok := metadata["user_id"].(string); ok { + parsed.MetadataUserID = userID + } + } + if system, ok := req["system"]; ok && system != nil { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } + + return parsed, nil +} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go new file mode 100644 index 00000000..c921e0f6 --- /dev/null +++ b/backend/internal/service/gateway_request_test.go @@ -0,0 +1,38 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseGatewayRequest(t *testing.T) { + body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.Equal(t, "claude-3-7-sonnet", parsed.Model) + require.True(t, parsed.Stream) + require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID) + require.True(t, parsed.HasSystem) + require.NotNil(t, parsed.System) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_SystemNull(t *testing.T) { + body := []byte(`{"model":"claude-3","system":null}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.False(t, parsed.HasSystem) +} + +func TestParseGatewayRequest_InvalidModelType(t *testing.T) { + body := []byte(`{"model":123}`) + _, err := ParseGatewayRequest(body) + require.Error(t, err) +} + +func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { + body := []byte(`{"stream":"true"}`) + _, err := ParseGatewayRequest(body) + require.Error(t, err) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 50bfd161..41362662 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -19,7 +19,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -33,7 +32,10 @@ const ( // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). -var sseDataRe = regexp.MustCompile(`^data:\s*`) +var ( + sseDataRe = regexp.MustCompile(`^data:\s*`) + sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) +) // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ @@ -141,40 +143,36 @@ func NewGatewayService( } } -// GenerateSessionHash 从请求体计算粘性会话hash -func (s *GatewayService) GenerateSessionHash(body []byte) string { - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { +// GenerateSessionHash 从预解析请求计算粘性会话 hash +func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { + if parsed == nil { return "" } - // 1. 最高优先级:从metadata.user_id提取session_xxx - if metadata, ok := req["metadata"].(map[string]any); ok { - if userID, ok := metadata["user_id"].(string); ok { - re := regexp.MustCompile(`session_([a-f0-9-]{36})`) - if match := re.FindStringSubmatch(userID); len(match) > 1 { - return match[1] - } + // 1. 最高优先级:从 metadata.user_id 提取 session_xxx + if parsed.MetadataUserID != "" { + if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { + return match[1] } } - // 2. 提取带cache_control: {type: "ephemeral"}的内容 - cacheableContent := s.extractCacheableContent(req) + // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 + cacheableContent := s.extractCacheableContent(parsed) if cacheableContent != "" { return s.hashContent(cacheableContent) } - // 3. Fallback: 使用system内容 - if system := req["system"]; system != nil { - systemText := s.extractTextFromSystem(system) + // 3. Fallback: 使用 system 内容 + if parsed.System != nil { + systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { return s.hashContent(systemText) } } - // 4. 最后fallback: 使用第一条消息 - if messages, ok := req["messages"].([]any); ok && len(messages) > 0 { - if firstMsg, ok := messages[0].(map[string]any); ok { + // 4. 最后 fallback: 使用第一条消息 + if len(parsed.Messages) > 0 { + if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { msgText := s.extractTextFromContent(firstMsg["content"]) if msgText != "" { return s.hashContent(msgText) @@ -185,36 +183,38 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string { return "" } -func (s *GatewayService) extractCacheableContent(req map[string]any) string { - var content string +func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } - // 检查system中的cacheable内容 - if system, ok := req["system"].([]any); ok { + var builder strings.Builder + + // 检查 system 中的 cacheable 内容 + if system, ok := parsed.System.([]any); ok { for _, part := range system { if partMap, ok := part.(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc["type"] == "ephemeral" { if text, ok := partMap["text"].(string); ok { - content += text + builder.WriteString(text) } } } } } } + systemText := builder.String() - // 检查messages中的cacheable内容 - if messages, ok := req["messages"].([]any); ok { - for _, msg := range messages { - if msgMap, ok := msg.(map[string]any); ok { - if msgContent, ok := msgMap["content"].([]any); ok { - for _, part := range msgContent { - if partMap, ok := part.(map[string]any); ok { - if cc, ok := partMap["cache_control"].(map[string]any); ok { - if cc["type"] == "ephemeral" { - // 找到cacheable内容,提取第一条消息的文本 - return s.extractTextFromContent(msgMap["content"]) - } + // 检查 messages 中的 cacheable 内容 + for _, msg := range parsed.Messages { + if msgMap, ok := msg.(map[string]any); ok { + if msgContent, ok := msgMap["content"].([]any); ok { + for _, part := range msgContent { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + return s.extractTextFromContent(msgMap["content"]) } } } @@ -223,7 +223,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string { } } - return content + return systemText } func (s *GatewayService) extractTextFromSystem(system any) string { @@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { } // Forward 转发请求到Claude API -func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { +func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() - - // 解析请求获取model和stream - var req struct { - Model string `json:"model"` - Stream bool `json:"stream"` - } - if err := json.Unmarshal(body, &req); err != nil { - return nil, fmt.Errorf("parse request: %w", err) + if parsed == nil { + return nil, fmt.Errorf("parse request: empty request") } - if !gjson.GetBytes(body, "system").Exists() { + body := parsed.Body + reqModel := parsed.Model + reqStream := parsed.Stream + + if !parsed.HasSystem { body, _ = sjson.SetBytes(body, "system", []any{ map[string]any{ "type": "text", @@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 应用模型映射(仅对apikey类型账号) - originalModel := req.Model + originalModel := reqModel if account.Type == AccountTypeApiKey { - mappedModel := account.GetMappedModel(req.Model) - if mappedModel != req.Model { + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { // 替换请求体中的模型名 body = s.replaceModelInBody(body, mappedModel) - req.Model = mappedModel + reqModel = mappedModel log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) } } @@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A var resp *http.Response for attempt := 1; attempt <= maxRetries; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) if err != nil { return nil, err } @@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理正常响应 var usage *ClaudeUsage var firstTokenMs *int - if req.Stream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model) + if reqStream { + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) if err != nil { if err.Error() == "have error in stream" { return nil, &UpstreamFailoverError{ @@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) if err != nil { return nil, err } @@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 - Stream: req.Stream, + Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeApiKey { @@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) } return req, nil @@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // getBetaHeader 处理anthropic-beta header // 对于OAuth账号,需要确保包含oauth-2025-04-20 -func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string { +func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { // 如果客户端传了anthropic-beta if clientBetaHeader != "" { // 已包含oauth beta则直接返回 @@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str } // 客户端没传,根据模型生成 - var modelID string - var reqMap map[string]any - if json.Unmarshal(body, &reqMap) == nil { - if m, ok := reqMap["model"].(string); ok { - modelID = m - } - } - - // haiku模型不需要claude-code beta + // haiku 模型不需要 claude-code beta if strings.Contains(strings.ToLower(modelID), "haiku") { return claude.HaikuBetaHeader } @@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu log.Printf("Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil { - log.Printf("Update subscription cache failed: %v", err) - } - }() + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) @@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu log.Printf("Deduct balance failed: %v", err) } // 异步更新余额缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil { - log.Printf("Update balance cache failed: %v", err) - } - }() + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) } } @@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 -func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error { +func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { + if parsed == nil { + s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return fmt.Errorf("parse request: empty request") + } + + body := parsed.Body + reqModel := parsed.Model + // Antigravity 账户不支持 count_tokens 转发,返回估算值 // 参考 Antigravity-Manager 和 proxycast 实现 if account.Platform == PlatformAntigravity { @@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 应用模型映射(仅对 apikey 类型账号) if account.Type == AccountTypeApiKey { - var req struct { - Model string `json:"model"` - } - if err := json.Unmarshal(body, &req); err == nil && req.Model != "" { - mappedModel := account.GetMappedModel(req.Model) - if mappedModel != req.Model { + if reqModel != "" { + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { body = s.replaceModelInBody(body, mappedModel) - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name) + reqModel = mappedModel + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) } } } @@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 构建上游请求 - upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel) if err != nil { s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") return err @@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL if account.Type == AccountTypeApiKey { @@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) } return req, nil diff --git a/backend/internal/service/gateway_service_benchmark_test.go b/backend/internal/service/gateway_service_benchmark_test.go new file mode 100644 index 00000000..f15a85d6 --- /dev/null +++ b/backend/internal/service/gateway_service_benchmark_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "strconv" + "testing" +) + +var benchmarkStringSink string + +// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。 +func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { + svc := &GatewayService{} + body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + parsed, err := ParseGatewayRequest(body) + if err != nil { + b.Fatalf("解析请求失败: %v", err) + } + benchmarkStringSink = svc.GenerateSessionHash(parsed) + } +} + +// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。 +func BenchmarkExtractCacheableContent_System(b *testing.B) { + svc := &GatewayService{} + req := buildSystemCacheableRequest(12) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + benchmarkStringSink = svc.extractCacheableContent(req) + } +} + +func buildSystemCacheableRequest(parts int) *ParsedRequest { + systemParts := make([]any, 0, parts) + for i := 0; i < parts; i++ { + systemParts = append(systemParts, map[string]any{ + "text": "system_part_" + strconv.Itoa(i), + "cache_control": map[string]any{ + "type": "ephemeral", + }, + }) + } + return &ParsedRequest{ + System: systemParts, + HasSystem: true, + } +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 6ccbf8ec..34958541 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -921,7 +921,10 @@ func sleepGeminiBackoff(attempt int) { time.Sleep(sleepFor) } -var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`) +var ( + sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`) + retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`) +) func sanitizeUpstreamErrorMessage(msg string) string { if msg == "" { @@ -1925,7 +1928,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 { } // Match "Please retry in Xs" - retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`) matches := retryInRegex.FindStringSubmatch(string(body)) if len(matches) == 2 { if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 36257667..e4bda5f8 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -7,13 +7,13 @@ import ( "fmt" "io" "net/http" - "net/url" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" ) type GeminiOAuthService struct { @@ -497,11 +497,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) - client := &http.Client{Timeout: 30 * time.Second} - if strings.TrimSpace(proxyURL) != "" { - if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil { - client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)} - } + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: strings.TrimSpace(proxyURL), + Timeout: 30 * time.Second, + }) + if err != nil { + client = &http.Client{Timeout: 30 * time.Second} } resp, err := client.Do(req) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 79801b29..aa844554 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -768,20 +768,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if isSubscriptionBilling { if cost.TotalCost > 0 { _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost) - }() + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { if cost.ActualCost > 0 { _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost) - }() + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) } } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index e2e263d9..bb050d0a 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -18,6 +18,11 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var ( + openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) + openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) +) + // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { @@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 3. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { - // 正则匹配日期后缀 (如 -20251222) - datePattern := regexp.MustCompile(`-\d{8}$`) - // 尝试的回退变体 - variants := s.generateOpenAIModelVariants(model, datePattern) + variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern) for _, variant := range variants { if pricing, ok := s.pricingData[variant]; ok { @@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern * // 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2 // 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的 - basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) - if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 { + if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 { addVariant(matches[1]) } // 3. 同时去掉日期后再提取基础版本号 if withoutDate != model { - if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 { + if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 { addVariant(matches[1]) } } diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 0df8a0de..f57e90eb 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta // GetStatsByAccount 获取账号的使用统计 func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) { - logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime) + stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get account stats: %w", err) } - return s.calculateStats(logs), nil + return &UsageStats{ + TotalRequests: stats.TotalRequests, + TotalInputTokens: stats.TotalInputTokens, + TotalOutputTokens: stats.TotalOutputTokens, + TotalCacheTokens: stats.TotalCacheTokens, + TotalTokens: stats.TotalTokens, + TotalCost: stats.TotalCost, + TotalActualCost: stats.TotalActualCost, + AverageDurationMs: stats.AverageDurationMs, + }, nil } // GetStatsByModel 获取模型的使用统计 func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) { - logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime) + stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get model stats: %w", err) } - return s.calculateStats(logs), nil + return &UsageStats{ + TotalRequests: stats.TotalRequests, + TotalInputTokens: stats.TotalInputTokens, + TotalOutputTokens: stats.TotalOutputTokens, + TotalCacheTokens: stats.TotalCacheTokens, + TotalTokens: stats.TotalTokens, + TotalCost: stats.TotalCost, + TotalActualCost: stats.TotalActualCost, + AverageDurationMs: stats.AverageDurationMs, + }, nil } // GetDailyStats 获取每日使用统计(最近N天) @@ -209,54 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int endTime := time.Now() startTime := endTime.AddDate(0, 0, -days) - logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime) + stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get daily stats: %w", err) } - // 按日期分组统计 - dailyStats := make(map[string]*UsageStats) - for _, log := range logs { - dateKey := log.CreatedAt.Format("2006-01-02") - if _, exists := dailyStats[dateKey]; !exists { - dailyStats[dateKey] = &UsageStats{} - } - - stats := dailyStats[dateKey] - stats.TotalRequests++ - stats.TotalInputTokens += int64(log.InputTokens) - stats.TotalOutputTokens += int64(log.OutputTokens) - stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens) - stats.TotalTokens += int64(log.TotalTokens()) - stats.TotalCost += log.TotalCost - stats.TotalActualCost += log.ActualCost - - if log.DurationMs != nil { - stats.AverageDurationMs += float64(*log.DurationMs) - } - } - - // 计算平均值并转换为数组 - result := make([]map[string]any, 0, len(dailyStats)) - for date, stats := range dailyStats { - if stats.TotalRequests > 0 { - stats.AverageDurationMs /= float64(stats.TotalRequests) - } - - result = append(result, map[string]any{ - "date": date, - "total_requests": stats.TotalRequests, - "total_input_tokens": stats.TotalInputTokens, - "total_output_tokens": stats.TotalOutputTokens, - "total_cache_tokens": stats.TotalCacheTokens, - "total_tokens": stats.TotalTokens, - "total_cost": stats.TotalCost, - "total_actual_cost": stats.TotalActualCost, - "average_duration_ms": stats.AverageDurationMs, - }) - } - - return result, nil + return stats, nil } // calculateStats 计算统计数据 diff --git a/backend/migrations/010_add_usage_logs_aggregated_indexes.sql b/backend/migrations/010_add_usage_logs_aggregated_indexes.sql new file mode 100644 index 00000000..ab2dbbc1 --- /dev/null +++ b/backend/migrations/010_add_usage_logs_aggregated_indexes.sql @@ -0,0 +1,4 @@ +-- 为聚合查询补充复合索引 +CREATE INDEX IF NOT EXISTS idx_usage_logs_account_created_at ON usage_logs(account_id, created_at); +CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_created_at ON usage_logs(api_key_id, created_at); +CREATE INDEX IF NOT EXISTS idx_usage_logs_model_created_at ON usage_logs(model, created_at); diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index fcaa7b7c..8db4cbc9 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -21,6 +21,22 @@ server: # - simple: Hides SaaS features and skips billing/balance checks run_mode: "standard" +# ============================================================================= +# 网关配置 +# ============================================================================= +gateway: + # 等待上游响应头超时时间(秒) + response_header_timeout: 300 + # 请求体最大字节数(默认 100MB) + max_body_size: 104857600 + # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认) + max_idle_conns: 240 + max_idle_conns_per_host: 120 + max_conns_per_host: 240 + idle_conn_timeout_seconds: 300 + # 并发槽位过期时间(分钟) + concurrency_slot_ttl_minutes: 15 + # ============================================================================= # Database Configuration (PostgreSQL) # =============================================================================