perf(后端): 完成性能优化与连接池配置

新增 DB/Redis 连接池配置与校验,并补充单测

网关请求体大小限制与 413 处理

HTTP/req 客户端池化并调整上游连接池默认值

并发槽位改为 ZSET+Lua 与指数退避

用量统计改 SQL 聚合并新增索引迁移

计费缓存写入改工作池并补测试/基准

测试: 在 backend/ 下运行 go test ./...
This commit is contained in:
yangjianbo
2025-12-31 08:50:12 +08:00
parent 5376786694
commit 7efa8b54c4
53 changed files with 1805 additions and 449 deletions

View File

@@ -67,6 +67,7 @@ func provideCleanup(
tokenRefresh *service.TokenRefreshService, tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
oauth *service.OAuthService, oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
@@ -94,6 +95,10 @@ func provideCleanup(
emailQueue.Stop() emailQueue.Stop()
return nil return nil
}}, }},
{"BillingCacheService", func() error {
billingCache.Stop()
return nil
}},
{"OAuthService", func() error { {"OAuthService", func() error {
oauth.Stop() oauth.Stop()
return nil return nil

View File

@@ -39,11 +39,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
sqlDB, err := infrastructure.ProvideSQLDB(client) db, err := infrastructure.ProvideSQLDB(client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userRepository := repository.NewUserRepository(client, sqlDB) userRepository := repository.NewUserRepository(client, db)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := infrastructure.ProvideRedis(configConfig) redisClient := infrastructure.ProvideRedis(configConfig)
@@ -57,12 +57,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
authHandler := handler.NewAuthHandler(configConfig, authService, userService) authHandler := handler.NewAuthHandler(configConfig, authService, userService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(client) apiKeyRepository := repository.NewApiKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, sqlDB) groupRepository := repository.NewGroupRepository(client, db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
apiKeyCache := repository.NewApiKeyCache(redisClient) apiKeyCache := repository.NewApiKeyCache(redisClient)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, sqlDB) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository) usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client) redeemCodeRepository := repository.NewRedeemCodeRepository(client)
@@ -75,8 +75,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository) dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService) dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, sqlDB) accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, sqlDB) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber() proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
@@ -95,7 +95,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig) httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(redisClient) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.NewConcurrencyService(concurrencyCache) concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
@@ -142,7 +142,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig) antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher) v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
@@ -170,6 +170,7 @@ func provideCleanup(
tokenRefresh *service.TokenRefreshService, tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
oauth *service.OAuthService, oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
@@ -196,6 +197,10 @@ func provideCleanup(
emailQueue.Stop() emailQueue.Stop()
return nil return nil
}}, }},
{"BillingCacheService", func() error {
billingCache.Stop()
return nil
}},
{"OAuthService", func() error { {"OAuthService", func() error {
oauth.Stop() oauth.Stop()
return nil return nil

View File

@@ -79,12 +79,29 @@ type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时 // 等待上游响应头的超时时间0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间 // 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
MaxIdleConns int `mapstructure:"max_idle_conns"`
// MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率)
MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"`
// MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲0表示无限制
MaxConnsPerHost int `mapstructure:"max_conns_per_host"`
// IdleConnTimeoutSeconds: 空闲连接超时时间(秒)
IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"`
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
} }
func (s *ServerConfig) Address() string { func (s *ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port) return fmt.Sprintf("%s:%d", s.Host, s.Port)
} }
// DatabaseConfig 数据库连接配置
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
type DatabaseConfig struct { type DatabaseConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
@@ -92,6 +109,15 @@ type DatabaseConfig struct {
Password string `mapstructure:"password"` Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"` DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"` SSLMode string `mapstructure:"sslmode"`
// 连接池配置(性能优化:可配置化连接池参数)
// MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽
MaxOpenConns int `mapstructure:"max_open_conns"`
// MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟
MaxIdleConns int `mapstructure:"max_idle_conns"`
// ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏
ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"`
// ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接
ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"`
} }
func (d *DatabaseConfig) DSN() string { func (d *DatabaseConfig) DSN() string {
@@ -112,11 +138,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
) )
} }
// RedisConfig Redis 连接配置
// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量
type RedisConfig struct { type RedisConfig struct {
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Password string `mapstructure:"password"` Password string `mapstructure:"password"`
DB int `mapstructure:"db"` DB int `mapstructure:"db"`
// 连接池与超时配置(性能优化:可配置化连接池参数)
// DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
// ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
// WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
// PoolSize: 连接池大小,控制最大并发连接数
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
} }
func (r *RedisConfig) Address() string { func (r *RedisConfig) Address() string {
@@ -203,12 +242,21 @@ func setDefaults() {
viper.SetDefault("database.password", "postgres") viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable") viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.max_open_conns", 50)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
// Redis // Redis
viper.SetDefault("redis.host", "localhost") viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379) viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "") viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0) viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
// JWT // JWT
viper.SetDefault("jwt.secret", "change-me-in-production") viper.SetDefault("jwt.secret", "change-me-in-production")
@@ -240,6 +288,13 @@ func setDefaults() {
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数含活跃HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
// TokenRefresh // TokenRefresh
viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.enabled", true)
@@ -263,6 +318,57 @@ func (c *Config) Validate() error {
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" { if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production") return fmt.Errorf("jwt.secret must be changed in production")
} }
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
}
if c.Database.MaxIdleConns < 0 {
return fmt.Errorf("database.max_idle_conns must be non-negative")
}
if c.Database.MaxIdleConns > c.Database.MaxOpenConns {
return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns")
}
if c.Database.ConnMaxLifetimeMinutes < 0 {
return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative")
}
if c.Database.ConnMaxIdleTimeMinutes < 0 {
return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative")
}
if c.Redis.DialTimeoutSeconds <= 0 {
return fmt.Errorf("redis.dial_timeout_seconds must be positive")
}
if c.Redis.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("redis.read_timeout_seconds must be positive")
}
if c.Redis.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("redis.write_timeout_seconds must be positive")
}
if c.Redis.PoolSize <= 0 {
return fmt.Errorf("redis.pool_size must be positive")
}
if c.Redis.MinIdleConns < 0 {
return fmt.Errorf("redis.min_idle_conns must be non-negative")
}
if c.Redis.MinIdleConns > c.Redis.PoolSize {
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive")
}
if c.Gateway.MaxIdleConnsPerHost <= 0 {
return fmt.Errorf("gateway.max_idle_conns_per_host must be positive")
}
if c.Gateway.MaxConnsPerHost < 0 {
return fmt.Errorf("gateway.max_conns_per_host must be non-negative")
}
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
return nil return nil
} }

View File

