perf(后端): 完成性能优化与连接池配置
新增 DB/Redis 连接池配置与校验,并补充单测 网关请求体大小限制与 413 处理 HTTP/req 客户端池化并调整上游连接池默认值 并发槽位改为 ZSET+Lua 与指数退避 用量统计改 SQL 聚合并新增索引迁移 计费缓存写入改工作池并补测试/基准 测试: 在 backend/ 下运行 go test ./...
This commit is contained in:
@@ -67,6 +67,7 @@ func provideCleanup(
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
@@ -94,6 +95,10 @@ func provideCleanup(
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"BillingCacheService", func() error {
|
||||
billingCache.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
oauth.Stop()
|
||||
return nil
|
||||
|
||||
@@ -39,11 +39,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB, err := infrastructure.ProvideSQLDB(client)
|
||||
db, err := infrastructure.ProvideSQLDB(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userRepository := repository.NewUserRepository(client, sqlDB)
|
||||
userRepository := repository.NewUserRepository(client, db)
|
||||
settingRepository := repository.NewSettingRepository(client)
|
||||
settingService := service.NewSettingService(settingRepository, configConfig)
|
||||
redisClient := infrastructure.ProvideRedis(configConfig)
|
||||
@@ -57,12 +57,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewApiKeyRepository(client)
|
||||
groupRepository := repository.NewGroupRepository(client, sqlDB)
|
||||
groupRepository := repository.NewGroupRepository(client, db)
|
||||
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
|
||||
apiKeyCache := repository.NewApiKeyCache(redisClient)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, sqlDB)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||
@@ -75,8 +75,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
dashboardService := service.NewDashboardService(usageLogRepository)
|
||||
dashboardHandler := admin.NewDashboardHandler(dashboardService)
|
||||
accountRepository := repository.NewAccountRepository(client, sqlDB)
|
||||
proxyRepository := repository.NewProxyRepository(client, sqlDB)
|
||||
accountRepository := repository.NewAccountRepository(client, db)
|
||||
proxyRepository := repository.NewProxyRepository(client, db)
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
|
||||
adminUserHandler := admin.NewUserHandler(adminService)
|
||||
@@ -95,7 +95,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(redisClient)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
|
||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
|
||||
@@ -142,7 +142,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||
antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig)
|
||||
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
|
||||
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -170,6 +170,7 @@ func provideCleanup(
|
||||
tokenRefresh *service.TokenRefreshService,
|
||||
pricing *service.PricingService,
|
||||
emailQueue *service.EmailQueueService,
|
||||
billingCache *service.BillingCacheService,
|
||||
oauth *service.OAuthService,
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
@@ -196,6 +197,10 @@ func provideCleanup(
|
||||
emailQueue.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"BillingCacheService", func() error {
|
||||
billingCache.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"OAuthService", func() error {
|
||||
oauth.Stop()
|
||||
return nil
|
||||
|
||||
@@ -79,12 +79,29 @@ type GatewayConfig struct {
|
||||
// 等待上游响应头的超时时间(秒),0表示无超时
|
||||
// 注意:这不影响流式数据传输,只控制等待响应头的时间
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
// 请求体最大字节数,用于网关请求体大小限制
|
||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
// MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率)
|
||||
MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"`
|
||||
// MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲),0表示无限制
|
||||
MaxConnsPerHost int `mapstructure:"max_conns_per_host"`
|
||||
// IdleConnTimeoutSeconds: 空闲连接超时时间(秒)
|
||||
IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"`
|
||||
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
|
||||
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
|
||||
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
|
||||
}
|
||||
|
||||
func (s *ServerConfig) Address() string {
|
||||
return fmt.Sprintf("%s:%d", s.Host, s.Port)
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库连接配置
|
||||
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
|
||||
type DatabaseConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
@@ -92,6 +109,15 @@ type DatabaseConfig struct {
|
||||
Password string `mapstructure:"password"`
|
||||
DBName string `mapstructure:"dbname"`
|
||||
SSLMode string `mapstructure:"sslmode"`
|
||||
// 连接池配置(性能优化:可配置化连接池参数)
|
||||
// MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
// MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
// ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏
|
||||
ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"`
|
||||
// ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接
|
||||
ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"`
|
||||
}
|
||||
|
||||
func (d *DatabaseConfig) DSN() string {
|
||||
@@ -112,11 +138,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
|
||||
)
|
||||
}
|
||||
|
||||
// RedisConfig Redis 连接配置
|
||||
// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量
|
||||
type RedisConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
// 连接池与超时配置(性能优化:可配置化连接池参数)
|
||||
// DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞
|
||||
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
|
||||
// ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池
|
||||
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
|
||||
// WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池
|
||||
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
|
||||
// PoolSize: 连接池大小,控制最大并发连接数
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
|
||||
MinIdleConns int `mapstructure:"min_idle_conns"`
|
||||
}
|
||||
|
||||
func (r *RedisConfig) Address() string {
|
||||
@@ -203,12 +242,21 @@ func setDefaults() {
|
||||
viper.SetDefault("database.password", "postgres")
|
||||
viper.SetDefault("database.dbname", "sub2api")
|
||||
viper.SetDefault("database.sslmode", "disable")
|
||||
viper.SetDefault("database.max_open_conns", 50)
|
||||
viper.SetDefault("database.max_idle_conns", 10)
|
||||
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
|
||||
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
|
||||
|
||||
// Redis
|
||||
viper.SetDefault("redis.host", "localhost")
|
||||
viper.SetDefault("redis.port", 6379)
|
||||
viper.SetDefault("redis.password", "")
|
||||
viper.SetDefault("redis.db", 0)
|
||||
viper.SetDefault("redis.dial_timeout_seconds", 5)
|
||||
viper.SetDefault("redis.read_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.write_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.pool_size", 128)
|
||||
viper.SetDefault("redis.min_idle_conns", 10)
|
||||
|
||||
// JWT
|
||||
viper.SetDefault("jwt.secret", "change-me-in-production")
|
||||
@@ -240,6 +288,13 @@ func setDefaults() {
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
|
||||
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
@@ -263,6 +318,57 @@ func (c *Config) Validate() error {
|
||||
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
|
||||
return fmt.Errorf("jwt.secret must be changed in production")
|
||||
}
|
||||
if c.Database.MaxOpenConns <= 0 {
|
||||
return fmt.Errorf("database.max_open_conns must be positive")
|
||||
}
|
||||
if c.Database.MaxIdleConns < 0 {
|
||||
return fmt.Errorf("database.max_idle_conns must be non-negative")
|
||||
}
|
||||
if c.Database.MaxIdleConns > c.Database.MaxOpenConns {
|
||||
return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns")
|
||||
}
|
||||
if c.Database.ConnMaxLifetimeMinutes < 0 {
|
||||
return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative")
|
||||
}
|
||||
if c.Database.ConnMaxIdleTimeMinutes < 0 {
|
||||
return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative")
|
||||
}
|
||||
if c.Redis.DialTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("redis.dial_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Redis.ReadTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("redis.read_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Redis.WriteTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("redis.write_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Redis.PoolSize <= 0 {
|
||||
return fmt.Errorf("redis.pool_size must be positive")
|
||||
}
|
||||
if c.Redis.MinIdleConns < 0 {
|
||||
return fmt.Errorf("redis.min_idle_conns must be non-negative")
|
||||
}
|
||||
if c.Redis.MinIdleConns > c.Redis.PoolSize {
|
||||
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
|
||||
}
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
if c.Gateway.MaxIdleConns <= 0 {
|
||||
return fmt.Errorf("gateway.max_idle_conns must be positive")
|
||||
}
|
||||
if c.Gateway.MaxIdleConnsPerHost <= 0 {
|
||||
return fmt.Errorf("gateway.max_idle_conns_per_host must be positive")
|
||||
}
|
||||
if c.Gateway.MaxConnsPerHost < 0 {
|
||||
return fmt.Errorf("gateway.max_conns_per_host must be non-negative")
|
||||
}
|
||||
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
|
||||
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
|
||||
}
|
||||
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
|
||||
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -76,15 +80,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求获取模型名和stream
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
@@ -106,7 +108,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -124,7 +126,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话hash
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||
platform := ""
|
||||
@@ -141,7 +143,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -153,16 +155,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -172,7 +174,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
}
|
||||
@@ -223,7 +225,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -235,16 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -256,7 +258,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@@ -496,6 +498,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -505,11 +511,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求获取模型名
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
@@ -525,17 +528,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 计算粘性会话 hash
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求(不记录使用量)
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
|
||||
log.Printf("Forward count_tokens request failed: %v", err)
|
||||
// 错误响应已在 ForwardCountTokens 中处理
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -11,11 +12,28 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题:
|
||||
// 1. 高并发时频繁轮询增加 Redis 压力
|
||||
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
|
||||
//
|
||||
// 新实现使用指数退避 + 抖动算法:
|
||||
// 1. 初始退避 100ms,每次乘以 1.5,最大 2s
|
||||
// 2. 添加 ±20% 的随机抖动,分散重试时间点
|
||||
// 3. 减少 Redis 压力,避免惊群效应
|
||||
const (
|
||||
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
|
||||
// maxConcurrencyWait 等待并发槽位的最大时间
|
||||
maxConcurrencyWait = 30 * time.Second
|
||||
// pingInterval is the interval for sending ping events during slot wait
|
||||
// pingInterval 流式响应等待时发送 ping 的间隔
|
||||
pingInterval = 15 * time.Second
|
||||
// initialBackoff 初始退避时间
|
||||
initialBackoff = 100 * time.Millisecond
|
||||
// backoffMultiplier 退避时间乘数(指数退避)
|
||||
backoffMultiplier = 1.5
|
||||
// maxBackoff 最大退避时间
|
||||
maxBackoff = 2 * time.Second
|
||||
)
|
||||
|
||||
// SSEPingFormat defines the format of SSE ping events for different platforms
|
||||
@@ -131,8 +149,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
pollTicker := time.NewTicker(100 * time.Millisecond)
|
||||
defer pollTicker.Stop()
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -156,7 +176,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-pollTicker.C:
|
||||
case <-timer.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
@@ -174,6 +194,35 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
backoff = nextBackoff(backoff, rng)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
// rng: 随机数生成器(可为 nil,此时不添加抖动)
|
||||
// 返回值:下一次退避时间(100ms ~ 2s 之间)
|
||||
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
|
||||
// 指数退避:当前时间 * 1.5
|
||||
next := time.Duration(float64(current) * backoffMultiplier)
|
||||
if next > maxBackoff {
|
||||
next = maxBackoff
|
||||
}
|
||||
if rng == nil {
|
||||
return next
|
||||
}
|
||||
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
|
||||
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
|
||||
jitter := 0.8 + rng.Float64()*0.4
|
||||
jittered := time.Duration(float64(next) * jitter)
|
||||
if jittered < initialBackoff {
|
||||
return initialBackoff
|
||||
}
|
||||
if jittered > maxBackoff {
|
||||
return maxBackoff
|
||||
}
|
||||
return jittered
|
||||
}
|
||||
|
||||
@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
googleError(c, http.StatusBadRequest, "Failed to read request body")
|
||||
return
|
||||
}
|
||||
@@ -191,7 +195,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
|
||||
@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
27
backend/internal/handler/request_body_limit.go
Normal file
27
backend/internal/handler/request_body_limit.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) {
|
||||
return maxErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func formatBodyLimit(limit int64) string {
|
||||
const mb = 1024 * 1024
|
||||
if limit >= mb {
|
||||
return fmt.Sprintf("%dMB", limit/mb)
|
||||
}
|
||||
return fmt.Sprintf("%dB", limit)
|
||||
}
|
||||
|
||||
func buildBodyTooLargeMessage(limit int64) string {
|
||||
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
|
||||
}
|
||||
45
backend/internal/handler/request_body_limit_test.go
Normal file
45
backend/internal/handler/request_body_limit_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRequestBodyLimitTooLarge(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := int64(16)
|
||||
router := gin.New()
|
||||
router.Use(middleware.RequestBodyLimit(limit))
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
_, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": buildBodyTooLargeMessage(maxErr.Limit),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "read_failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
payload := bytes.Repeat([]byte("a"), int(limit+1))
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
|
||||
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
|
||||
}
|
||||
32
backend/internal/infrastructure/db_pool.go
Normal file
32
backend/internal/infrastructure/db_pool.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
type dbPoolSettings struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
|
||||
return dbPoolSettings{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
|
||||
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
db.SetMaxOpenConns(settings.MaxOpenConns)
|
||||
db.SetMaxIdleConns(settings.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
|
||||
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
|
||||
}
|
||||
50
backend/internal/infrastructure/db_pool_test.go
Normal file
50
backend/internal/infrastructure/db_pool_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestBuildDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 50,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetimeMinutes: 30,
|
||||
ConnMaxIdleTimeMinutes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
require.Equal(t, 50, settings.MaxOpenConns)
|
||||
require.Equal(t, 10, settings.MaxIdleConns)
|
||||
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
|
||||
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
func TestApplyDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 40,
|
||||
MaxIdleConns: 8,
|
||||
ConnMaxLifetimeMinutes: 15,
|
||||
ConnMaxIdleTimeMinutes: 3,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
applyDBPoolSettings(db, cfg)
|
||||
stats := db.Stats()
|
||||
require.Equal(t, 40, stats.MaxOpenConnections)
|
||||
}
|
||||
@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
applyDBPoolSettings(drv.DB(), cfg)
|
||||
|
||||
// 确保数据库 schema 已准备就绪。
|
||||
// SQL 迁移文件是 schema 的权威来源(source of truth)。
|
||||
|
||||
@@ -1,16 +1,39 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// InitRedis 初始化 Redis 客户端
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
|
||||
// 1. 默认连接池大小可能不足以支撑高并发
|
||||
// 2. 无超时控制可能导致慢操作阻塞
|
||||
//
|
||||
// 新实现支持可配置的连接池和超时参数:
|
||||
// 1. PoolSize: 控制最大并发连接数(默认 128)
|
||||
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10)
|
||||
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
|
||||
func InitRedis(cfg *config.Config) *redis.Client {
|
||||
return redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
})
|
||||
return redis.NewClient(buildRedisOptions(cfg))
|
||||
}
|
||||
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
return &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
|
||||
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
|
||||
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
}
|
||||
|
||||
35
backend/internal/infrastructure/redis_test.go
Normal file
35
backend/internal/infrastructure/redis_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package infrastructure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildRedisOptions(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "secret",
|
||||
DB: 2,
|
||||
DialTimeoutSeconds: 5,
|
||||
ReadTimeoutSeconds: 3,
|
||||
WriteTimeoutSeconds: 4,
|
||||
PoolSize: 100,
|
||||
MinIdleConns: 10,
|
||||
},
|
||||
}
|
||||
|
||||
opts := buildRedisOptions(cfg)
|
||||
require.Equal(t, "localhost:6379", opts.Addr)
|
||||
require.Equal(t, "secret", opts.Password)
|
||||
require.Equal(t, 2, opts.DB)
|
||||
require.Equal(t, 5*time.Second, opts.DialTimeout)
|
||||
require.Equal(t, 3*time.Second, opts.ReadTimeout)
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
}
|
||||
152
backend/internal/pkg/httpclient/pool.go
Normal file
152
backend/internal/pkg/httpclient/pool.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Package httpclient 提供共享 HTTP 客户端池
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在多个服务中重复创建 http.Client:
|
||||
// 1. proxy_probe_service.go: 每次探测创建新客户端
|
||||
// 2. pricing_service.go: 每次请求创建新客户端
|
||||
// 3. turnstile_service.go: 每次验证创建新客户端
|
||||
// 4. github_release_service.go: 每次请求创建新客户端
|
||||
// 5. claude_usage_service.go: 每次请求创建新客户端
|
||||
//
|
||||
// 新实现使用统一的客户端池:
|
||||
// 1. 相同配置复用同一 http.Client 实例
|
||||
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
|
||||
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
|
||||
// 4. 支持严格代理模式(代理失败则返回错误)
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// Transport 连接池默认配置
|
||||
const (
|
||||
defaultMaxIdleConns = 100 // 最大空闲连接数
|
||||
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
|
||||
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
|
||||
)
|
||||
|
||||
// Options 定义共享 HTTP 客户端的构建参数
|
||||
type Options struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求总超时时间
|
||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
||||
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
|
||||
|
||||
// 可选的连接池参数(不设置则使用默认值)
|
||||
MaxIdleConns int // 最大空闲连接总数(默认 100)
|
||||
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10)
|
||||
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制)
|
||||
}
|
||||
|
||||
// sharedClients 存储按配置参数缓存的 http.Client 实例
|
||||
var sharedClients sync.Map
|
||||
|
||||
// GetClient 返回共享的 HTTP 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
|
||||
func GetClient(opts Options) (*http.Client, error) {
|
||||
key := buildClientKey(opts)
|
||||
if cached, ok := sharedClients.Load(key); ok {
|
||||
return cached.(*http.Client), nil
|
||||
}
|
||||
|
||||
client, err := buildClient(opts)
|
||||
if err != nil {
|
||||
if opts.ProxyStrict {
|
||||
return nil, err
|
||||
}
|
||||
fallback := opts
|
||||
fallback.ProxyURL = ""
|
||||
client, _ = buildClient(fallback)
|
||||
}
|
||||
|
||||
actual, _ := sharedClients.LoadOrStore(key, client)
|
||||
return actual.(*http.Client), nil
|
||||
}
|
||||
|
||||
func buildClient(opts Options) (*http.Client, error) {
|
||||
transport, err := buildTransport(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: opts.Timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildTransport(opts Options) (*http.Transport, error) {
|
||||
// 使用自定义值或默认值
|
||||
maxIdleConns := opts.MaxIdleConns
|
||||
if maxIdleConns <= 0 {
|
||||
maxIdleConns = defaultMaxIdleConns
|
||||
}
|
||||
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
|
||||
if maxIdleConnsPerHost <= 0 {
|
||||
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxIdleConnsPerHost: maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
|
||||
IdleConnTimeout: defaultIdleConnTimeout,
|
||||
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
|
||||
}
|
||||
|
||||
if opts.InsecureSkipVerify {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
proxyURL := strings.TrimSpace(opts.ProxyURL)
|
||||
if proxyURL == "" {
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
case "socks5", "socks5h":
|
||||
dialer, err := proxy.FromURL(parsed, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
func buildClientKey(opts Options) string {
|
||||
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.ResponseHeaderTimeout.String(),
|
||||
opts.InsecureSkipVerify,
|
||||
opts.ProxyStrict,
|
||||
opts.MaxIdleConns,
|
||||
opts.MaxIdleConnsPerHost,
|
||||
opts.MaxConnsPerHost,
|
||||
)
|
||||
}
|
||||
@@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
Impersonate: true,
|
||||
})
|
||||
}
|
||||
|
||||
func prefix(s string, n int) string {
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get default transport")
|
||||
}
|
||||
transport = transport.Clone()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
|
||||
@@ -3,67 +3,90 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 并发控制缓存常量定义
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
||||
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
||||
//
|
||||
// 新实现改用 Redis 有序集合(Sorted Set):
|
||||
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
||||
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
||||
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
||||
// 4. 单次 Redis 调用完成计数,减少网络往返
|
||||
const (
|
||||
// Key prefixes for independent slot keys
|
||||
// Format: concurrency:account:{accountID}:{requestID}
|
||||
// 并发槽位键前缀(有序集合)
|
||||
// 格式: concurrency:account:{accountID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// Format: concurrency:user:{userID}:{requestID}
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// Wait queue keeps counter format: concurrency:wait:{userID}
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
|
||||
// Slot TTL - each slot expires independently
|
||||
slotTTL = 5 * time.Minute
|
||||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||||
defaultSlotTTLMinutes = 15
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
|
||||
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
|
||||
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
|
||||
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
||||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
||||
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL in seconds
|
||||
// ARGV[2] = TTL(秒)
|
||||
// ARGV[3] = requestID
|
||||
acquireScript = redis.NewScript(`
|
||||
local pattern = KEYS[1]
|
||||
local slotKey = KEYS[2]
|
||||
local key = KEYS[1]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local requestID = ARGV[3]
|
||||
|
||||
-- Count existing slots using SCAN
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Check if we can acquire a slot
|
||||
-- 清理过期槽位
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
-- 检查是否已存在(支持重试场景刷新时间戳)
|
||||
local exists = redis.call('ZSCORE', key, requestID)
|
||||
if exists ~= false then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 检查是否达到并发上限
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count < maxConcurrency then
|
||||
redis.call('SET', slotKey, '1', 'EX', ttl)
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
// getCountScript counts slots using SCAN
|
||||
// KEYS[1] = pattern for SCAN
|
||||
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
||||
// 使用 Redis TIME 命令获取服务器时间
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
getCountScript = redis.NewScript(`
|
||||
local pattern = KEYS[1]
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
return count
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
-- 使用 Redis 服务器时间
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
@@ -103,28 +126,29 @@ var (
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
rdb *redis.Client
|
||||
slotTTLSeconds int // 槽位过期时间(秒)
|
||||
}
|
||||
|
||||
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
// NewConcurrencyCache 创建并发控制缓存
|
||||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
|
||||
if slotTTLMinutes <= 0 {
|
||||
slotTTLMinutes = defaultSlotTTLMinutes
|
||||
}
|
||||
return &concurrencyCache{
|
||||
rdb: rdb,
|
||||
slotTTLSeconds: slotTTLMinutes * 60,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
|
||||
func accountSlotKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func accountSlotPattern(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
|
||||
}
|
||||
|
||||
func userSlotPattern(userID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
|
||||
func userSlotKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
@@ -134,10 +158,9 @@ func waitQueueKey(userID int64) string {
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := accountSlotPattern(accountID)
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -145,13 +168,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
key := accountSlotKey(accountID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
pattern := accountSlotPattern(accountID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -161,10 +185,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := userSlotPattern(userID)
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -172,13 +195,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
key := userSlotKey(userID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
pattern := userSlotPattern(userID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -189,7 +213,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
135
backend/internal/repository/concurrency_cache_benchmark_test.go
Normal file
135
backend/internal/repository/concurrency_cache_benchmark_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 基准测试用 TTL 配置
|
||||
const benchSlotTTLMinutes = 15
|
||||
|
||||
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
|
||||
|
||||
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
|
||||
func BenchmarkAccountConcurrency(b *testing.B) {
|
||||
rdb := newBenchmarkRedisClient(b)
|
||||
defer func() {
|
||||
_ = rdb.Close()
|
||||
}()
|
||||
|
||||
cache := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, size := range []int{10, 100, 1000} {
|
||||
size := size
|
||||
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
key := accountSlotKey(accountID)
|
||||
|
||||
b.StopTimer()
|
||||
members := make([]redis.Z, 0, size)
|
||||
now := float64(time.Now().Unix())
|
||||
for i := 0; i < size; i++ {
|
||||
members = append(members, redis.Z{
|
||||
Score: now,
|
||||
Member: fmt.Sprintf("req_%d", i),
|
||||
})
|
||||
}
|
||||
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
|
||||
b.Fatalf("初始化有序集合失败: %v", err)
|
||||
}
|
||||
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
|
||||
b.Fatalf("设置有序集合 TTL 失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
|
||||
b.Fatalf("获取并发数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, key).Err(); err != nil {
|
||||
b.Fatalf("清理有序集合失败: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
keys := make([]string, 0, size)
|
||||
|
||||
b.StopTimer()
|
||||
pipe := rdb.Pipeline()
|
||||
for i := 0; i < size; i++ {
|
||||
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
|
||||
keys = append(keys, key)
|
||||
pipe.Set(ctx, key, "1", benchSlotTTL)
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
b.Fatalf("初始化扫描键失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
|
||||
b.Fatalf("SCAN 计数失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
b.Fatalf("清理扫描键失败: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
|
||||
var cursor uint64
|
||||
count := 0
|
||||
for {
|
||||
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count += len(keys)
|
||||
if nextCursor == 0 {
|
||||
break
|
||||
}
|
||||
cursor = nextCursor
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
|
||||
b.Helper()
|
||||
|
||||
redisURL := os.Getenv("TEST_REDIS_URL")
|
||||
if redisURL == "" {
|
||||
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
b.Fatalf("Redis 连接失败: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
@@ -14,6 +14,12 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// 测试用 TTL 配置(15 分钟,与默认值一致)
|
||||
const testSlotTTLMinutes = 15
|
||||
|
||||
// 测试用 TTL Duration,用于 TTL 断言
|
||||
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb)
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
|
||||
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
|
||||
@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
|
||||
}
|
||||
|
||||
func createGeminiReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().SetTimeout(60 * time.Second)
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
|
||||
}
|
||||
|
||||
func createGeminiCliReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -17,10 +18,14 @@ type githubReleaseClient struct {
|
||||
}
|
||||
|
||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
httpClient: sharedClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
resp, err := downloadClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,65 +3,104 @@ package repository
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// httpUpstreamService is a generic HTTP upstream service that can be used for
|
||||
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
|
||||
// httpUpstreamService 通用 HTTP 上游服务
|
||||
// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理
|
||||
//
|
||||
// 性能优化:
|
||||
// 1. 使用 sync.Map 缓存代理客户端实例,避免每次请求都创建新的 http.Client
|
||||
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
|
||||
// 3. 原实现每次请求都 new 一个 http.Client,导致连接无法复用
|
||||
type httpUpstreamService struct {
|
||||
// defaultClient: 无代理时使用的默认客户端(单例复用)
|
||||
defaultClient *http.Client
|
||||
cfg *config.Config
|
||||
// proxyClients: 按代理 URL 缓存的客户端池,避免重复创建
|
||||
proxyClients sync.Map
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewHTTPUpstream creates a new generic HTTP upstream service
|
||||
// NewHTTPUpstream 创建通用 HTTP 上游服务
|
||||
// 使用配置中的连接池参数构建 Transport
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &httpUpstreamService{
|
||||
defaultClient: &http.Client{Transport: transport},
|
||||
defaultClient: &http.Client{Transport: buildUpstreamTransport(cfg, nil)},
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
if proxyURL == "" {
|
||||
if strings.TrimSpace(proxyURL) == "" {
|
||||
return s.defaultClient.Do(req)
|
||||
}
|
||||
client := s.createProxyClient(proxyURL)
|
||||
client := s.getOrCreateClient(proxyURL)
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
// getOrCreateClient 获取或创建代理客户端
|
||||
// 性能优化:使用 sync.Map 实现无锁缓存,相同代理 URL 复用同一客户端
|
||||
// LoadOrStore 保证并发安全,避免重复创建
|
||||
func (s *httpUpstreamService) getOrCreateClient(proxyURL string) *http.Client {
|
||||
proxyURL = strings.TrimSpace(proxyURL)
|
||||
if proxyURL == "" {
|
||||
return s.defaultClient
|
||||
}
|
||||
// 优先从缓存获取,命中则直接返回
|
||||
if cached, ok := s.proxyClients.Load(proxyURL); ok {
|
||||
return cached.(*http.Client)
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return s.defaultClient
|
||||
}
|
||||
|
||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
// 创建新客户端并缓存,LoadOrStore 保证只有一个实例被存储
|
||||
client := &http.Client{Transport: buildUpstreamTransport(s.cfg, parsedURL)}
|
||||
actual, _ := s.proxyClients.LoadOrStore(proxyURL, client)
|
||||
return actual.(*http.Client)
|
||||
}
|
||||
|
||||
// buildUpstreamTransport 构建上游请求的 Transport
|
||||
// 使用配置文件中的连接池参数,支持生产环境调优
|
||||
func buildUpstreamTransport(cfg *config.Config, proxyURL *url.URL) *http.Transport {
|
||||
// 读取配置,使用合理的默认值
|
||||
maxIdleConns := cfg.Gateway.MaxIdleConns
|
||||
if maxIdleConns <= 0 {
|
||||
maxIdleConns = 240
|
||||
}
|
||||
maxIdleConnsPerHost := cfg.Gateway.MaxIdleConnsPerHost
|
||||
if maxIdleConnsPerHost <= 0 {
|
||||
maxIdleConnsPerHost = 120
|
||||
}
|
||||
maxConnsPerHost := cfg.Gateway.MaxConnsPerHost
|
||||
if maxConnsPerHost < 0 {
|
||||
maxConnsPerHost = 240
|
||||
}
|
||||
idleConnTimeout := time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
|
||||
if idleConnTimeout <= 0 {
|
||||
idleConnTimeout = 300 * time.Second
|
||||
}
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout <= 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConns: maxIdleConns, // 最大空闲连接总数
|
||||
MaxIdleConnsPerHost: maxIdleConnsPerHost, // 每主机最大空闲连接
|
||||
MaxConnsPerHost: maxConnsPerHost, // 每主机最大连接数(含活跃)
|
||||
IdleConnTimeout: idleConnTimeout, // 空闲连接超时
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &http.Client{Transport: transport}
|
||||
if proxyURL != nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
46
backend/internal/repository/http_upstream_benchmark_test.go
Normal file
46
backend/internal/repository/http_upstream_benchmark_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var httpClientSink *http.Client
|
||||
|
||||
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销。
|
||||
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
|
||||
}
|
||||
upstream := NewHTTPUpstream(cfg)
|
||||
svc, ok := upstream.(*httpUpstreamService)
|
||||
if !ok {
|
||||
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
|
||||
}
|
||||
|
||||
proxyURL := "http://127.0.0.1:8080"
|
||||
b.ReportAllocs()
|
||||
|
||||
b.Run("新建", func(b *testing.B) {
|
||||
parsedProxy, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析代理地址失败: %v", err)
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
httpClientSink = &http.Client{
|
||||
Transport: buildUpstreamTransport(cfg, parsedProxy),
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("复用", func(b *testing.B) {
|
||||
client := svc.getOrCreateClient(proxyURL)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
httpClientSink = client
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -40,13 +40,13 @@ func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
|
||||
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
|
||||
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDefault() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
|
||||
got := svc.createProxyClient("://bad-proxy-url")
|
||||
got := svc.getOrCreateClient("://bad-proxy-url")
|
||||
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
|
||||
}
|
||||
|
||||
|
||||
@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -16,10 +17,14 @@ type pricingRemoteClient struct {
|
||||
}
|
||||
|
||||
func NewPricingRemoteClient() service.PricingRemoteClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &pricingRemoteClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
httpClient: sharedClient,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,18 +2,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||
@@ -27,14 +23,14 @@ type proxyProbeService struct {
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
transport, err := createProxyTransport(proxyURL)
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 15 * time.Second,
|
||||
InsecureSkipVerify: true,
|
||||
ProxyStrict: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 15 * time.Second,
|
||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
Country: ipInfo.Country,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
case "socks5":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
@@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
|
||||
_, err := createProxyTransport("://bad")
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "invalid proxy URL")
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
|
||||
_, err := createProxyTransport("ftp://127.0.0.1:1")
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
|
||||
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
|
||||
require.NoError(s.T(), err, "createProxyTransport")
|
||||
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
|
||||
59
backend/internal/repository/req_client_pool.go
Normal file
59
backend/internal/repository/req_client_pool.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// reqClientOptions 定义 req 客户端的构建参数
|
||||
type reqClientOptions struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求超时时间
|
||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||
}
|
||||
|
||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在每次 OAuth 刷新时都创建新的 req.Client:
|
||||
// 1. claude_oauth_service.go: 每次刷新创建新客户端
|
||||
// 2. openai_oauth_service.go: 每次刷新创建新客户端
|
||||
// 3. gemini_oauth_client.go: 每次刷新创建新客户端
|
||||
//
|
||||
// 新实现使用 sync.Map 缓存客户端:
|
||||
// 1. 相同配置(代理+超时+模拟设置)复用同一客户端
|
||||
// 2. 复用底层连接池,减少 TLS 握手开销
|
||||
// 3. LoadOrStore 保证并发安全,避免重复创建
|
||||
var sharedReqClients sync.Map
|
||||
|
||||
// getSharedReqClient 获取共享的 req 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建
|
||||
func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
key := buildReqClientKey(opts)
|
||||
if cached, ok := sharedReqClients.Load(key); ok {
|
||||
return cached.(*req.Client)
|
||||
}
|
||||
|
||||
client := req.C().SetTimeout(opts.Timeout)
|
||||
if opts.Impersonate {
|
||||
client = client.ImpersonateChrome()
|
||||
}
|
||||
if strings.TrimSpace(opts.ProxyURL) != "" {
|
||||
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
|
||||
}
|
||||
|
||||
actual, _ := sharedReqClients.LoadOrStore(key, client)
|
||||
return actual.(*req.Client)
|
||||
}
|
||||
|
||||
func buildReqClientKey(opts reqClientOptions) string {
|
||||
return fmt.Sprintf("%s|%s|%t",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.Impersonate,
|
||||
)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -20,11 +21,15 @@ type turnstileVerifier struct {
|
||||
}
|
||||
|
||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 10 * time.Second}
|
||||
}
|
||||
return &turnstileVerifier{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
verifyURL: turnstileVerifyURL,
|
||||
httpClient: sharedClient,
|
||||
verifyURL: turnstileVerifyURL,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -452,6 +452,161 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
|
||||
// 1. 需要传输大量数据到应用层
|
||||
// 2. 应用层循环计算增加 CPU 和内存开销
|
||||
//
|
||||
// 新实现使用 SQL 聚合函数:
|
||||
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
|
||||
// 2. 只返回单行聚合结果,大幅减少数据传输量
|
||||
// 3. 利用数据库索引优化聚合查询性能
|
||||
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{accountID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
|
||||
// 性能优化:数据库层聚合计算,避免应用层循环统计
|
||||
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE model = $1 AND created_at >= $2 AND created_at < $3
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{modelName, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
|
||||
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
|
||||
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
|
||||
query := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY 1
|
||||
ORDER BY 1
|
||||
`
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
result = nil
|
||||
}
|
||||
}()
|
||||
|
||||
result = make([]map[string]any, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
date string
|
||||
totalRequests int64
|
||||
totalInputTokens int64
|
||||
totalOutputTokens int64
|
||||
totalCacheTokens int64
|
||||
totalCost float64
|
||||
totalActualCost float64
|
||||
avgDurationMs float64
|
||||
)
|
||||
if err = rows.Scan(
|
||||
&date,
|
||||
&totalRequests,
|
||||
&totalInputTokens,
|
||||
&totalOutputTokens,
|
||||
&totalCacheTokens,
|
||||
&totalCost,
|
||||
&totalActualCost,
|
||||
&avgDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, map[string]any{
|
||||
"date": date,
|
||||
"total_requests": totalRequests,
|
||||
"total_input_tokens": totalInputTokens,
|
||||
"total_output_tokens": totalOutputTokens,
|
||||
"total_cache_tokens": totalCacheTokens,
|
||||
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
|
||||
"total_cost": totalCost,
|
||||
"total_actual_cost": totalActualCost,
|
||||
"average_duration_ms": avgDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
|
||||
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
|
||||
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
|
||||
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes)
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
@@ -20,7 +29,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewConcurrencyCache,
|
||||
ProvideConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
|
||||
@@ -981,6 +981,18 @@ func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyI
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
15
backend/internal/server/middleware/request_body_limit.go
Normal file
15
backend/internal/server/middleware/request_body_limit.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
|
||||
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -18,8 +18,11 @@ func RegisterGatewayRoutes(
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(bodyLimit)
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
@@ -32,6 +35,7 @@ func RegisterGatewayRoutes(
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
gemini := r.Group("/v1beta")
|
||||
gemini.Use(bodyLimit)
|
||||
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
{
|
||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
@@ -41,10 +45,11 @@ func RegisterGatewayRoutes(
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
|
||||
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
|
||||
antigravityV1 := r.Group("/antigravity/v1")
|
||||
antigravityV1.Use(bodyLimit)
|
||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
{
|
||||
@@ -55,6 +60,7 @@ func RegisterGatewayRoutes(
|
||||
}
|
||||
|
||||
antigravityV1Beta := r.Group("/antigravity/v1beta")
|
||||
antigravityV1Beta.Use(bodyLimit)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
{
|
||||
|
||||
@@ -52,6 +52,9 @@ type UsageLogRepository interface {
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -27,6 +28,45 @@ type subscriptionCacheData struct {
|
||||
Version int64
|
||||
}
|
||||
|
||||
// 缓存写入任务类型
|
||||
type cacheWriteKind int
|
||||
|
||||
const (
|
||||
cacheWriteSetBalance cacheWriteKind = iota
|
||||
cacheWriteSetSubscription
|
||||
cacheWriteUpdateSubscriptionUsage
|
||||
cacheWriteDeductBalance
|
||||
)
|
||||
|
||||
// 异步缓存写入工作池配置
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
|
||||
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
|
||||
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
|
||||
// 3. goroutine 创建/销毁带来额外开销
|
||||
//
|
||||
// 新实现使用固定大小的工作池:
|
||||
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
|
||||
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
|
||||
// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受)
|
||||
// 4. 统一超时控制,避免慢操作阻塞工作池
|
||||
const (
|
||||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
type cacheWriteTask struct {
|
||||
kind cacheWriteKind
|
||||
userID int64
|
||||
groupID int64
|
||||
balance float64
|
||||
amount float64
|
||||
subscriptionData *subscriptionCacheData
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
@@ -34,16 +74,81 @@ type BillingCacheService struct {
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
|
||||
// Stop 关闭缓存写入工作池
|
||||
func (s *BillingCacheService) Stop() {
|
||||
s.cacheWriteStopOnce.Do(func() {
|
||||
if s.cacheWriteChan == nil {
|
||||
return
|
||||
}
|
||||
close(s.cacheWriteChan)
|
||||
s.cacheWriteWg.Wait()
|
||||
s.cacheWriteChan = nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||||
s.cacheWriteWg.Add(1)
|
||||
go s.cacheWriteWorker()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) {
|
||||
if s.cacheWriteChan == nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = recover()
|
||||
}()
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) cacheWriteWorker() {
|
||||
defer s.cacheWriteWg.Done()
|
||||
for task := range s.cacheWriteChan {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
switch task.kind {
|
||||
case cacheWriteSetBalance:
|
||||
s.setBalanceCache(ctx, task.userID, task.balance)
|
||||
case cacheWriteSetSubscription:
|
||||
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
|
||||
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteDeductBalance:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
@@ -70,11 +175,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setBalanceCache(cacheCtx, userID, balance)
|
||||
}()
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
|
||||
return balance, nil
|
||||
}
|
||||
@@ -98,7 +203,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||
// DeductBalanceCache 扣减余额缓存(同步调用)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
@@ -106,6 +211,15 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// QueueDeductBalance 异步扣减余额缓存
|
||||
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: userID,
|
||||
amount: amount,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
@@ -141,11 +255,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
|
||||
}()
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetSubscription,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
subscriptionData: data,
|
||||
})
|
||||
|
||||
return data, nil
|
||||
}
|
||||
@@ -199,7 +314,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
@@ -207,6 +322,16 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
|
||||
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateSubscriptionUsage,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
amount: costUSD,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.cache == nil {
|
||||
|
||||
75
backend/internal/service/billing_cache_service_test.go
Normal file
75
backend/internal/service/billing_cache_service_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheWorkerStub struct {
|
||||
balanceUpdates int64
|
||||
subscriptionUpdates int64
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < cacheWriteBufferSize*2; i++ {
|
||||
svc.QueueDeductBalance(1, 1)
|
||||
}
|
||||
require.Less(t, time.Since(start), 2*time.Second)
|
||||
|
||||
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.balanceUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
@@ -9,22 +9,22 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
// 使用有序集合存储槽位,按时间戳清理过期条目
|
||||
type ConcurrencyCache interface {
|
||||
// Account slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:account:{accountID}:{requestID}
|
||||
// 账号槽位管理
|
||||
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID)
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// User slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:user:{userID}:{requestID}
|
||||
// 用户槽位管理
|
||||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue - uses counter with TTL set only on creation
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 20 * time.Second}
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 20 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 20 * time.Second}
|
||||
}
|
||||
|
||||
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
|
||||
if err != nil {
|
||||
|
||||
70
backend/internal/service/gateway_request.go
Normal file
70
backend/internal/service/gateway_request.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ParsedRequest 保存网关请求的预解析结果
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
|
||||
// 1. gateway_handler.go 解析获取 model 和 stream
|
||||
// 2. gateway_service.go 再次解析获取 system、messages、metadata
|
||||
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
|
||||
//
|
||||
// 新实现一次解析,多处复用:
|
||||
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
|
||||
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
||||
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
||||
type ParsedRequest struct {
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
|
||||
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
}
|
||||
|
||||
if rawModel, exists := req["model"]; exists {
|
||||
model, ok := rawModel.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid model field type")
|
||||
}
|
||||
parsed.Model = model
|
||||
}
|
||||
if rawStream, exists := req["stream"]; exists {
|
||||
stream, ok := rawStream.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid stream field type")
|
||||
}
|
||||
parsed.Stream = stream
|
||||
}
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
if system, ok := req["system"]; ok && system != nil {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
38
backend/internal/service/gateway_request_test.go
Normal file
38
backend/internal/service/gateway_request_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseGatewayRequest(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
||||
require.True(t, parsed.Stream)
|
||||
require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID)
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.NotNil(t, parsed.System)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.False(t, parsed.HasSystem)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||
body := []byte(`{"model":123}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
body := []byte(`{"stream":"true"}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -33,7 +32,10 @@ const (
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
@@ -141,40 +143,36 @@ func NewGatewayService(
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash 从请求体计算粘性会话hash
|
||||
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
// GenerateSessionHash 从预解析请求计算粘性会话 hash
|
||||
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
||||
if parsed.MetadataUserID != "" {
|
||||
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 提取带cache_control: {type: "ephemeral"}的内容
|
||||
cacheableContent := s.extractCacheableContent(req)
|
||||
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
|
||||
cacheableContent := s.extractCacheableContent(parsed)
|
||||
if cacheableContent != "" {
|
||||
return s.hashContent(cacheableContent)
|
||||
}
|
||||
|
||||
// 3. Fallback: 使用system内容
|
||||
if system := req["system"]; system != nil {
|
||||
systemText := s.extractTextFromSystem(system)
|
||||
// 3. Fallback: 使用 system 内容
|
||||
if parsed.System != nil {
|
||||
systemText := s.extractTextFromSystem(parsed.System)
|
||||
if systemText != "" {
|
||||
return s.hashContent(systemText)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 最后fallback: 使用第一条消息
|
||||
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]any); ok {
|
||||
// 4. 最后 fallback: 使用第一条消息
|
||||
if len(parsed.Messages) > 0 {
|
||||
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||
if msgText != "" {
|
||||
return s.hashContent(msgText)
|
||||
@@ -185,36 +183,38 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
var content string
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 检查system中的cacheable内容
|
||||
if system, ok := req["system"].([]any); ok {
|
||||
var builder strings.Builder
|
||||
|
||||
// 检查 system 中的 cacheable 内容
|
||||
if system, ok := parsed.System.([]any); ok {
|
||||
for _, part := range system {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
content += text
|
||||
builder.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
systemText := builder.String()
|
||||
|
||||
// 检查messages中的cacheable内容
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
// 找到cacheable内容,提取第一条消息的文本
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
}
|
||||
// 检查 messages 中的 cacheable 内容
|
||||
for _, msg := range parsed.Messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +223,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
return content
|
||||
return systemText
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractTextFromSystem(system any) string {
|
||||
@@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求获取model和stream
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
if parsed == nil {
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
if !gjson.GetBytes(body, "system").Exists() {
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
|
||||
if !parsed.HasSystem {
|
||||
body, _ = sjson.SetBytes(body, "system", []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
@@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := req.Model
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
req.Model = mappedModel
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
@@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 处理正常响应
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
if req.Stream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||
if err != nil {
|
||||
if err.Error() == "have error in stream" {
|
||||
return nil, &UpstreamFailoverError{
|
||||
@@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model)
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: req.Stream,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
@@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// getBetaHeader 处理anthropic-beta header
|
||||
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
||||
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
|
||||
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
|
||||
// 如果客户端传了anthropic-beta
|
||||
if clientBetaHeader != "" {
|
||||
// 已包含oauth beta则直接返回
|
||||
@@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
}
|
||||
|
||||
// 客户端没传,根据模型生成
|
||||
var modelID string
|
||||
var reqMap map[string]any
|
||||
if json.Unmarshal(body, &reqMap) == nil {
|
||||
if m, ok := reqMap["model"].(string); ok {
|
||||
modelID = m
|
||||
}
|
||||
}
|
||||
|
||||
// haiku模型不需要claude-code beta
|
||||
// haiku 模型不需要 claude-code beta
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.HaikuBetaHeader
|
||||
}
|
||||
@@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
|
||||
log.Printf("Update subscription cache failed: %v", err)
|
||||
}
|
||||
}()
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
@@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Update balance cache failed: %v", err)
|
||||
}
|
||||
}()
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
if parsed == nil {
|
||||
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
||||
// 参考 Antigravity-Manager 和 proxycast 实现
|
||||
if account.Platform == PlatformAntigravity {
|
||||
@@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||
return err
|
||||
@@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
@@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var benchmarkStringSink string
|
||||
|
||||
// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。
|
||||
func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
b.Fatalf("解析请求失败: %v", err)
|
||||
}
|
||||
benchmarkStringSink = svc.GenerateSessionHash(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。
|
||||
func BenchmarkExtractCacheableContent_System(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
req := buildSystemCacheableRequest(12)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkStringSink = svc.extractCacheableContent(req)
|
||||
}
|
||||
}
|
||||
|
||||
func buildSystemCacheableRequest(parts int) *ParsedRequest {
|
||||
systemParts := make([]any, 0, parts)
|
||||
for i := 0; i < parts; i++ {
|
||||
systemParts = append(systemParts, map[string]any{
|
||||
"text": "system_part_" + strconv.Itoa(i),
|
||||
"cache_control": map[string]any{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
return &ParsedRequest{
|
||||
System: systemParts,
|
||||
HasSystem: true,
|
||||
}
|
||||
}
|
||||
@@ -921,7 +921,10 @@ func sleepGeminiBackoff(attempt int) {
|
||||
time.Sleep(sleepFor)
|
||||
}
|
||||
|
||||
var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
|
||||
var (
|
||||
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
|
||||
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
|
||||
)
|
||||
|
||||
func sanitizeUpstreamErrorMessage(msg string) string {
|
||||
if msg == "" {
|
||||
@@ -1925,7 +1928,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
||||
}
|
||||
|
||||
// Match "Please retry in Xs"
|
||||
retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`)
|
||||
matches := retryInRegex.FindStringSubmatch(string(body))
|
||||
if len(matches) == 2 {
|
||||
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
type GeminiOAuthService struct {
|
||||
@@ -497,11 +497,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
|
||||
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
|
||||
}
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
@@ -768,20 +768,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}()
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
|
||||
}()
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var (
|
||||
openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`)
|
||||
openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
)
|
||||
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
@@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 正则匹配日期后缀 (如 -20251222)
|
||||
datePattern := regexp.MustCompile(`-\d{8}$`)
|
||||
|
||||
// 尝试的回退变体
|
||||
variants := s.generateOpenAIModelVariants(model, datePattern)
|
||||
variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
|
||||
|
||||
for _, variant := range variants {
|
||||
if pricing, ok := s.pricingData[variant]; ok {
|
||||
@@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern *
|
||||
|
||||
// 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2
|
||||
// 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的
|
||||
basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 {
|
||||
if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
|
||||
// 3. 同时去掉日期后再提取基础版本号
|
||||
if withoutDate != model {
|
||||
if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
|
||||
if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta
|
||||
|
||||
// GetStatsByAccount 获取账号的使用统计
|
||||
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get account stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByModel 获取模型的使用统计
|
||||
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get model stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDailyStats 获取每日使用统计(最近N天)
|
||||
@@ -209,54 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
endTime := time.Now()
|
||||
startTime := endTime.AddDate(0, 0, -days)
|
||||
|
||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get daily stats: %w", err)
|
||||
}
|
||||
|
||||
// 按日期分组统计
|
||||
dailyStats := make(map[string]*UsageStats)
|
||||
for _, log := range logs {
|
||||
dateKey := log.CreatedAt.Format("2006-01-02")
|
||||
if _, exists := dailyStats[dateKey]; !exists {
|
||||
dailyStats[dateKey] = &UsageStats{}
|
||||
}
|
||||
|
||||
stats := dailyStats[dateKey]
|
||||
stats.TotalRequests++
|
||||
stats.TotalInputTokens += int64(log.InputTokens)
|
||||
stats.TotalOutputTokens += int64(log.OutputTokens)
|
||||
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
||||
stats.TotalTokens += int64(log.TotalTokens())
|
||||
stats.TotalCost += log.TotalCost
|
||||
stats.TotalActualCost += log.ActualCost
|
||||
|
||||
if log.DurationMs != nil {
|
||||
stats.AverageDurationMs += float64(*log.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算平均值并转换为数组
|
||||
result := make([]map[string]any, 0, len(dailyStats))
|
||||
for date, stats := range dailyStats {
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
result = append(result, map[string]any{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// calculateStats 计算统计数据
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
-- 为聚合查询补充复合索引
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_account_created_at ON usage_logs(account_id, created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_created_at ON usage_logs(api_key_id, created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_logs_model_created_at ON usage_logs(model, created_at);
|
||||
Reference in New Issue
Block a user