@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return return
} }
@@ -76,15 +80,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 解析请求获取模型名和stream parsedReq, err := service.ParseGatewayRequest(body)
var req struct { if err != nil {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
} }
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// Track if we've started streaming (for error handling) // Track if we've started streaming (for error handling)
streamStarted := false streamStarted := false
@@ -106,7 +108,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位 // 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -124,7 +126,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 计算粘性会话hash // 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台 // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
platform := "" platform := ""
@@ -141,7 +143,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -153,16 +155,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream { if reqStream {
sendMockWarmupStream(c, req.Model) sendMockWarmupStream(c, reqModel)
} else { } else {
sendMockWarmupResponse(c, req.Model) sendMockWarmupResponse(c, reqModel)
} }
return return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -172,7 +174,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body) result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
} else { } else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
} }
@@ -223,7 +225,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -235,16 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream { if reqStream {
sendMockWarmupStream(c, req.Model) sendMockWarmupStream(c, reqModel)
} else { } else {
sendMockWarmupResponse(c, req.Model) sendMockWarmupResponse(c, reqModel)
} }
return return
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -256,7 +258,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if account.Platform == service.PlatformAntigravity { if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else { } else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
} }
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
@@ -496,6 +498,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return return
} }
@@ -505,11 +511,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
// 解析请求获取模型名 parsedReq, err := service.ParseGatewayRequest(body)
var req struct { if err != nil {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return return
} }
@@ -525,17 +528,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
} }
// 计算粘性会话 hash // 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号 // 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return return
} }
// 转发请求(不记录使用量) // 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil { if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err) log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理 // 错误响应已在 ForwardCountTokens 中处理
return return

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"net/http" "net/http"
"time" "time"
@@ -11,11 +12,28 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// 并发槽位等待相关常量
//
// 性能优化说明:
// 原实现使用固定间隔100ms轮询并发槽位存在以下问题
// 1. 高并发时频繁轮询增加 Redis 压力
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
//
// 新实现使用指数退避 + 抖动算法:
// 1. 初始退避 100ms每次乘以 1.5,最大 2s
// 2. 添加 ±20% 的随机抖动,分散重试时间点
// 3. 减少 Redis 压力,避免惊群效应
const ( const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot // maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait // pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second pingInterval = 15 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
backoffMultiplier = 1.5
// maxBackoff 最大退避时间
maxBackoff = 2 * time.Second
) )
// SSEPingFormat defines the format of SSE ping events for different platforms // SSEPingFormat defines the format of SSE ping events for different platforms
@@ -131,8 +149,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
pingCh = pingTicker.C pingCh = pingTicker.C
} }
pollTicker := time.NewTicker(100 * time.Millisecond) backoff := initialBackoff
defer pollTicker.Stop() timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for { for {
select { select {
@@ -156,7 +176,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
} }
flusher.Flush() flusher.Flush()
case <-pollTicker.C: case <-timer.C:
// Try to acquire slot // Try to acquire slot
var result *service.AcquireResult var result *service.AcquireResult
var err error var err error
@@ -174,6 +194,35 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
if result.Acquired { if result.Acquired {
return result.ReleaseFunc, nil return result.ReleaseFunc, nil
} }
backoff = nextBackoff(backoff, rng)
timer.Reset(backoff)
} }
} }
} }
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil此时不添加抖动
// 返回值下一次退避时间100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动jitter 范围 0.8 ~ 1.2
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff
}
if jittered > maxBackoff {
return maxBackoff
}
return jittered
}

View File

@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
return
}
googleError(c, http.StatusBadRequest, "Failed to read request body") googleError(c, http.StatusBadRequest, "Failed to read request body")
return return
} }
@@ -191,7 +195,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body) parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
const maxAccountSwitches = 3 const maxAccountSwitches = 3
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})

View File

@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Read request body // Read request body
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return return
} }

View File

@@ -0,0 +1,27 @@
package handler
import (
"errors"
"fmt"
"net/http"
)
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
return maxErr, true
}
return nil, false
}
func formatBodyLimit(limit int64) string {
const mb = 1024 * 1024
if limit >= mb {
return fmt.Sprintf("%dMB", limit/mb)
}
return fmt.Sprintf("%dB", limit)
}
func buildBodyTooLargeMessage(limit int64) string {
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
}

View File

@@ -0,0 +1,45 @@
package handler
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRequestBodyLimitTooLarge(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := int64(16)
router := gin.New()
router.Use(middleware.RequestBodyLimit(limit))
router.POST("/test", func(c *gin.Context) {
_, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": buildBodyTooLargeMessage(maxErr.Limit),
})
return
}
c.JSON(http.StatusBadRequest, gin.H{
"error": "read_failed",
})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
payload := bytes.Repeat([]byte("a"), int(limit+1))
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
}

View File

@@ -0,0 +1,32 @@
package infrastructure
import (
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}

View File

@@ -0,0 +1,50 @@
package infrastructure
import (
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}

View File

@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
applyDBPoolSettings(drv.DB(), cfg)
// 确保数据库 schema 已准备就绪。 // 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源source of truth // SQL 迁移文件是 schema 的权威来源source of truth

View File

@@ -1,16 +1,39 @@
package infrastructure package infrastructure
import ( import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
// InitRedis 初始化 Redis 客户端 // InitRedis 初始化 Redis 客户端
//
// 性能优化说明:
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
// 1. 默认连接池大小可能不足以支撑高并发
// 2. 无超时控制可能导致慢操作阻塞
//
// 新实现支持可配置的连接池和超时参数:
// 1. PoolSize: 控制最大并发连接数(默认 128
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
func InitRedis(cfg *config.Config) *redis.Client { func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{ return redis.NewClient(buildRedisOptions(cfg))
}
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
Addr: cfg.Redis.Address(), Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password, Password: cfg.Redis.Password,
DB: cfg.Redis.DB, DB: cfg.Redis.DB,
}) DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
} }

View File

@@ -0,0 +1,35 @@
package infrastructure
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestBuildRedisOptions(t *testing.T) {
cfg := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "secret",
DB: 2,
DialTimeoutSeconds: 5,
ReadTimeoutSeconds: 3,
WriteTimeoutSeconds: 4,
PoolSize: 100,
MinIdleConns: 10,
},
}
opts := buildRedisOptions(cfg)
require.Equal(t, "localhost:6379", opts.Addr)
require.Equal(t, "secret", opts.Password)
require.Equal(t, 2, opts.DB)
require.Equal(t, 5*time.Second, opts.DialTimeout)
require.Equal(t, 3*time.Second, opts.ReadTimeout)
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
}

View File

@@ -0,0 +1,152 @@
// Package httpclient 提供共享 HTTP 客户端池
//
// 性能优化说明:
// 原实现在多个服务中重复创建 http.Client
// 1. proxy_probe_service.go: 每次探测创建新客户端
// 2. pricing_service.go: 每次请求创建新客户端
// 3. turnstile_service.go: 每次验证创建新客户端
// 4. github_release_service.go: 每次请求创建新客户端
// 5. claude_usage_service.go: 每次请求创建新客户端
//
// 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
// 4. 支持严格代理模式(代理失败则返回错误)
package httpclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
)
// Options 定义共享 HTTP 客户端的构建参数
type Options struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制)
}
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
func GetClient(opts Options) (*http.Client, error) {
key := buildClientKey(opts)
if cached, ok := sharedClients.Load(key); ok {
return cached.(*http.Client), nil
}
client, err := buildClient(opts)
if err != nil {
if opts.ProxyStrict {
return nil, err
}
fallback := opts
fallback.ProxyURL = ""
client, _ = buildClient(fallback)
}
actual, _ := sharedClients.LoadOrStore(key, client)
return actual.(*http.Client), nil
}
func buildClient(opts Options) (*http.Client, error) {
transport, err := buildTransport(opts)
if err != nil {
return nil, err
}
return &http.Client{
Transport: transport,
Timeout: opts.Timeout,
}, nil
}
func buildTransport(opts Options) (*http.Transport, error) {
// 使用自定义值或默认值
maxIdleConns := opts.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns
}
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
if maxIdleConnsPerHost <= 0 {
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
transport := &http.Transport{
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
IdleConnTimeout: defaultIdleConnTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
transport.Proxy = http.ProxyURL(parsed)
case "socks5", "socks5h":
dialer, err := proxy.FromURL(parsed, proxy.Direct)
if err != nil {
return nil, err
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
}
return transport, nil
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}

View File

@@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
} }
func createReqClient(proxyURL string) *req.Client { func createReqClient(proxyURL string) *req.Client {
client := req.C(). return getSharedReqClient(reqClientOptions{
ImpersonateChrome(). ProxyURL: proxyURL,
SetTimeout(60 * time.Second) Timeout: 60 * time.Second,
Impersonate: true,
if proxyURL != "" { })
client.SetProxyURL(proxyURL)
}
return client
} }
func prefix(s string, n int) string { func prefix(s string, n int) string {

View File

@@ -6,9 +6,9 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
@@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
} }
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport, ok := http.DefaultTransport.(*http.Transport) client, err := httpclient.GetClient(httpclient.Options{
if !ok { ProxyURL: proxyURL,
return nil, fmt.Errorf("failed to get default transport")
}
transport = transport.Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
} }
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)

View File

@@ -3,67 +3,90 @@ package repository
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
// 并发控制缓存常量定义
//
// 性能优化说明:
// 原实现使用 SCAN 命令遍历独立的槽位键concurrency:account:{id}:{requestID}
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
//
// 新实现改用 Redis 有序集合Sorted Set
// 1. 每个账号/用户只有一个键,成员为 requestID分数为时间戳
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
// 4. 单次 Redis 调用完成计数,减少网络往返
const ( const (
// Key prefixes for independent slot keys // 并发槽位键前缀(有序集合)
// Format: concurrency:account:{accountID}:{requestID} // 格式: concurrency:account:{accountID}
accountSlotKeyPrefix = "concurrency:account:" accountSlotKeyPrefix = "concurrency:account:"
// Format: concurrency:user:{userID}:{requestID} // 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:" userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID} // 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:" waitQueueKeyPrefix = "concurrency:wait:"
// Slot TTL - each slot expires independently // 默认槽位过期时间(分钟),可通过配置覆盖
slotTTL = 5 * time.Minute defaultSlotTTLMinutes = 15
) )
var ( var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit // acquireScript 使用有序集合计数并在未达上限时添加槽位
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*") // 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx") // KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
// ARGV[1] = maxConcurrency // ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds // ARGV[2] = TTL(秒)
// ARGV[3] = requestID
acquireScript = redis.NewScript(` acquireScript = redis.NewScript(`
local pattern = KEYS[1] local key = KEYS[1]
local slotKey = KEYS[2]
local maxConcurrency = tonumber(ARGV[1]) local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2]) local ttl = tonumber(ARGV[2])
local requestID = ARGV[3]
-- Count existing slots using SCAN -- 使用 Redis 服务器时间,确保多实例时钟一致
local cursor = "0" local timeResult = redis.call('TIME')
local count = 0 local now = tonumber(timeResult[1])
repeat local expireBefore = now - ttl
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- Check if we can acquire a slot -- 清理过期槽位
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查是否已存在(支持重试场景刷新时间戳)
local exists = redis.call('ZSCORE', key, requestID)
if exists ~= false then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
-- 检查是否达到并发上限
local count = redis.call('ZCARD', key)
if count < maxConcurrency then if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl) redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1 return 1
end end
return 0 return 0
`) `)
// getCountScript counts slots using SCAN // getCountScript 统计有序集合中的槽位数量并清理过期条目
// KEYS[1] = pattern for SCAN // 使用 Redis TIME 命令获取服务器时间
// KEYS[1] = 有序集合键
// ARGV[1] = TTL
getCountScript = redis.NewScript(` getCountScript = redis.NewScript(`
local pattern = KEYS[1] local key = KEYS[1]
local cursor = "0" local ttl = tonumber(ARGV[1])
local count = 0
repeat -- 使用 Redis 服务器时间
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100) local timeResult = redis.call('TIME')
cursor = result[1] local now = tonumber(timeResult[1])
count = count + #result[2] local expireBefore = now - ttl
until cursor == "0"
return count redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`) `)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing // incrementWaitScript - only sets TTL on first creation to avoid refreshing
@@ -104,27 +127,28 @@ var (
type concurrencyCache struct { type concurrencyCache struct {
rdb *redis.Client rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
} }
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache { // NewConcurrencyCache 创建并发控制缓存
return &concurrencyCache{rdb: rdb} // slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
}
} }
// Helper functions for key generation // Helper functions for key generation
func accountSlotKey(accountID int64, requestID string) string { func accountSlotKey(accountID int64) string {
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID) return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
} }
func accountSlotPattern(accountID int64) string { func userSlotKey(userID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID) return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
} }
func waitQueueKey(userID int64) string { func waitQueueKey(userID int64) string {
@@ -134,10 +158,9 @@ func waitQueueKey(userID int64) string {
// Account slot operations // Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID) key := accountSlotKey(accountID)
slotKey := accountSlotKey(accountID, requestID) // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -145,13 +168,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int
} }
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
slotKey := accountSlotKey(accountID, requestID) key := accountSlotKey(accountID)
return c.rdb.Del(ctx, slotKey).Err() return c.rdb.ZRem(ctx, key, requestID).Err()
} }
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
pattern := accountSlotPattern(accountID) key := accountSlotKey(accountID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -161,10 +185,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
// User slot operations // User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID) key := userSlotKey(userID)
slotKey := userSlotKey(userID, requestID) // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -172,13 +195,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma
} }
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
slotKey := userSlotKey(userID, requestID) key := userSlotKey(userID)
return c.rdb.Del(ctx, slotKey).Err() return c.rdb.ZRem(ctx, key, requestID).Err()
} }
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
pattern := userSlotPattern(userID) key := userSlotKey(userID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -189,7 +213,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID) key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int() result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -0,0 +1,135 @@
package repository
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// 基准测试用 TTL 配置
const benchSlotTTLMinutes = 15
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
func BenchmarkAccountConcurrency(b *testing.B) {
rdb := newBenchmarkRedisClient(b)
defer func() {
_ = rdb.Close()
}()
cache := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {
size := size
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
key := accountSlotKey(accountID)
b.StopTimer()
members := make([]redis.Z, 0, size)
now := float64(time.Now().Unix())
for i := 0; i < size; i++ {
members = append(members, redis.Z{
Score: now,
Member: fmt.Sprintf("req_%d", i),
})
}
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
b.Fatalf("初始化有序集合失败: %v", err)
}
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
b.Fatalf("设置有序集合 TTL 失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
b.Fatalf("获取并发数量失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, key).Err(); err != nil {
b.Fatalf("清理有序集合失败: %v", err)
}
})
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
keys := make([]string, 0, size)
b.StopTimer()
pipe := rdb.Pipeline()
for i := 0; i < size; i++ {
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
keys = append(keys, key)
pipe.Set(ctx, key, "1", benchSlotTTL)
}
if _, err := pipe.Exec(ctx); err != nil {
b.Fatalf("初始化扫描键失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
b.Fatalf("SCAN 计数失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, keys...).Err(); err != nil {
b.Fatalf("清理扫描键失败: %v", err)
}
})
}
}
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
var cursor uint64
count := 0
for {
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return 0, err
}
count += len(keys)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
return count, nil
}
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
b.Helper()
redisURL := os.Getenv("TEST_REDIS_URL")
if redisURL == "" {
b.Skip("未设置 TEST_REDIS_URL跳过 Redis 基准测试")
}
opt, err := redis.ParseURL(redisURL)
if err != nil {
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
}
client := redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
b.Fatalf("Redis 连接失败: %v", err)
}
return client
}

View File

@@ -14,6 +14,12 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
// 测试用 TTL 配置15 分钟,与默认值一致)
const testSlotTTLMinutes = 15
// 测试用 TTL Duration用于 TTL 断言
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
type ConcurrencyCacheSuite struct { type ConcurrencyCacheSuite struct {
IntegrationRedisSuite IntegrationRedisSuite
cache service.ConcurrencyCache cache service.ConcurrencyCache
@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest() s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb) s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
} }
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11) accountID := int64(11)
reqID := "req_ttl_test" reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID) slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID) ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot") require.NoError(s.T(), err, "AcquireAccountSlot")
@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL") require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
} }
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() { func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200) userID := int64(200)
reqID := "req_ttl_test" reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID) slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID) ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot") require.NoError(s.T(), err, "AcquireUserSlot")
@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL") require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
} }
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey") require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")

View File

@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
} }
func createGeminiReqClient(proxyURL string) *req.Client { func createGeminiReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(60 * time.Second) return getSharedReqClient(reqClientOptions{
if proxyURL != "" { ProxyURL: proxyURL,
client.SetProxyURL(proxyURL) Timeout: 60 * time.Second,
} })
return client
} }

View File

@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
} }
func createGeminiCliReqClient(proxyURL string) *req.Client { func createGeminiCliReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(30 * time.Second) return getSharedReqClient(reqClientOptions{
if proxyURL != "" { ProxyURL: proxyURL,
client.SetProxyURL(proxyURL) Timeout: 30 * time.Second,
} })
return client
} }
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest { func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {

View File

@@ -9,6 +9,7 @@ import (
"os" "os"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
@@ -17,10 +18,14 @@ type githubReleaseClient struct {
} }
func NewGitHubReleaseClient() service.GitHubReleaseClient { func NewGitHubReleaseClient() service.GitHubReleaseClient {
return &githubReleaseClient{ sharedClient, err := httpclient.GetClient(httpclient.Options{
httpClient: &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
}, })
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
httpClient: sharedClient,
} }
} }
@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err return err
} }
client := &http.Client{Timeout: 10 * time.Minute} downloadClient, err := httpclient.GetClient(httpclient.Options{
resp, err := client.Do(req) Timeout: 10 * time.Minute,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,65 +3,104 @@ package repository
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"strings"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
// httpUpstreamService is a generic HTTP upstream service that can be used for // httpUpstreamService 通用 HTTP 上游服务
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support. // 用于向任意 HTTP APIClaudeOpenAI 等)发送请求,支持可选代理
//
// 性能优化:
// 1. 使用 sync.Map 缓存代理客户端实例,避免每次请求都创建新的 http.Client
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
// 3. 原实现每次请求都 new 一个 http.Client导致连接无法复用
type httpUpstreamService struct { type httpUpstreamService struct {
// defaultClient: 无代理时使用的默认客户端(单例复用)
defaultClient *http.Client defaultClient *http.Client
// proxyClients: 按代理 URL 缓存的客户端池,避免重复创建
proxyClients sync.Map
cfg *config.Config cfg *config.Config
} }
// NewHTTPUpstream creates a new generic HTTP upstream service // NewHTTPUpstream 创建通用 HTTP 上游服务
// 使用配置中的连接池参数构建 Transport
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &httpUpstreamService{ return &httpUpstreamService{
defaultClient: &http.Client{Transport: transport}, defaultClient: &http.Client{Transport: buildUpstreamTransport(cfg, nil)},
cfg: cfg, cfg: cfg,
} }
} }
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) { func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
if proxyURL == "" { if strings.TrimSpace(proxyURL) == "" {
return s.defaultClient.Do(req) return s.defaultClient.Do(req)
} }
client := s.createProxyClient(proxyURL) client := s.getOrCreateClient(proxyURL)
return client.Do(req) return client.Do(req)
} }
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client { // getOrCreateClient 获取或创建代理客户端
// 性能优化:使用 sync.Map 实现无锁缓存,相同代理 URL 复用同一客户端
// LoadOrStore 保证并发安全,避免重复创建
func (s *httpUpstreamService) getOrCreateClient(proxyURL string) *http.Client {
proxyURL = strings.TrimSpace(proxyURL)
if proxyURL == "" {
return s.defaultClient
}
// 优先从缓存获取,命中则直接返回
if cached, ok := s.proxyClients.Load(proxyURL); ok {
return cached.(*http.Client)
}
parsedURL, err := url.Parse(proxyURL) parsedURL, err := url.Parse(proxyURL)
if err != nil { if err != nil {
return s.defaultClient return s.defaultClient
} }
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second // 创建新客户端并缓存LoadOrStore 保证只有一个实例被存储
if responseHeaderTimeout == 0 { client := &http.Client{Transport: buildUpstreamTransport(s.cfg, parsedURL)}
actual, _ := s.proxyClients.LoadOrStore(proxyURL, client)
return actual.(*http.Client)
}
// buildUpstreamTransport 构建上游请求的 Transport
// 使用配置文件中的连接池参数,支持生产环境调优
func buildUpstreamTransport(cfg *config.Config, proxyURL *url.URL) *http.Transport {
// 读取配置,使用合理的默认值
maxIdleConns := cfg.Gateway.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = 240
}
maxIdleConnsPerHost := cfg.Gateway.MaxIdleConnsPerHost
if maxIdleConnsPerHost <= 0 {
maxIdleConnsPerHost = 120
}
maxConnsPerHost := cfg.Gateway.MaxConnsPerHost
if maxConnsPerHost < 0 {
maxConnsPerHost = 240
}
idleConnTimeout := time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
if idleConnTimeout <= 0 {
idleConnTimeout = 300 * time.Second
}
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout <= 0 {
responseHeaderTimeout = 300 * time.Second responseHeaderTimeout = 300 * time.Second
} }
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL), MaxIdleConns: maxIdleConns, // 最大空闲连接总数
MaxIdleConns: 100, MaxIdleConnsPerHost: maxIdleConnsPerHost, // 每主机最大空闲连接
MaxIdleConnsPerHost: 10, MaxConnsPerHost: maxConnsPerHost, // 每主机最大连接数(含活跃)
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: idleConnTimeout, // 空闲连接超时
ResponseHeaderTimeout: responseHeaderTimeout, ResponseHeaderTimeout: responseHeaderTimeout,
} }
if proxyURL != nil {
return &http.Client{Transport: transport} transport.Proxy = http.ProxyURL(proxyURL)
}
return transport
} }

View File

@@ -0,0 +1,46 @@
package repository
import (
"net/http"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var httpClientSink *http.Client
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销。
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
cfg := &config.Config{
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
}
upstream := NewHTTPUpstream(cfg)
svc, ok := upstream.(*httpUpstreamService)
if !ok {
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
}
proxyURL := "http://127.0.0.1:8080"
b.ReportAllocs()
b.Run("新建", func(b *testing.B) {
parsedProxy, err := url.Parse(proxyURL)
if err != nil {
b.Fatalf("解析代理地址失败: %v", err)
}
for i := 0; i < b.N; i++ {
httpClientSink = &http.Client{
Transport: buildUpstreamTransport(cfg, parsedProxy),
}
}
})
b.Run("复用", func(b *testing.B) {
client := svc.getOrCreateClient(proxyURL)
b.ResetTimer()
for i := 0; i < b.N; i++ {
httpClientSink = client
}
})
}

View File

@@ -40,13 +40,13 @@ func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
} }
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() { func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDefault() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5} s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
up := NewHTTPUpstream(s.cfg) up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService) svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService") require.True(s.T(), ok, "expected *httpUpstreamService")
got := svc.createProxyClient("://bad-proxy-url") got := svc.getOrCreateClient("://bad-proxy-url")
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback") require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
} }

View File

@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
} }
func createOpenAIReqClient(proxyURL string) *req.Client { func createOpenAIReqClient(proxyURL string) *req.Client {
client := req.C(). return getSharedReqClient(reqClientOptions{
SetTimeout(60 * time.Second) ProxyURL: proxyURL,
Timeout: 60 * time.Second,
if proxyURL != "" { })
client.SetProxyURL(proxyURL)
}
return client
} }

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
@@ -16,10 +17,14 @@ type pricingRemoteClient struct {
} }
func NewPricingRemoteClient() service.PricingRemoteClient { func NewPricingRemoteClient() service.PricingRemoteClient {
return &pricingRemoteClient{ sharedClient, err := httpclient.GetClient(httpclient.Options{
httpClient: &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
}, })
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{
httpClient: sharedClient,
} }
} }

View File

@@ -2,18 +2,14 @@ package repository
import ( import (
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy"
) )
func NewProxyExitInfoProber() service.ProxyExitInfoProber { func NewProxyExitInfoProber() service.ProxyExitInfoProber {
@@ -27,14 +23,14 @@ type proxyProbeService struct {
} }
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
transport, err := createProxyTransport(proxyURL) client, err := httpclient.GetClient(httpclient.Options{
if err != nil { ProxyURL: proxyURL,
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second, Timeout: 15 * time.Second,
InsecureSkipVerify: true,
ProxyStrict: true,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
} }
startTime := time.Now() startTime := time.Now()
@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
Country: ipInfo.Country, Country: ipInfo.Country,
}, latencyMs, nil }, latencyMs, nil
} }
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler) s.proxySrv = httptest.NewServer(handler)
} }
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() { func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
_, err := createProxyTransport("://bad") _, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
require.Error(s.T(), err) require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "invalid proxy URL") require.ErrorContains(s.T(), err, "failed to create proxy client")
} }
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() { func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
_, err := createProxyTransport("ftp://127.0.0.1:1") _, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
require.Error(s.T(), err) require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "unsupported proxy protocol") require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
require.NoError(s.T(), err, "createProxyTransport")
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
} }
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {

View File

@@ -0,0 +1,59 @@
package repository
import (
"fmt"
"strings"
"sync"
"time"
"github.com/imroc/req/v3"
)
// reqClientOptions 定义 req 客户端的构建参数
type reqClientOptions struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
//
// 性能优化说明:
// 原实现在每次 OAuth 刷新时都创建新的 req.Client
// 1. claude_oauth_service.go: 每次刷新创建新客户端
// 2. openai_oauth_service.go: 每次刷新创建新客户端
// 3. gemini_oauth_client.go: 每次刷新创建新客户端
//
// 新实现使用 sync.Map 缓存客户端:
// 1. 相同配置(代理+超时+模拟设置)复用同一客户端
// 2. 复用底层连接池,减少 TLS 握手开销
// 3. LoadOrStore 保证并发安全,避免重复创建
var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
return cached.(*req.Client)
}
client := req.C().SetTimeout(opts.Timeout)
if opts.Impersonate {
client = client.ImpersonateChrome()
}
if strings.TrimSpace(opts.ProxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
return actual.(*req.Client)
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
)
}

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
@@ -20,10 +21,14 @@ type turnstileVerifier struct {
} }
func NewTurnstileVerifier() service.TurnstileVerifier { func NewTurnstileVerifier() service.TurnstileVerifier {
return &turnstileVerifier{ sharedClient, err := httpclient.GetClient(httpclient.Options{
httpClient: &http.Client{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
}, })
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}
}
return &turnstileVerifier{
httpClient: sharedClient,
verifyURL: turnstileVerifyURL, verifyURL: turnstileVerifyURL,
} }
} }

View File

@@ -452,6 +452,161 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
return &stats, nil return &stats, nil
} }
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
//
// 性能优化说明:
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
// 1. 需要传输大量数据到应用层
// 2. 应用层循环计算增加 CPU 和内存开销
//
// 新实现使用 SQL 聚合函数:
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
// 2. 只返回单行聚合结果,大幅减少数据传输量
// 3. 利用数据库索引优化聚合查询性能
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
// 性能优化:数据库层聚合计算,避免应用层循环统计
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE model = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{modelName, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
query := `
SELECT
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY 1
ORDER BY 1
`
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
result = nil
}
}()
result = make([]map[string]any, 0)
for rows.Next() {
var (
date string
totalRequests int64
totalInputTokens int64
totalOutputTokens int64
totalCacheTokens int64
totalCost float64
totalActualCost float64
avgDurationMs float64
)
if err = rows.Scan(
&date,
&totalRequests,
&totalInputTokens,
&totalOutputTokens,
&totalCacheTokens,
&totalCost,
&totalActualCost,
&avgDurationMs,
); err != nil {
return nil, err
}
result = append(result, map[string]any{
"date": date,
"total_requests": totalRequests,
"total_input_tokens": totalInputTokens,
"total_output_tokens": totalOutputTokens,
"total_cache_tokens": totalCacheTokens,
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
"total_cost": totalCost,
"total_actual_cost": totalActualCost,
"average_duration_ms": avgDurationMs,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)

View File

@@ -1,9 +1,18 @@
package repository package repository
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9"
) )
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes)
}
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
@@ -20,7 +29,7 @@ var ProviderSet = wire.NewSet(
NewGatewayCache, NewGatewayCache,
NewBillingCache, NewBillingCache,
NewApiKeyCache, NewApiKeyCache,
NewConcurrencyCache, ProvideConcurrencyCache,
NewEmailCache, NewEmailCache,
NewIdentityCache, NewIdentityCache,
NewRedeemCache, NewRedeemCache,

View File

@@ -981,6 +981,18 @@ func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyI
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (r *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }

View File

@@ -0,0 +1,15 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
c.Next()
}
}

View File

@@ -18,8 +18,11 @@ func RegisterGatewayRoutes(
subscriptionService *service.SubscriptionService, subscriptionService *service.SubscriptionService,
cfg *config.Config, cfg *config.Config,
) { ) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
// API网关Claude API兼容 // API网关Claude API兼容
gateway := r.Group("/v1") gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(gin.HandlerFunc(apiKeyAuth)) gateway.Use(gin.HandlerFunc(apiKeyAuth))
{ {
gateway.POST("/messages", h.Gateway.Messages) gateway.POST("/messages", h.Gateway.Messages)
@@ -32,6 +35,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连) // Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta") gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
@@ -41,10 +45,11 @@ func RegisterGatewayRoutes(
} }
// OpenAI Responses API不带v1前缀的别名 // OpenAI Responses API不带v1前缀的别名
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1") antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(bodyLimit)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
{ {
@@ -55,6 +60,7 @@ func RegisterGatewayRoutes(
} }
antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{ {

View File

@@ -52,6 +52,9 @@ type UsageLogRepository interface {
// Aggregated stats (optimized) // Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
} }
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据utilization, resets_at // apiUsageCache 缓存从 Anthropic API 获取的使用率数据utilization, resets_at

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
@@ -27,6 +28,45 @@ type subscriptionCacheData struct {
Version int64 Version int64
} }
// 缓存写入任务类型
type cacheWriteKind int
const (
cacheWriteSetBalance cacheWriteKind = iota
cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance
)
// 异步缓存写入工作池配置
//
// 性能优化说明:
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
// 1. 每次请求创建新 goroutine高并发下产生大量短生命周期 goroutine
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
// 3. goroutine 创建/销毁带来额外开销
//
// 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine避免频繁创建销毁
// 2. 使用带缓冲的 channel1000作为任务队列平滑写入峰值
// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受)
// 4. 统一超时控制,避免慢操作阻塞工作池
const (
cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
)
// cacheWriteTask 缓存写入任务
type cacheWriteTask struct {
kind cacheWriteKind
userID int64
groupID int64
balance float64
amount float64
subscriptionData *subscriptionCacheData
}
// BillingCacheService 计费缓存服务 // BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
@@ -34,16 +74,81 @@ type BillingCacheService struct {
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config cfg *config.Config
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
return &BillingCacheService{ svc := &BillingCacheService{
cache: cache, cache: cache,
userRepo: userRepo, userRepo: userRepo,
subRepo: subRepo, subRepo: subRepo,
cfg: cfg, cfg: cfg,
} }
svc.startCacheWriteWorkers()
return svc
}
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil {
return
}
close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
s.cacheWriteChan = nil
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker()
}
}
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) {
if s.cacheWriteChan == nil {
return
}
defer func() {
_ = recover()
}()
select {
case s.cacheWriteChan <- task:
default:
}
}
func (s *BillingCacheService) cacheWriteWorker() {
defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
s.setBalanceCache(ctx, task.userID, task.balance)
case cacheWriteSetSubscription:
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
case cacheWriteUpdateSubscriptionUsage:
if s.cache != nil {
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
}
}
case cacheWriteDeductBalance:
if s.cache != nil {
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
}
cancel()
}
} }
// ============================================ // ============================================
@@ -70,11 +175,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
} }
// 异步建立缓存 // 异步建立缓存
go func() { s.enqueueCacheWrite(cacheWriteTask{
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) kind: cacheWriteSetBalance,
defer cancel() userID: userID,
s.setBalanceCache(cacheCtx, userID, balance) balance: balance,
}() })
return balance, nil return balance, nil
} }
@@ -98,7 +203,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
} }
} }
// DeductBalanceCache 扣减余额缓存(步调用,用于扣费后更新缓存 // DeductBalanceCache 扣减余额缓存(步调用)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error { func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
@@ -106,6 +211,15 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
return s.cache.DeductUserBalance(ctx, userID, amount) return s.cache.DeductUserBalance(ctx, userID, amount)
} }
// QueueDeductBalance 异步扣减余额缓存
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: userID,
amount: amount,
})
}
// InvalidateUserBalance 失效用户余额缓存 // InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error { func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.cache == nil { if s.cache == nil {
@@ -141,11 +255,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
} }
// 异步建立缓存 // 异步建立缓存
go func() { s.enqueueCacheWrite(cacheWriteTask{
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) kind: cacheWriteSetSubscription,
defer cancel() userID: userID,
s.setSubscriptionCache(cacheCtx, userID, groupID, data) groupID: groupID,
}() subscriptionData: data,
})
return data, nil return data, nil
} }
@@ -199,7 +314,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
} }
} }
// UpdateSubscriptionUsage 更新订阅用量缓存(步调用,用于扣费后更新缓存 // UpdateSubscriptionUsage 更新订阅用量缓存(步调用)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error { func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
@@ -207,6 +322,16 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD) return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
} }
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateSubscriptionUsage,
userID: userID,
groupID: groupID,
amount: costUSD,
})
}
// InvalidateSubscription 失效指定订阅缓存 // InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error { func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.cache == nil { if s.cache == nil {

View File

@@ -0,0 +1,75 @@
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheWorkerStub struct {
balanceUpdates int64
subscriptionUpdates int64
}
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
for i := 0; i < cacheWriteBufferSize*2; i++ {
svc.QueueDeductBalance(1, 1)
}
require.Less(t, time.Since(start), 2*time.Second)
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.balanceUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}

View File

@@ -9,22 +9,22 @@ import (
"time" "time"
) )
// ConcurrencyCache defines cache operations for concurrency service // ConcurrencyCache 定义并发控制的缓存接口
// Uses independent keys per request slot with native Redis TTL for automatic cleanup // 使用有序集合存储槽位,按时间戳清理过期条目
type ConcurrencyCache interface { type ConcurrencyCache interface {
// Account slot management - each slot is a separate key with independent TTL // 账号槽位管理
// Key format: concurrency:account:{accountID}:{requestID} // 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// User slot management - each slot is a separate key with independent TTL // 用户槽位管理
// Key format: concurrency:user:{userID}:{requestID} // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error) GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue - uses counter with TTL set only on creation // 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
} }

View File

@@ -12,6 +12,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
) )
type CRSSyncService struct { type CRSSyncService struct {
@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return nil, errors.New("username and password are required") return nil, errors.New("username and password are required")
} }
client := &http.Client{Timeout: 20 * time.Second} client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password) adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil { if err != nil {

View File

@@ -0,0 +1,70 @@
package service
import (
"encoding/json"
"fmt"
)
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
// 原实现在多个位置重复解析请求体Handler、Service 各解析一次):
// 1. gateway_handler.go 解析获取 model 和 stream
// 2. gateway_service.go 再次解析获取 system、messages、metadata
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
//
// 新实现一次解析,多处复用:
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
parsed := &ParsedRequest{
Body: body,
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
}
}
if system, ok := req["system"]; ok && system != nil {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
return parsed, nil
}

View File

@@ -0,0 +1,38 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream)
require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID)
require.True(t, parsed.HasSystem)
require.NotNil(t, parsed.System)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.False(t, parsed.HasSystem)
}
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}

View File

@@ -19,7 +19,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -33,7 +32,10 @@ const (
// sseDataRe matches SSE data lines with optional whitespace after colon. // sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataRe = regexp.MustCompile(`^data:\s*`) var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
)
// allowedHeaders 白名单headers参考CRS项目 // allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{ var allowedHeaders = map[string]bool{
@@ -141,40 +143,36 @@ func NewGatewayService(
} }
} }
// GenerateSessionHash 从请求计算粘性会话hash // GenerateSessionHash 从预解析请求计算粘性会话 hash
func (s *GatewayService) GenerateSessionHash(body []byte) string { func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
var req map[string]any if parsed == nil {
if err := json.Unmarshal(body, &req); err != nil {
return "" return ""
} }
// 1. 最高优先级从metadata.user_id提取session_xxx // 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if metadata, ok := req["metadata"].(map[string]any); ok { if parsed.MetadataUserID != "" {
if userID, ok := metadata["user_id"].(string); ok { if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
if match := re.FindStringSubmatch(userID); len(match) > 1 {
return match[1] return match[1]
} }
} }
}
// 2. 提取带cache_control: {type: "ephemeral"}的内容 // 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(req) cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" { if cacheableContent != "" {
return s.hashContent(cacheableContent) return s.hashContent(cacheableContent)
} }
// 3. Fallback: 使用system内容 // 3. Fallback: 使用 system 内容
if system := req["system"]; system != nil { if parsed.System != nil {
systemText := s.extractTextFromSystem(system) systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" { if systemText != "" {
return s.hashContent(systemText) return s.hashContent(systemText)
} }
} }
// 4. 最后fallback: 使用第一条消息 // 4. 最后 fallback: 使用第一条消息
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 { if len(parsed.Messages) > 0 {
if firstMsg, ok := messages[0].(map[string]any); ok { if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"]) msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" { if msgText != "" {
return s.hashContent(msgText) return s.hashContent(msgText)
@@ -185,34 +183,37 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
return "" return ""
} }
func (s *GatewayService) extractCacheableContent(req map[string]any) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
var content string if parsed == nil {
return ""
}
// 检查system中的cacheable内容 var builder strings.Builder
if system, ok := req["system"].([]any); ok {
// 检查 system 中的 cacheable 内容
if system, ok := parsed.System.([]any); ok {
for _, part := range system { for _, part := range system {
if partMap, ok := part.(map[string]any); ok { if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" { if cc["type"] == "ephemeral" {
if text, ok := partMap["text"].(string); ok { if text, ok := partMap["text"].(string); ok {
content += text builder.WriteString(text)
} }
} }
} }
} }
} }
} }
systemText := builder.String()
// 检查messages中的cacheable内容 // 检查 messages 中的 cacheable 内容
if messages, ok := req["messages"].([]any); ok { for _, msg := range parsed.Messages {
for _, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok { if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok { if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent { for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok { if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" { if cc["type"] == "ephemeral" {
// 找到cacheable内容提取第一条消息的文本
return s.extractTextFromContent(msgMap["content"]) return s.extractTextFromContent(msgMap["content"])
} }
} }
@@ -221,9 +222,8 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string {
} }
} }
} }
}
return content return systemText
} }
func (s *GatewayService) extractTextFromSystem(system any) string { func (s *GatewayService) extractTextFromSystem(system any) string {
@@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
} }
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
if parsed == nil {
// 解析请求获取model和stream return nil, fmt.Errorf("parse request: empty request")
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
} }
if !gjson.GetBytes(body, "system").Exists() { body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
if !parsed.HasSystem {
body, _ = sjson.SetBytes(body, "system", []any{ body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{ map[string]any{
"type": "text", "type": "text",
@@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 应用模型映射仅对apikey类型账号 // 应用模型映射仅对apikey类型账号
originalModel := req.Model originalModel := reqModel
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model) mappedModel := account.GetMappedModel(reqModel)
if mappedModel != req.Model { if mappedModel != reqModel {
// 替换请求体中的模型名 // 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel) body = s.replaceModelInBody(body, mappedModel)
req.Model = mappedModel reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
} }
} }
@@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var resp *http.Response var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ { for attempt := 1; attempt <= maxRetries; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应 // 处理正常响应
var usage *ClaudeUsage var usage *ClaudeUsage
var firstTokenMs *int var firstTokenMs *int
if req.Stream { if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model) streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil { if err != nil {
if err.Error() == "have error in stream" { if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{ return nil, &UpstreamFailoverError{
@@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
usage = streamResult.usage usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs firstTokenMs = streamResult.firstTokenMs
} else { } else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model) usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"), RequestID: resp.Header.Get("x-request-id"),
Usage: *usage, Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志 Model: originalModel, // 使用原始模型用于计费和日志
Stream: req.Stream, Stream: reqStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
}, nil }, nil
} }
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
@@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理 // 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} }
return req, nil return req, nil
@@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// getBetaHeader 处理anthropic-beta header // getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20 // 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string { func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
// 如果客户端传了anthropic-beta // 如果客户端传了anthropic-beta
if clientBetaHeader != "" { if clientBetaHeader != "" {
// 已包含oauth beta则直接返回 // 已包含oauth beta则直接返回
@@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
} }
// 客户端没传,根据模型生成 // 客户端没传,根据模型生成
var modelID string // haiku 模型不需要 claude-code beta
var reqMap map[string]any
if json.Unmarshal(body, &reqMap) == nil {
if m, ok := reqMap["model"].(string); ok {
modelID = m
}
}
// haiku模型不需要claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") { if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.HaikuBetaHeader return claude.HaikuBetaHeader
} }
@@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Increment subscription usage failed: %v", err) log.Printf("Increment subscription usage failed: %v", err)
} }
// 异步更新订阅缓存 // 异步更新订阅缓存
go func() { s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
log.Printf("Update subscription cache failed: %v", err)
}
}()
} }
} else { } else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
@@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Deduct balance failed: %v", err) log.Printf("Deduct balance failed: %v", err)
} }
// 异步更新余额缓存 // 异步更新余额缓存
go func() { s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
log.Printf("Update balance cache failed: %v", err)
}
}()
} }
} }
@@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
if parsed == nil {
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算值 // Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现 // 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
@@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
var req struct { if reqModel != "" {
Model string `json:"model"` mappedModel := account.GetMappedModel(reqModel)
} if mappedModel != reqModel {
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
body = s.replaceModelInBody(body, mappedModel) body = s.replaceModelInBody(body, mappedModel)
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name) reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
} }
} }
} }
@@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 构建上游请求 // 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil { if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err return err
@@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// buildCountTokensRequest 构建 count_tokens 上游请求 // buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey { if account.Type == AccountTypeApiKey {
@@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} }
return req, nil return req, nil

View File

@@ -0,0 +1,50 @@
package service
import (
"strconv"
"testing"
)
var benchmarkStringSink string
// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。
func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
svc := &GatewayService{}
body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body)
if err != nil {
b.Fatalf("解析请求失败: %v", err)
}
benchmarkStringSink = svc.GenerateSessionHash(parsed)
}
}
// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。
func BenchmarkExtractCacheableContent_System(b *testing.B) {
svc := &GatewayService{}
req := buildSystemCacheableRequest(12)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkStringSink = svc.extractCacheableContent(req)
}
}
func buildSystemCacheableRequest(parts int) *ParsedRequest {
systemParts := make([]any, 0, parts)
for i := 0; i < parts; i++ {
systemParts = append(systemParts, map[string]any{
"text": "system_part_" + strconv.Itoa(i),
"cache_control": map[string]any{
"type": "ephemeral",
},
})
}
return &ParsedRequest{
System: systemParts,
HasSystem: true,
}
}

View File

@@ -921,7 +921,10 @@ func sleepGeminiBackoff(attempt int) {
time.Sleep(sleepFor) time.Sleep(sleepFor)
} }
var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`) var (
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
)
func sanitizeUpstreamErrorMessage(msg string) string { func sanitizeUpstreamErrorMessage(msg string) string {
if msg == "" { if msg == "" {
@@ -1925,7 +1928,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
} }
// Match "Please retry in Xs" // Match "Please retry in Xs"
retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`)
matches := retryInRegex.FindStringSubmatch(string(body)) matches := retryInRegex.FindStringSubmatch(string(body))
if len(matches) == 2 { if len(matches) == 2 {
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {

View File

@@ -7,13 +7,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
) )
type GeminiOAuthService struct { type GeminiOAuthService struct {
@@ -497,11 +497,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client := &http.Client{Timeout: 30 * time.Second} client, err := httpclient.GetClient(httpclient.Options{
if strings.TrimSpace(proxyURL) != "" { ProxyURL: strings.TrimSpace(proxyURL),
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil { Timeout: 30 * time.Second,
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)} })
} if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
} }
resp, err := client.Do(req) resp, err := client.Do(req)

View File

@@ -768,20 +768,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if isSubscriptionBilling { if isSubscriptionBilling {
if cost.TotalCost > 0 { if cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
go func() { s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
}()
} }
} else { } else {
if cost.ActualCost > 0 { if cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
go func() { s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
}()
} }
} }

View File

@@ -18,6 +18,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
var (
openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`)
openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
)
// LiteLLMModelPricing LiteLLM价格数据结构 // LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值 // 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct { type LiteLLMModelPricing struct {
@@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex) // 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
// 正则匹配日期后缀 (如 -20251222)
datePattern := regexp.MustCompile(`-\d{8}$`)
// 尝试的回退变体 // 尝试的回退变体
variants := s.generateOpenAIModelVariants(model, datePattern) variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
for _, variant := range variants { for _, variant := range variants {
if pricing, ok := s.pricingData[variant]; ok { if pricing, ok := s.pricingData[variant]; ok {
@@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern *
// 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2 // 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2
// 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y不匹配 gpt-4o 这种带字母后缀的 // 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y不匹配 gpt-4o 这种带字母后缀的
basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 {
if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 {
addVariant(matches[1]) addVariant(matches[1])
} }
// 3. 同时去掉日期后再提取基础版本号 // 3. 同时去掉日期后再提取基础版本号
if withoutDate != model { if withoutDate != model {
if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 { if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
addVariant(matches[1]) addVariant(matches[1])
} }
} }

View File

@@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta
// GetStatsByAccount 获取账号的使用统计 // GetStatsByAccount 获取账号的使用统计
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) { func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime) stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err) return nil, fmt.Errorf("get account stats: %w", err)
} }
return s.calculateStats(logs), nil return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
} }
// GetStatsByModel 获取模型的使用统计 // GetStatsByModel 获取模型的使用统计
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) { func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime) stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err) return nil, fmt.Errorf("get model stats: %w", err)
} }
return s.calculateStats(logs), nil return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
} }
// GetDailyStats 获取每日使用统计最近N天 // GetDailyStats 获取每日使用统计最近N天
@@ -209,54 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
endTime := time.Now() endTime := time.Now()
startTime := endTime.AddDate(0, 0, -days) startTime := endTime.AddDate(0, 0, -days)
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime) stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime)
if err != nil { if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err) return nil, fmt.Errorf("get daily stats: %w", err)
} }
// 按日期分组统计 return stats, nil
dailyStats := make(map[string]*UsageStats)
for _, log := range logs {
dateKey := log.CreatedAt.Format("2006-01-02")
if _, exists := dailyStats[dateKey]; !exists {
dailyStats[dateKey] = &UsageStats{}
}
stats := dailyStats[dateKey]
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均值并转换为数组
result := make([]map[string]any, 0, len(dailyStats))
for date, stats := range dailyStats {
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
result = append(result, map[string]any{
"date": date,
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs,
})
}
return result, nil
} }
// calculateStats 计算统计数据 // calculateStats 计算统计数据

View File

@@ -0,0 +1,4 @@
-- 为聚合查询补充复合索引
CREATE INDEX IF NOT EXISTS idx_usage_logs_account_created_at ON usage_logs(account_id, created_at);
CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_created_at ON usage_logs(api_key_id, created_at);
CREATE INDEX IF NOT EXISTS idx_usage_logs_model_created_at ON usage_logs(model, created_at);

View File

@@ -21,6 +21,22 @@ server:
# - simple: Hides SaaS features and skips billing/balance checks # - simple: Hides SaaS features and skips billing/balance checks
run_mode: "standard" run_mode: "standard"
# =============================================================================
# 网关配置
# =============================================================================
gateway:
# 等待上游响应头超时时间(秒)
response_header_timeout: 300
# 请求体最大字节数(默认 100MB
max_body_size: 104857600
# HTTP 上游连接池配置HTTP/2 + 多代理场景默认)
max_idle_conns: 240
max_idle_conns_per_host: 120
max_conns_per_host: 240
idle_conn_timeout_seconds: 300
# 并发槽位过期时间(分钟)
concurrency_slot_ttl_minutes: 15
# ============================================================================= # =============================================================================
# Database Configuration (PostgreSQL) # Database Configuration (PostgreSQL)
# ============================================================================= # =============================================================================