From 537678669461df0a814c70f563312e9b6097e2ab Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 30 Dec 2025 20:28:41 +0800 Subject: [PATCH 01/49] =?UTF-8?q?chore(=E9=85=8D=E7=BD=AE):=20=E6=8F=90?= =?UTF-8?q?=E5=8D=87=E5=AE=B9=E5=99=A8=E6=96=87=E4=BB=B6=E6=8F=8F=E8=BF=B0?= =?UTF-8?q?=E7=AC=A6=E4=B8=8A=E9=99=90=E5=88=B010=E4=B8=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 调整原因: - 防止高并发下出现 "too many open files" 错误 - 统一测试与生产环境的 ulimits 配置 改动内容: - 为 sub2api、postgres、redis 设置 nofile - 软硬限制均为 100000 测试: 未运行 --- deploy/docker-compose-test.yml | 12 ++++++++++++ deploy/docker-compose.yml | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/deploy/docker-compose-test.yml b/deploy/docker-compose-test.yml index defc0aa8..35aa553b 100644 --- a/deploy/docker-compose-test.yml +++ b/deploy/docker-compose-test.yml @@ -19,6 +19,10 @@ services: image: sub2api:latest container_name: sub2api restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 ports: - "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080" volumes: @@ -107,6 +111,10 @@ services: image: postgres:18-alpine container_name: sub2api-postgres restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 volumes: - postgres_data:/var/lib/postgresql/data environment: @@ -132,6 +140,10 @@ services: image: redis:7-alpine container_name: sub2api-redis restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 volumes: - redis_data:/data command: > diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 0e3fb16e..45b3796b 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -19,6 +19,10 @@ services: image: weishaw/sub2api:latest container_name: sub2api restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 ports: - "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080" volumes: @@ -107,6 +111,10 @@ services: image: postgres:18-alpine container_name: sub2api-postgres restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 volumes: - postgres_data:/var/lib/postgresql/data environment: @@ -132,6 +140,10 @@ services: image: redis:7-alpine container_name: sub2api-redis restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 volumes: - redis_data:/data command: > From 7efa8b54c4a15c2d9140f2774b2eb0921f039368 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 08:50:12 +0800 Subject: [PATCH 02/49] =?UTF-8?q?perf(=E5=90=8E=E7=AB=AF):=20=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96=E4=B8=8E=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=E6=B1=A0=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 DB/Redis 连接池配置与校验,并补充单测 网关请求体大小限制与 413 处理 HTTP/req 客户端池化并调整上游连接池默认值 并发槽位改为 ZSET+Lua 与指数退避 用量统计改 SQL 聚合并新增索引迁移 计费缓存写入改工作池并补测试/基准 测试: 在 backend/ 下运行 go test ./... --- backend/cmd/server/wire.go | 5 + backend/cmd/server/wire_gen.go | 21 ++- backend/internal/config/config.go | 106 +++++++++++ backend/internal/handler/gateway_handler.go | 59 +++--- backend/internal/handler/gateway_helper.go | 59 +++++- .../internal/handler/gemini_v1beta_handler.go | 7 +- .../handler/openai_gateway_handler.go | 4 + .../internal/handler/request_body_limit.go | 27 +++ .../handler/request_body_limit_test.go | 45 +++++ backend/internal/infrastructure/db_pool.go | 32 ++++ .../internal/infrastructure/db_pool_test.go | 50 +++++ backend/internal/infrastructure/ent.go | 1 + backend/internal/infrastructure/redis.go | 33 +++- backend/internal/infrastructure/redis_test.go | 35 ++++ backend/internal/pkg/httpclient/pool.go | 152 +++++++++++++++ .../repository/claude_oauth_service.go | 14 +- .../repository/claude_usage_service.go | 22 +-- .../internal/repository/concurrency_cache.go | 156 +++++++++------- .../concurrency_cache_benchmark_test.go | 135 ++++++++++++++ .../concurrency_cache_integration_test.go | 18 +- .../repository/gemini_oauth_client.go | 9 +- .../repository/geminicli_codeassist_client.go | 9 +- .../repository/github_release_service.go | 20 +- backend/internal/repository/http_upstream.go | 95 +++++++--- .../http_upstream_benchmark_test.go | 46 +++++ .../internal/repository/http_upstream_test.go | 4 +- .../repository/openai_oauth_service.go | 12 +- .../internal/repository/pricing_service.go | 11 +- .../repository/proxy_probe_service.go | 48 +---- .../repository/proxy_probe_service_test.go | 18 +- .../internal/repository/req_client_pool.go | 59 ++++++ .../internal/repository/turnstile_service.go | 13 +- backend/internal/repository/usage_log_repo.go | 155 ++++++++++++++++ backend/internal/repository/wire.go | 11 +- backend/internal/server/api_contract_test.go | 12 ++ .../server/middleware/request_body_limit.go | 15 ++ backend/internal/server/routes/gateway.go | 8 +- .../internal/service/account_usage_service.go | 3 + .../internal/service/billing_cache_service.go | 151 +++++++++++++-- .../service/billing_cache_service_test.go | 75 ++++++++ .../internal/service/concurrency_service.go | 14 +- backend/internal/service/crs_sync_service.go | 9 +- backend/internal/service/gateway_request.go | 70 +++++++ .../internal/service/gateway_request_test.go | 38 ++++ backend/internal/service/gateway_service.go | 174 ++++++++---------- .../service/gateway_service_benchmark_test.go | 50 +++++ .../service/gemini_messages_compat_service.go | 6 +- .../internal/service/gemini_oauth_service.go | 13 +- .../service/openai_gateway_service.go | 12 +- backend/internal/service/pricing_service.go | 15 +- backend/internal/service/usage_service.go | 78 +++----- .../010_add_usage_logs_aggregated_indexes.sql | 4 + deploy/config.example.yaml | 16 ++ 53 files changed, 1805 insertions(+), 449 deletions(-) create mode 100644 backend/internal/handler/request_body_limit.go create mode 100644 backend/internal/handler/request_body_limit_test.go create mode 100644 backend/internal/infrastructure/db_pool.go create mode 100644 backend/internal/infrastructure/db_pool_test.go create mode 100644 backend/internal/infrastructure/redis_test.go create mode 100644 backend/internal/pkg/httpclient/pool.go create mode 100644 backend/internal/repository/concurrency_cache_benchmark_test.go create mode 100644 backend/internal/repository/http_upstream_benchmark_test.go create mode 100644 backend/internal/repository/req_client_pool.go create mode 100644 backend/internal/server/middleware/request_body_limit.go create mode 100644 backend/internal/service/billing_cache_service_test.go create mode 100644 backend/internal/service/gateway_request.go create mode 100644 backend/internal/service/gateway_request_test.go create mode 100644 backend/internal/service/gateway_service_benchmark_test.go create mode 100644 backend/migrations/010_add_usage_logs_aggregated_indexes.sql diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index a746955b..fffcd5f9 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -67,6 +67,7 @@ func provideCleanup( tokenRefresh *service.TokenRefreshService, pricing *service.PricingService, emailQueue *service.EmailQueueService, + billingCache *service.BillingCacheService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -94,6 +95,10 @@ func provideCleanup( emailQueue.Stop() return nil }}, + {"BillingCacheService", func() error { + billingCache.Stop() + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f37d696b..664e7aca 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -39,11 +39,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { if err != nil { return nil, err } - sqlDB, err := infrastructure.ProvideSQLDB(client) + db, err := infrastructure.ProvideSQLDB(client) if err != nil { return nil, err } - userRepository := repository.NewUserRepository(client, sqlDB) + userRepository := repository.NewUserRepository(client, db) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) redisClient := infrastructure.ProvideRedis(configConfig) @@ -57,12 +57,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { authHandler := handler.NewAuthHandler(configConfig, authService, userService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewApiKeyRepository(client) - groupRepository := repository.NewGroupRepository(client, sqlDB) + groupRepository := repository.NewGroupRepository(client, db) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) apiKeyCache := repository.NewApiKeyCache(redisClient) apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) - usageLogRepository := repository.NewUsageLogRepository(client, sqlDB) + usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) @@ -75,8 +75,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) dashboardHandler := admin.NewDashboardHandler(dashboardService) - accountRepository := repository.NewAccountRepository(client, sqlDB) - proxyRepository := repository.NewProxyRepository(client, sqlDB) + accountRepository := repository.NewAccountRepository(client, db) + proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber() adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) adminUserHandler := admin.NewUserHandler(adminService) @@ -95,7 +95,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream) - concurrencyCache := repository.NewConcurrencyCache(redisClient) + concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.NewConcurrencyService(concurrencyCache) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) @@ -142,7 +142,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { httpServer := server.ProvideHTTPServer(configConfig, engine) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig) - v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher) + v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher) application := &Application{ Server: httpServer, Cleanup: v, @@ -170,6 +170,7 @@ func provideCleanup( tokenRefresh *service.TokenRefreshService, pricing *service.PricingService, emailQueue *service.EmailQueueService, + billingCache *service.BillingCacheService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -196,6 +197,10 @@ func provideCleanup( emailQueue.Stop() return nil }}, + {"BillingCacheService", func() error { + billingCache.Stop() + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 86283fde..dfc9a844 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -79,12 +79,29 @@ type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 // 注意:这不影响流式数据传输,只控制等待响应头的时间 ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` + // 请求体最大字节数,用于网关请求体大小限制 + MaxBodySize int64 `mapstructure:"max_body_size"` + + // HTTP 上游连接池配置(性能优化:支持高并发场景调优) + // MaxIdleConns: 所有主机的最大空闲连接总数 + MaxIdleConns int `mapstructure:"max_idle_conns"` + // MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率) + MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"` + // MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲),0表示无限制 + MaxConnsPerHost int `mapstructure:"max_conns_per_host"` + // IdleConnTimeoutSeconds: 空闲连接超时时间(秒) + IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"` + // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) + // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 + ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` } func (s *ServerConfig) Address() string { return fmt.Sprintf("%s:%d", s.Host, s.Port) } +// DatabaseConfig 数据库连接配置 +// 性能优化:新增连接池参数,避免频繁创建/销毁连接 type DatabaseConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` @@ -92,6 +109,15 @@ type DatabaseConfig struct { Password string `mapstructure:"password"` DBName string `mapstructure:"dbname"` SSLMode string `mapstructure:"sslmode"` + // 连接池配置(性能优化:可配置化连接池参数) + // MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽 + MaxOpenConns int `mapstructure:"max_open_conns"` + // MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟 + MaxIdleConns int `mapstructure:"max_idle_conns"` + // ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏 + ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"` + // ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接 + ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"` } func (d *DatabaseConfig) DSN() string { @@ -112,11 +138,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string { ) } +// RedisConfig Redis 连接配置 +// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量 type RedisConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Password string `mapstructure:"password"` DB int `mapstructure:"db"` + // 连接池与超时配置(性能优化:可配置化连接池参数) + // DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞 + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + // ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池 + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + // WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池 + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + // PoolSize: 连接池大小,控制最大并发连接数 + PoolSize int `mapstructure:"pool_size"` + // MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟 + MinIdleConns int `mapstructure:"min_idle_conns"` } func (r *RedisConfig) Address() string { @@ -203,12 +242,21 @@ func setDefaults() { viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") viper.SetDefault("database.sslmode", "disable") + viper.SetDefault("database.max_open_conns", 50) + viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.conn_max_lifetime_minutes", 30) + viper.SetDefault("database.conn_max_idle_time_minutes", 5) // Redis viper.SetDefault("redis.host", "localhost") viper.SetDefault("redis.port", 6379) viper.SetDefault("redis.password", "") viper.SetDefault("redis.db", 0) + viper.SetDefault("redis.dial_timeout_seconds", 5) + viper.SetDefault("redis.read_timeout_seconds", 3) + viper.SetDefault("redis.write_timeout_seconds", 3) + viper.SetDefault("redis.pool_size", 128) + viper.SetDefault("redis.min_idle_conns", 10) // JWT viper.SetDefault("jwt.secret", "change-me-in-production") @@ -240,6 +288,13 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + // HTTP 上游连接池配置(针对 5000+ 并发用户优化) + viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒) + viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -263,6 +318,57 @@ func (c *Config) Validate() error { if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" { return fmt.Errorf("jwt.secret must be changed in production") } + if c.Database.MaxOpenConns <= 0 { + return fmt.Errorf("database.max_open_conns must be positive") + } + if c.Database.MaxIdleConns < 0 { + return fmt.Errorf("database.max_idle_conns must be non-negative") + } + if c.Database.MaxIdleConns > c.Database.MaxOpenConns { + return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns") + } + if c.Database.ConnMaxLifetimeMinutes < 0 { + return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative") + } + if c.Database.ConnMaxIdleTimeMinutes < 0 { + return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative") + } + if c.Redis.DialTimeoutSeconds <= 0 { + return fmt.Errorf("redis.dial_timeout_seconds must be positive") + } + if c.Redis.ReadTimeoutSeconds <= 0 { + return fmt.Errorf("redis.read_timeout_seconds must be positive") + } + if c.Redis.WriteTimeoutSeconds <= 0 { + return fmt.Errorf("redis.write_timeout_seconds must be positive") + } + if c.Redis.PoolSize <= 0 { + return fmt.Errorf("redis.pool_size must be positive") + } + if c.Redis.MinIdleConns < 0 { + return fmt.Errorf("redis.min_idle_conns must be non-negative") + } + if c.Redis.MinIdleConns > c.Redis.PoolSize { + return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size") + } + if c.Gateway.MaxBodySize <= 0 { + return fmt.Errorf("gateway.max_body_size must be positive") + } + if c.Gateway.MaxIdleConns <= 0 { + return fmt.Errorf("gateway.max_idle_conns must be positive") + } + if c.Gateway.MaxIdleConnsPerHost <= 0 { + return fmt.Errorf("gateway.max_idle_conns_per_host must be positive") + } + if c.Gateway.MaxConnsPerHost < 0 { + return fmt.Errorf("gateway.max_conns_per_host must be non-negative") + } + if c.Gateway.IdleConnTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") + } + if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { + return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") + } return nil } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index bf179ea1..fc92b2d8 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } @@ -76,15 +80,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // 解析请求获取模型名和stream - var req struct { - Model string `json:"model"` - Stream bool `json:"stream"` - } - if err := json.Unmarshal(body, &req); err != nil { + parsedReq, err := service.ParseGatewayRequest(body) + if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + reqModel := parsedReq.Model + reqStream := parsedReq.Stream // Track if we've started streaming (for error handling) streamStarted := false @@ -106,7 +108,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) // 1. 首先获取用户并发槽位 - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted) + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("User concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "user", streamStarted) @@ -124,7 +126,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 计算粘性会话hash - sessionHash := h.gatewayService.GenerateSessionHash(body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 platform := "" @@ -141,7 +143,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { lastFailoverStatus := 0 for { - account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -153,16 +155,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if req.Stream { - sendMockWarmupStream(c, req.Model) + if reqStream { + sendMockWarmupStream(c, reqModel) } else { - sendMockWarmupResponse(c, req.Model) + sendMockWarmupResponse(c, reqModel) } return } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) @@ -172,7 +174,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body) + result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) } else { result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) } @@ -223,7 +225,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -235,16 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if req.Stream { - sendMockWarmupStream(c, req.Model) + if reqStream { + sendMockWarmupStream(c, reqModel) } else { - sendMockWarmupResponse(c, req.Model) + sendMockWarmupResponse(c, reqModel) } return } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) @@ -256,7 +258,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) + result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -496,6 +498,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } @@ -505,11 +511,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - // 解析请求获取模型名 - var req struct { - Model string `json:"model"` - } - if err := json.Unmarshal(body, &req); err != nil { + parsedReq, err := service.ParseGatewayRequest(body) + if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } @@ -525,17 +528,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } // 计算粘性会话 hash - sessionHash := h.gatewayService.GenerateSessionHash(body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 选择支持该模型的账号 - account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) + account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) return } // 转发请求(不记录使用量) - if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil { + if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { log.Printf("Forward count_tokens request failed: %v", err) // 错误响应已在 ForwardCountTokens 中处理 return diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 5cbe462d..4c7bd0f0 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -3,6 +3,7 @@ package handler import ( "context" "fmt" + "math/rand" "net/http" "time" @@ -11,11 +12,28 @@ import ( "github.com/gin-gonic/gin" ) +// 并发槽位等待相关常量 +// +// 性能优化说明: +// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题: +// 1. 高并发时频繁轮询增加 Redis 压力 +// 2. 固定间隔可能导致多个请求同时重试(惊群效应) +// +// 新实现使用指数退避 + 抖动算法: +// 1. 初始退避 100ms,每次乘以 1.5,最大 2s +// 2. 添加 ±20% 的随机抖动,分散重试时间点 +// 3. 减少 Redis 压力,避免惊群效应 const ( - // maxConcurrencyWait is the maximum time to wait for a concurrency slot + // maxConcurrencyWait 等待并发槽位的最大时间 maxConcurrencyWait = 30 * time.Second - // pingInterval is the interval for sending ping events during slot wait + // pingInterval 流式响应等待时发送 ping 的间隔 pingInterval = 15 * time.Second + // initialBackoff 初始退避时间 + initialBackoff = 100 * time.Millisecond + // backoffMultiplier 退避时间乘数(指数退避) + backoffMultiplier = 1.5 + // maxBackoff 最大退避时间 + maxBackoff = 2 * time.Second ) // SSEPingFormat defines the format of SSE ping events for different platforms @@ -131,8 +149,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, pingCh = pingTicker.C } - pollTicker := time.NewTicker(100 * time.Millisecond) - defer pollTicker.Stop() + backoff := initialBackoff + timer := time.NewTimer(backoff) + defer timer.Stop() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) for { select { @@ -156,7 +176,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, } flusher.Flush() - case <-pollTicker.C: + case <-timer.C: // Try to acquire slot var result *service.AcquireResult var err error @@ -174,6 +194,35 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, if result.Acquired { return result.ReleaseFunc, nil } + backoff = nextBackoff(backoff, rng) + timer.Reset(backoff) } } } + +// nextBackoff 计算下一次退避时间 +// 性能优化:使用指数退避 + 随机抖动,避免惊群效应 +// current: 当前退避时间 +// rng: 随机数生成器(可为 nil,此时不添加抖动) +// 返回值:下一次退避时间(100ms ~ 2s 之间) +func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration { + // 指数退避:当前时间 * 1.5 + next := time.Duration(float64(current) * backoffMultiplier) + if next > maxBackoff { + next = maxBackoff + } + if rng == nil { + return next + } + // 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2) + // 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis + jitter := 0.8 + rng.Float64()*0.4 + jittered := time.Duration(float64(next) * jitter) + if jittered < initialBackoff { + return initialBackoff + } + if jittered > maxBackoff { + return maxBackoff + } + return jittered +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index ea1bdf5a..4e99e00d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit)) + return + } googleError(c, http.StatusBadRequest, "Failed to read request body") return } @@ -191,7 +195,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 3) select account (sticky session based on request body) - sessionHash := h.gatewayService.GenerateSessionHash(body) + parsedReq, _ := service.ParseGatewayRequest(body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) const maxAccountSwitches = 3 switchCount := 0 failedAccountIDs := make(map[int64]struct{}) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 2dee9ccd..7fcb329d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Read request body body, err := io.ReadAll(c.Request.Body) if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } diff --git a/backend/internal/handler/request_body_limit.go b/backend/internal/handler/request_body_limit.go new file mode 100644 index 00000000..d746673b --- /dev/null +++ b/backend/internal/handler/request_body_limit.go @@ -0,0 +1,27 @@ +package handler + +import ( + "errors" + "fmt" + "net/http" +) + +func extractMaxBytesError(err error) (*http.MaxBytesError, bool) { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + return maxErr, true + } + return nil, false +} + +func formatBodyLimit(limit int64) string { + const mb = 1024 * 1024 + if limit >= mb { + return fmt.Sprintf("%dMB", limit/mb) + } + return fmt.Sprintf("%dB", limit) +} + +func buildBodyTooLargeMessage(limit int64) string { + return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit)) +} diff --git a/backend/internal/handler/request_body_limit_test.go b/backend/internal/handler/request_body_limit_test.go new file mode 100644 index 00000000..bd9b8177 --- /dev/null +++ b/backend/internal/handler/request_body_limit_test.go @@ -0,0 +1,45 @@ +package handler + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestRequestBodyLimitTooLarge(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := int64(16) + router := gin.New() + router.Use(middleware.RequestBodyLimit(limit)) + router.POST("/test", func(c *gin.Context) { + _, err := io.ReadAll(c.Request.Body) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": buildBodyTooLargeMessage(maxErr.Limit), + }) + return + } + c.JSON(http.StatusBadRequest, gin.H{ + "error": "read_failed", + }) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + payload := bytes.Repeat([]byte("a"), int(limit+1)) + req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload)) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code) + require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit)) +} diff --git a/backend/internal/infrastructure/db_pool.go b/backend/internal/infrastructure/db_pool.go new file mode 100644 index 00000000..612155bf --- /dev/null +++ b/backend/internal/infrastructure/db_pool.go @@ -0,0 +1,32 @@ +package infrastructure + +import ( + "database/sql" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +type dbPoolSettings struct { + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration + ConnMaxIdleTime time.Duration +} + +func buildDBPoolSettings(cfg *config.Config) dbPoolSettings { + return dbPoolSettings{ + MaxOpenConns: cfg.Database.MaxOpenConns, + MaxIdleConns: cfg.Database.MaxIdleConns, + ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute, + ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute, + } +} + +func applyDBPoolSettings(db *sql.DB, cfg *config.Config) { + settings := buildDBPoolSettings(cfg) + db.SetMaxOpenConns(settings.MaxOpenConns) + db.SetMaxIdleConns(settings.MaxIdleConns) + db.SetConnMaxLifetime(settings.ConnMaxLifetime) + db.SetConnMaxIdleTime(settings.ConnMaxIdleTime) +} diff --git a/backend/internal/infrastructure/db_pool_test.go b/backend/internal/infrastructure/db_pool_test.go new file mode 100644 index 00000000..0f0e9716 --- /dev/null +++ b/backend/internal/infrastructure/db_pool_test.go @@ -0,0 +1,50 @@ +package infrastructure + +import ( + "database/sql" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" + + _ "github.com/lib/pq" +) + +func TestBuildDBPoolSettings(t *testing.T) { + cfg := &config.Config{ + Database: config.DatabaseConfig{ + MaxOpenConns: 50, + MaxIdleConns: 10, + ConnMaxLifetimeMinutes: 30, + ConnMaxIdleTimeMinutes: 5, + }, + } + + settings := buildDBPoolSettings(cfg) + require.Equal(t, 50, settings.MaxOpenConns) + require.Equal(t, 10, settings.MaxIdleConns) + require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime) + require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime) +} + +func TestApplyDBPoolSettings(t *testing.T) { + cfg := &config.Config{ + Database: config.DatabaseConfig{ + MaxOpenConns: 40, + MaxIdleConns: 8, + ConnMaxLifetimeMinutes: 15, + ConnMaxIdleTimeMinutes: 3, + }, + } + + db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable") + require.NoError(t, err) + t.Cleanup(func() { + _ = db.Close() + }) + + applyDBPoolSettings(db, cfg) + stats := db.Stats() + require.Equal(t, 40, stats.MaxOpenConnections) +} diff --git a/backend/internal/infrastructure/ent.go b/backend/internal/infrastructure/ent.go index 13184a83..b1ab9a55 100644 --- a/backend/internal/infrastructure/ent.go +++ b/backend/internal/infrastructure/ent.go @@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { if err != nil { return nil, nil, err } + applyDBPoolSettings(drv.DB(), cfg) // 确保数据库 schema 已准备就绪。 // SQL 迁移文件是 schema 的权威来源(source of truth)。 diff --git a/backend/internal/infrastructure/redis.go b/backend/internal/infrastructure/redis.go index 970a2595..5bb92d19 100644 --- a/backend/internal/infrastructure/redis.go +++ b/backend/internal/infrastructure/redis.go @@ -1,16 +1,39 @@ package infrastructure import ( + "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/redis/go-redis/v9" ) // InitRedis 初始化 Redis 客户端 +// +// 性能优化说明: +// 原实现使用 go-redis 默认配置,未设置连接池和超时参数: +// 1. 默认连接池大小可能不足以支撑高并发 +// 2. 无超时控制可能导致慢操作阻塞 +// +// 新实现支持可配置的连接池和超时参数: +// 1. PoolSize: 控制最大并发连接数(默认 128) +// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10) +// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时 func InitRedis(cfg *config.Config) *redis.Client { - return redis.NewClient(&redis.Options{ - Addr: cfg.Redis.Address(), - Password: cfg.Redis.Password, - DB: cfg.Redis.DB, - }) + return redis.NewClient(buildRedisOptions(cfg)) +} + +// buildRedisOptions 构建 Redis 连接选项 +// 从配置文件读取连接池和超时参数,支持生产环境调优 +func buildRedisOptions(cfg *config.Config) *redis.Options { + return &redis.Options{ + Addr: cfg.Redis.Address(), + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时 + ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时 + WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时 + PoolSize: cfg.Redis.PoolSize, // 连接池大小 + MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接 + } } diff --git a/backend/internal/infrastructure/redis_test.go b/backend/internal/infrastructure/redis_test.go new file mode 100644 index 00000000..5e38e826 --- /dev/null +++ b/backend/internal/infrastructure/redis_test.go @@ -0,0 +1,35 @@ +package infrastructure + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestBuildRedisOptions(t *testing.T) { + cfg := &config.Config{ + Redis: config.RedisConfig{ + Host: "localhost", + Port: 6379, + Password: "secret", + DB: 2, + DialTimeoutSeconds: 5, + ReadTimeoutSeconds: 3, + WriteTimeoutSeconds: 4, + PoolSize: 100, + MinIdleConns: 10, + }, + } + + opts := buildRedisOptions(cfg) + require.Equal(t, "localhost:6379", opts.Addr) + require.Equal(t, "secret", opts.Password) + require.Equal(t, 2, opts.DB) + require.Equal(t, 5*time.Second, opts.DialTimeout) + require.Equal(t, 3*time.Second, opts.ReadTimeout) + require.Equal(t, 4*time.Second, opts.WriteTimeout) + require.Equal(t, 100, opts.PoolSize) + require.Equal(t, 10, opts.MinIdleConns) +} diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go new file mode 100644 index 00000000..f68d50a5 --- /dev/null +++ b/backend/internal/pkg/httpclient/pool.go @@ -0,0 +1,152 @@ +// Package httpclient 提供共享 HTTP 客户端池 +// +// 性能优化说明: +// 原实现在多个服务中重复创建 http.Client: +// 1. proxy_probe_service.go: 每次探测创建新客户端 +// 2. pricing_service.go: 每次请求创建新客户端 +// 3. turnstile_service.go: 每次验证创建新客户端 +// 4. github_release_service.go: 每次请求创建新客户端 +// 5. claude_usage_service.go: 每次请求创建新客户端 +// +// 新实现使用统一的客户端池: +// 1. 相同配置复用同一 http.Client 实例 +// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销 +// 3. 支持 HTTP/HTTPS/SOCKS5 代理 +// 4. 支持严格代理模式(代理失败则返回错误) +package httpclient + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/net/proxy" +) + +// Transport 连接池默认配置 +const ( + defaultMaxIdleConns = 100 // 最大空闲连接数 + defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 + defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 +) + +// Options 定义共享 HTTP 客户端的构建参数 +type Options struct { + ProxyURL string // 代理 URL(支持 http/https/socks5) + Timeout time.Duration // 请求总超时时间 + ResponseHeaderTimeout time.Duration // 等待响应头超时时间 + InsecureSkipVerify bool // 是否跳过 TLS 证书验证 + ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 + + // 可选的连接池参数(不设置则使用默认值) + MaxIdleConns int // 最大空闲连接总数(默认 100) + MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10) + MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制) +} + +// sharedClients 存储按配置参数缓存的 http.Client 实例 +var sharedClients sync.Map + +// GetClient 返回共享的 HTTP 客户端实例 +// 性能优化:相同配置复用同一客户端,避免重复创建 Transport +func GetClient(opts Options) (*http.Client, error) { + key := buildClientKey(opts) + if cached, ok := sharedClients.Load(key); ok { + return cached.(*http.Client), nil + } + + client, err := buildClient(opts) + if err != nil { + if opts.ProxyStrict { + return nil, err + } + fallback := opts + fallback.ProxyURL = "" + client, _ = buildClient(fallback) + } + + actual, _ := sharedClients.LoadOrStore(key, client) + return actual.(*http.Client), nil +} + +func buildClient(opts Options) (*http.Client, error) { + transport, err := buildTransport(opts) + if err != nil { + return nil, err + } + + return &http.Client{ + Transport: transport, + Timeout: opts.Timeout, + }, nil +} + +func buildTransport(opts Options) (*http.Transport, error) { + // 使用自定义值或默认值 + maxIdleConns := opts.MaxIdleConns + if maxIdleConns <= 0 { + maxIdleConns = defaultMaxIdleConns + } + maxIdleConnsPerHost := opts.MaxIdleConnsPerHost + if maxIdleConnsPerHost <= 0 { + maxIdleConnsPerHost = defaultMaxIdleConnsPerHost + } + + transport := &http.Transport{ + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConnsPerHost, + MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制 + IdleConnTimeout: defaultIdleConnTimeout, + ResponseHeaderTimeout: opts.ResponseHeaderTimeout, + } + + if opts.InsecureSkipVerify { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + proxyURL := strings.TrimSpace(opts.ProxyURL) + if proxyURL == "" { + return transport, nil + } + + parsed, err := url.Parse(proxyURL) + if err != nil { + return nil, err + } + + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + transport.Proxy = http.ProxyURL(parsed) + case "socks5", "socks5h": + dialer, err := proxy.FromURL(parsed, proxy.Direct) + if err != nil { + return nil, err + } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + default: + return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme) + } + + return transport, nil +} + +func buildClientKey(opts Options) string { + return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d", + strings.TrimSpace(opts.ProxyURL), + opts.Timeout.String(), + opts.ResponseHeaderTimeout.String(), + opts.InsecureSkipVerify, + opts.ProxyStrict, + opts.MaxIdleConns, + opts.MaxIdleConnsPerHost, + opts.MaxConnsPerHost, + ) +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index f7ff2341..b03b5415 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro } func createReqClient(proxyURL string) *req.Client { - client := req.C(). - ImpersonateChrome(). - SetTimeout(60 * time.Second) - - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 60 * time.Second, + Impersonate: true, + }) } func prefix(s string, n int) string { diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 7ccbeafc..424d1a9a 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -6,9 +6,9 @@ import ( "fmt" "io" "net/http" - "net/url" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { } func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { - transport, ok := http.DefaultTransport.(*http.Transport) - if !ok { - return nil, fmt.Errorf("failed to get default transport") - } - transport = transport.Clone() - if proxyURL != "" { - if parsedURL, err := url.Parse(proxyURL); err == nil { - transport.Proxy = http.ProxyURL(parsedURL) - } - } - - client := &http.Client{ - Transport: transport, - Timeout: 30 * time.Second, + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 30 * time.Second, + }) + if err != nil { + client = &http.Client{Timeout: 30 * time.Second} } req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil) diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 2946f691..31527f22 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -3,67 +3,90 @@ package repository import ( "context" "fmt" - "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) +// 并发控制缓存常量定义 +// +// 性能优化说明: +// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}), +// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。 +// +// 新实现改用 Redis 有序集合(Sorted Set): +// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳 +// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1) +// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL +// 4. 单次 Redis 调用完成计数,减少网络往返 const ( - // Key prefixes for independent slot keys - // Format: concurrency:account:{accountID}:{requestID} + // 并发槽位键前缀(有序集合) + // 格式: concurrency:account:{accountID} accountSlotKeyPrefix = "concurrency:account:" - // Format: concurrency:user:{userID}:{requestID} + // 格式: concurrency:user:{userID} userSlotKeyPrefix = "concurrency:user:" - // Wait queue keeps counter format: concurrency:wait:{userID} + // 等待队列计数器格式: concurrency:wait:{userID} waitQueueKeyPrefix = "concurrency:wait:" - // Slot TTL - each slot expires independently - slotTTL = 5 * time.Minute + // 默认槽位过期时间(分钟),可通过配置覆盖 + defaultSlotTTLMinutes = 15 ) var ( - // acquireScript uses SCAN to count existing slots and creates new slot if under limit - // KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*") - // KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx") + // acquireScript 使用有序集合计数并在未达上限时添加槽位 + // 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题 + // KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id}) // ARGV[1] = maxConcurrency - // ARGV[2] = TTL in seconds + // ARGV[2] = TTL(秒) + // ARGV[3] = requestID acquireScript = redis.NewScript(` - local pattern = KEYS[1] - local slotKey = KEYS[2] + local key = KEYS[1] local maxConcurrency = tonumber(ARGV[1]) local ttl = tonumber(ARGV[2]) + local requestID = ARGV[3] - -- Count existing slots using SCAN - local cursor = "0" - local count = 0 - repeat - local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100) - cursor = result[1] - count = count + #result[2] - until cursor == "0" + -- 使用 Redis 服务器时间,确保多实例时钟一致 + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl - -- Check if we can acquire a slot + -- 清理过期槽位 + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + + -- 检查是否已存在(支持重试场景刷新时间戳) + local exists = redis.call('ZSCORE', key, requestID) + if exists ~= false then + redis.call('ZADD', key, now, requestID) + redis.call('EXPIRE', key, ttl) + return 1 + end + + -- 检查是否达到并发上限 + local count = redis.call('ZCARD', key) if count < maxConcurrency then - redis.call('SET', slotKey, '1', 'EX', ttl) + redis.call('ZADD', key, now, requestID) + redis.call('EXPIRE', key, ttl) return 1 end return 0 `) - // getCountScript counts slots using SCAN - // KEYS[1] = pattern for SCAN + // getCountScript 统计有序集合中的槽位数量并清理过期条目 + // 使用 Redis TIME 命令获取服务器时间 + // KEYS[1] = 有序集合键 + // ARGV[1] = TTL(秒) getCountScript = redis.NewScript(` - local pattern = KEYS[1] - local cursor = "0" - local count = 0 - repeat - local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100) - cursor = result[1] - count = count + #result[2] - until cursor == "0" - return count + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + + -- 使用 Redis 服务器时间 + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl + + redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + return redis.call('ZCARD', key) `) // incrementWaitScript - only sets TTL on first creation to avoid refreshing @@ -103,28 +126,29 @@ var ( ) type concurrencyCache struct { - rdb *redis.Client + rdb *redis.Client + slotTTLSeconds int // 槽位过期时间(秒) } -func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache { - return &concurrencyCache{rdb: rdb} +// NewConcurrencyCache 创建并发控制缓存 +// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟 +func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache { + if slotTTLMinutes <= 0 { + slotTTLMinutes = defaultSlotTTLMinutes + } + return &concurrencyCache{ + rdb: rdb, + slotTTLSeconds: slotTTLMinutes * 60, + } } // Helper functions for key generation -func accountSlotKey(accountID int64, requestID string) string { - return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID) +func accountSlotKey(accountID int64) string { + return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) } -func accountSlotPattern(accountID int64) string { - return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID) -} - -func userSlotKey(userID int64, requestID string) string { - return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID) -} - -func userSlotPattern(userID int64) string { - return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID) +func userSlotKey(userID int64) string { + return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) } func waitQueueKey(userID int64) string { @@ -134,10 +158,9 @@ func waitQueueKey(userID int64) string { // Account slot operations func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { - pattern := accountSlotPattern(accountID) - slotKey := accountSlotKey(accountID, requestID) - - result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int() + key := accountSlotKey(accountID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致 + result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int() if err != nil { return false, err } @@ -145,13 +168,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int } func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { - slotKey := accountSlotKey(accountID, requestID) - return c.rdb.Del(ctx, slotKey).Err() + key := accountSlotKey(accountID) + return c.rdb.ZRem(ctx, key, requestID).Err() } func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { - pattern := accountSlotPattern(accountID) - result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() + key := accountSlotKey(accountID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取 + result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int() if err != nil { return 0, err } @@ -161,10 +185,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID // User slot operations func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { - pattern := userSlotPattern(userID) - slotKey := userSlotKey(userID, requestID) - - result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int() + key := userSlotKey(userID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致 + result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int() if err != nil { return false, err } @@ -172,13 +195,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma } func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { - slotKey := userSlotKey(userID, requestID) - return c.rdb.Del(ctx, slotKey).Err() + key := userSlotKey(userID) + return c.rdb.ZRem(ctx, key, requestID).Err() } func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { - pattern := userSlotPattern(userID) - result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int() + key := userSlotKey(userID) + // 时间戳在 Lua 脚本内使用 Redis TIME 命令获取 + result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int() if err != nil { return 0, err } @@ -189,7 +213,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { key := waitQueueKey(userID) - result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int() + result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int() if err != nil { return false, err } diff --git a/backend/internal/repository/concurrency_cache_benchmark_test.go b/backend/internal/repository/concurrency_cache_benchmark_test.go new file mode 100644 index 00000000..29cc7fbc --- /dev/null +++ b/backend/internal/repository/concurrency_cache_benchmark_test.go @@ -0,0 +1,135 @@ +package repository + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/redis/go-redis/v9" +) + +// 基准测试用 TTL 配置 +const benchSlotTTLMinutes = 15 + +var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute + +// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。 +func BenchmarkAccountConcurrency(b *testing.B) { + rdb := newBenchmarkRedisClient(b) + defer func() { + _ = rdb.Close() + }() + + cache := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) + ctx := context.Background() + + for _, size := range []int{10, 100, 1000} { + size := size + b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) { + accountID := time.Now().UnixNano() + key := accountSlotKey(accountID) + + b.StopTimer() + members := make([]redis.Z, 0, size) + now := float64(time.Now().Unix()) + for i := 0; i < size; i++ { + members = append(members, redis.Z{ + Score: now, + Member: fmt.Sprintf("req_%d", i), + }) + } + if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil { + b.Fatalf("初始化有序集合失败: %v", err) + } + if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil { + b.Fatalf("设置有序集合 TTL 失败: %v", err) + } + b.StartTimer() + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil { + b.Fatalf("获取并发数量失败: %v", err) + } + } + + b.StopTimer() + if err := rdb.Del(ctx, key).Err(); err != nil { + b.Fatalf("清理有序集合失败: %v", err) + } + }) + + b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) { + accountID := time.Now().UnixNano() + pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID) + keys := make([]string, 0, size) + + b.StopTimer() + pipe := rdb.Pipeline() + for i := 0; i < size; i++ { + key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i) + keys = append(keys, key) + pipe.Set(ctx, key, "1", benchSlotTTL) + } + if _, err := pipe.Exec(ctx); err != nil { + b.Fatalf("初始化扫描键失败: %v", err) + } + b.StartTimer() + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := scanSlotCount(ctx, rdb, pattern); err != nil { + b.Fatalf("SCAN 计数失败: %v", err) + } + } + + b.StopTimer() + if err := rdb.Del(ctx, keys...).Err(); err != nil { + b.Fatalf("清理扫描键失败: %v", err) + } + }) + } +} + +func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) { + var cursor uint64 + count := 0 + for { + keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result() + if err != nil { + return 0, err + } + count += len(keys) + if nextCursor == 0 { + break + } + cursor = nextCursor + } + return count, nil +} + +func newBenchmarkRedisClient(b *testing.B) *redis.Client { + b.Helper() + + redisURL := os.Getenv("TEST_REDIS_URL") + if redisURL == "" { + b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试") + } + + opt, err := redis.ParseURL(redisURL) + if err != nil { + b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err) + } + + client := redis.NewClient(opt) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + b.Fatalf("Redis 连接失败: %v", err) + } + + return client +} diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index c1feaf85..6a7c83f4 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -14,6 +14,12 @@ import ( "github.com/stretchr/testify/suite" ) +// 测试用 TTL 配置(15 分钟,与默认值一致) +const testSlotTTLMinutes = 15 + +// 测试用 TTL Duration,用于 TTL 断言 +var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute + type ConcurrencyCacheSuite struct { IntegrationRedisSuite cache service.ConcurrencyCache @@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct { func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() - s.cache = NewConcurrencyCache(s.rdb) + s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) } func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { @@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { accountID := int64(11) reqID := "req_ttl_test" - slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID) + slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID) require.NoError(s.T(), err, "AcquireAccountSlot") @@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() require.NoError(s.T(), err, "TTL") - s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) } func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() { @@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { userID := int64(200) reqID := "req_ttl_test" - slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID) + slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID) require.NoError(s.T(), err, "AcquireUserSlot") @@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() require.NoError(s.T(), err, "TTL") - s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) } func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { @@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() require.NoError(s.T(), err, "TTL waitKey") - s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 4e9bae3e..bac8736b 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh } func createGeminiReqClient(proxyURL string) *req.Client { - client := req.C().SetTimeout(60 * time.Second) - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 60 * time.Second, + }) } diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index 0a5d813c..d7f54e85 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken } func createGeminiCliReqClient(proxyURL string) *req.Client { - client := req.C().SetTimeout(30 * time.Second) - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 30 * time.Second, + }) } func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest { diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 25b65a4b..3fa4b1ff 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -9,6 +9,7 @@ import ( "os" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -17,10 +18,14 @@ type githubReleaseClient struct { } func NewGitHubReleaseClient() service.GitHubReleaseClient { + sharedClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 30 * time.Second, + }) + if err != nil { + sharedClient = &http.Client{Timeout: 30 * time.Second} + } return &githubReleaseClient{ - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, + httpClient: sharedClient, } } @@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string return err } - client := &http.Client{Timeout: 10 * time.Minute} - resp, err := client.Do(req) + downloadClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 10 * time.Minute, + }) + if err != nil { + downloadClient = &http.Client{Timeout: 10 * time.Minute} + } + resp, err := downloadClient.Do(req) if err != nil { return err } diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 26befb25..0ca85a09 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -3,65 +3,104 @@ package repository import ( "net/http" "net/url" + "strings" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" ) -// httpUpstreamService is a generic HTTP upstream service that can be used for -// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support. +// httpUpstreamService 通用 HTTP 上游服务 +// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理 +// +// 性能优化: +// 1. 使用 sync.Map 缓存代理客户端实例,避免每次请求都创建新的 http.Client +// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销 +// 3. 原实现每次请求都 new 一个 http.Client,导致连接无法复用 type httpUpstreamService struct { + // defaultClient: 无代理时使用的默认客户端(单例复用) defaultClient *http.Client - cfg *config.Config + // proxyClients: 按代理 URL 缓存的客户端池,避免重复创建 + proxyClients sync.Map + cfg *config.Config } -// NewHTTPUpstream creates a new generic HTTP upstream service +// NewHTTPUpstream 创建通用 HTTP 上游服务 +// 使用配置中的连接池参数构建 Transport func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { - responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout == 0 { - responseHeaderTimeout = 300 * time.Second - } - - transport := &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - ResponseHeaderTimeout: responseHeaderTimeout, - } - return &httpUpstreamService{ - defaultClient: &http.Client{Transport: transport}, + defaultClient: &http.Client{Transport: buildUpstreamTransport(cfg, nil)}, cfg: cfg, } } func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) { - if proxyURL == "" { + if strings.TrimSpace(proxyURL) == "" { return s.defaultClient.Do(req) } - client := s.createProxyClient(proxyURL) + client := s.getOrCreateClient(proxyURL) return client.Do(req) } -func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client { +// getOrCreateClient 获取或创建代理客户端 +// 性能优化:使用 sync.Map 实现无锁缓存,相同代理 URL 复用同一客户端 +// LoadOrStore 保证并发安全,避免重复创建 +func (s *httpUpstreamService) getOrCreateClient(proxyURL string) *http.Client { + proxyURL = strings.TrimSpace(proxyURL) + if proxyURL == "" { + return s.defaultClient + } + // 优先从缓存获取,命中则直接返回 + if cached, ok := s.proxyClients.Load(proxyURL); ok { + return cached.(*http.Client) + } + parsedURL, err := url.Parse(proxyURL) if err != nil { return s.defaultClient } - responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout == 0 { + // 创建新客户端并缓存,LoadOrStore 保证只有一个实例被存储 + client := &http.Client{Transport: buildUpstreamTransport(s.cfg, parsedURL)} + actual, _ := s.proxyClients.LoadOrStore(proxyURL, client) + return actual.(*http.Client) +} + +// buildUpstreamTransport 构建上游请求的 Transport +// 使用配置文件中的连接池参数,支持生产环境调优 +func buildUpstreamTransport(cfg *config.Config, proxyURL *url.URL) *http.Transport { + // 读取配置,使用合理的默认值 + maxIdleConns := cfg.Gateway.MaxIdleConns + if maxIdleConns <= 0 { + maxIdleConns = 240 + } + maxIdleConnsPerHost := cfg.Gateway.MaxIdleConnsPerHost + if maxIdleConnsPerHost <= 0 { + maxIdleConnsPerHost = 120 + } + maxConnsPerHost := cfg.Gateway.MaxConnsPerHost + if maxConnsPerHost < 0 { + maxConnsPerHost = 240 + } + idleConnTimeout := time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second + if idleConnTimeout <= 0 { + idleConnTimeout = 300 * time.Second + } + responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second + if responseHeaderTimeout <= 0 { responseHeaderTimeout = 300 * time.Second } transport := &http.Transport{ - Proxy: http.ProxyURL(parsedURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, + MaxIdleConns: maxIdleConns, // 最大空闲连接总数 + MaxIdleConnsPerHost: maxIdleConnsPerHost, // 每主机最大空闲连接 + MaxConnsPerHost: maxConnsPerHost, // 每主机最大连接数(含活跃) + IdleConnTimeout: idleConnTimeout, // 空闲连接超时 ResponseHeaderTimeout: responseHeaderTimeout, } - - return &http.Client{Transport: transport} + if proxyURL != nil { + transport.Proxy = http.ProxyURL(proxyURL) + } + return transport } diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go new file mode 100644 index 00000000..2ea6e31a --- /dev/null +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -0,0 +1,46 @@ +package repository + +import ( + "net/http" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var httpClientSink *http.Client + +// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销。 +func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300}, + } + upstream := NewHTTPUpstream(cfg) + svc, ok := upstream.(*httpUpstreamService) + if !ok { + b.Fatalf("类型断言失败,无法获取 httpUpstreamService") + } + + proxyURL := "http://127.0.0.1:8080" + b.ReportAllocs() + + b.Run("新建", func(b *testing.B) { + parsedProxy, err := url.Parse(proxyURL) + if err != nil { + b.Fatalf("解析代理地址失败: %v", err) + } + for i := 0; i < b.N; i++ { + httpClientSink = &http.Client{ + Transport: buildUpstreamTransport(cfg, parsedProxy), + } + } + }) + + b.Run("复用", func(b *testing.B) { + client := svc.getOrCreateClient(proxyURL) + b.ResetTimer() + for i := 0; i < b.N; i++ { + httpClientSink = client + } + }) +} diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index 9bc38dae..74132e1d 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -40,13 +40,13 @@ func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() { +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDefault() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5} up := NewHTTPUpstream(s.cfg) svc, ok := up.(*httpUpstreamService) require.True(s.T(), ok, "expected *httpUpstreamService") - got := svc.createProxyClient("://bad-proxy-url") + got := svc.getOrCreateClient("://bad-proxy-url") require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback") } diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index da14a338..07d57410 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro } func createOpenAIReqClient(proxyURL string) *req.Client { - client := req.C(). - SetTimeout(60 * time.Second) - - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - - return client + return getSharedReqClient(reqClientOptions{ + ProxyURL: proxyURL, + Timeout: 60 * time.Second, + }) } diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index ccfebd1b..11f82fd3 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -16,10 +17,14 @@ type pricingRemoteClient struct { } func NewPricingRemoteClient() service.PricingRemoteClient { + sharedClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 30 * time.Second, + }) + if err != nil { + sharedClient = &http.Client{Timeout: 30 * time.Second} + } return &pricingRemoteClient{ - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, + httpClient: sharedClient, } } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 9331859c..8b288c3c 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -2,18 +2,14 @@ package repository import ( "context" - "crypto/tls" "encoding/json" "fmt" "io" - "net" "net/http" - "net/url" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" - - "golang.org/x/net/proxy" ) func NewProxyExitInfoProber() service.ProxyExitInfoProber { @@ -27,14 +23,14 @@ type proxyProbeService struct { } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { - transport, err := createProxyTransport(proxyURL) + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 15 * time.Second, + InsecureSkipVerify: true, + ProxyStrict: true, + }) if err != nil { - return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err) - } - - client := &http.Client{ - Transport: transport, - Timeout: 15 * time.Second, + return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) } startTime := time.Now() @@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s Country: ipInfo.Country, }, latencyMs, nil } - -func createProxyTransport(proxyURL string) (*http.Transport, error) { - parsedURL, err := url.Parse(proxyURL) - if err != nil { - return nil, fmt.Errorf("invalid proxy URL: %w", err) - } - - transport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - - switch parsedURL.Scheme { - case "http", "https": - transport.Proxy = http.ProxyURL(parsedURL) - case "socks5": - dialer, err := proxy.FromURL(parsedURL, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("failed to create socks5 dialer: %w", err) - } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - default: - return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme) - } - - return transport, nil -} diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go index 25ab0f9c..74d99c6d 100644 --- a/backend/internal/repository/proxy_probe_service_test.go +++ b/backend/internal/repository/proxy_probe_service_test.go @@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) { s.proxySrv = httptest.NewServer(handler) } -func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() { - _, err := createProxyTransport("://bad") +func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() { + _, _, err := s.prober.ProbeProxy(s.ctx, "://bad") require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "invalid proxy URL") + require.ErrorContains(s.T(), err, "failed to create proxy client") } -func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() { - _, err := createProxyTransport("ftp://127.0.0.1:1") +func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() { + _, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1") require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "unsupported proxy protocol") -} - -func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() { - tr, err := createProxyTransport("socks5://127.0.0.1:1080") - require.NoError(s.T(), err, "createProxyTransport") - require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5") + require.ErrorContains(s.T(), err, "failed to create proxy client") } func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go new file mode 100644 index 00000000..bfe0ccd2 --- /dev/null +++ b/backend/internal/repository/req_client_pool.go @@ -0,0 +1,59 @@ +package repository + +import ( + "fmt" + "strings" + "sync" + "time" + + "github.com/imroc/req/v3" +) + +// reqClientOptions 定义 req 客户端的构建参数 +type reqClientOptions struct { + ProxyURL string // 代理 URL(支持 http/https/socks5) + Timeout time.Duration // 请求超时时间 + Impersonate bool // 是否模拟 Chrome 浏览器指纹 +} + +// sharedReqClients 存储按配置参数缓存的 req 客户端实例 +// +// 性能优化说明: +// 原实现在每次 OAuth 刷新时都创建新的 req.Client: +// 1. claude_oauth_service.go: 每次刷新创建新客户端 +// 2. openai_oauth_service.go: 每次刷新创建新客户端 +// 3. gemini_oauth_client.go: 每次刷新创建新客户端 +// +// 新实现使用 sync.Map 缓存客户端: +// 1. 相同配置(代理+超时+模拟设置)复用同一客户端 +// 2. 复用底层连接池,减少 TLS 握手开销 +// 3. LoadOrStore 保证并发安全,避免重复创建 +var sharedReqClients sync.Map + +// getSharedReqClient 获取共享的 req 客户端实例 +// 性能优化:相同配置复用同一客户端,避免重复创建 +func getSharedReqClient(opts reqClientOptions) *req.Client { + key := buildReqClientKey(opts) + if cached, ok := sharedReqClients.Load(key); ok { + return cached.(*req.Client) + } + + client := req.C().SetTimeout(opts.Timeout) + if opts.Impersonate { + client = client.ImpersonateChrome() + } + if strings.TrimSpace(opts.ProxyURL) != "" { + client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) + } + + actual, _ := sharedReqClients.LoadOrStore(key, client) + return actual.(*req.Client) +} + +func buildReqClientKey(opts reqClientOptions) string { + return fmt.Sprintf("%s|%s|%t", + strings.TrimSpace(opts.ProxyURL), + opts.Timeout.String(), + opts.Impersonate, + ) +} diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go index c3755011..cf6083e2 100644 --- a/backend/internal/repository/turnstile_service.go +++ b/backend/internal/repository/turnstile_service.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -20,11 +21,15 @@ type turnstileVerifier struct { } func NewTurnstileVerifier() service.TurnstileVerifier { + sharedClient, err := httpclient.GetClient(httpclient.Options{ + Timeout: 10 * time.Second, + }) + if err != nil { + sharedClient = &http.Client{Timeout: 10 * time.Second} + } return &turnstileVerifier{ - httpClient: &http.Client{ - Timeout: 10 * time.Second, - }, - verifyURL: turnstileVerifyURL, + httpClient: sharedClient, + verifyURL: turnstileVerifyURL, } } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 9341f20e..4e26d751 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -452,6 +452,161 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe return &stats, nil } +// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据 +// +// 性能优化说明: +// 原实现先查询所有日志记录,再在应用层循环计算统计值: +// 1. 需要传输大量数据到应用层 +// 2. 应用层循环计算增加 CPU 和内存开销 +// +// 新实现使用 SQL 聚合函数: +// 1. 在数据库层完成 COUNT/SUM/AVG 计算 +// 2. 只返回单行聚合结果,大幅减少数据传输量 +// 3. 利用数据库索引优化聚合查询性能 +func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 + ` + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{accountID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 +// 性能优化:数据库层聚合计算,避免应用层循环统计 +func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE model = $1 AND created_at >= $2 AND created_at < $3 + ` + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{modelName, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据 +// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计 +func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) { + query := ` + SELECT + TO_CHAR(created_at, 'YYYY-MM-DD') as date, + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 + GROUP BY 1 + ORDER BY 1 + ` + + rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + result = nil + } + }() + + result = make([]map[string]any, 0) + for rows.Next() { + var ( + date string + totalRequests int64 + totalInputTokens int64 + totalOutputTokens int64 + totalCacheTokens int64 + totalCost float64 + totalActualCost float64 + avgDurationMs float64 + ) + if err = rows.Scan( + &date, + &totalRequests, + &totalInputTokens, + &totalOutputTokens, + &totalCacheTokens, + &totalCost, + &totalActualCost, + &avgDurationMs, + ); err != nil { + return nil, err + } + result = append(result, map[string]any{ + "date": date, + "total_requests": totalRequests, + "total_input_tokens": totalInputTokens, + "total_output_tokens": totalOutputTokens, + "total_cache_tokens": totalCacheTokens, + "total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens, + "total_cost": totalCost, + "total_actual_cost": totalActualCost, + "average_duration_ms": avgDurationMs, + }) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 53d42d90..edeaf782 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -1,9 +1,18 @@ package repository import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/wire" + "github.com/redis/go-redis/v9" ) +// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 +// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 +func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { + return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) +} + // ProviderSet is the Wire provider set for all repositories var ProviderSet = wire.NewSet( NewUserRepository, @@ -20,7 +29,7 @@ var ProviderSet = wire.NewSet( NewGatewayCache, NewBillingCache, NewApiKeyCache, - NewConcurrencyCache, + ProvideConcurrencyCache, NewEmailCache, NewIdentityCache, NewRedeemCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 8d5ace96..5a243bfc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -981,6 +981,18 @@ func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyI return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/request_body_limit.go b/backend/internal/server/middleware/request_body_limit.go new file mode 100644 index 00000000..fce13eea --- /dev/null +++ b/backend/internal/server/middleware/request_body_limit.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。 +func RequestBodyLimit(maxBytes int64) gin.HandlerFunc { + return func(c *gin.Context) { + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes) + c.Next() + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 34792be8..38df9225 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -18,8 +18,11 @@ func RegisterGatewayRoutes( subscriptionService *service.SubscriptionService, cfg *config.Config, ) { + bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) + // API网关(Claude API兼容) gateway := r.Group("/v1") + gateway.Use(bodyLimit) gateway.Use(gin.HandlerFunc(apiKeyAuth)) { gateway.POST("/messages", h.Gateway.Messages) @@ -32,6 +35,7 @@ func RegisterGatewayRoutes( // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) gemini := r.Group("/v1beta") + gemini.Use(bodyLimit) gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) @@ -41,10 +45,11 @@ func RegisterGatewayRoutes( } // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) antigravityV1 := r.Group("/antigravity/v1") + antigravityV1.Use(bodyLimit) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) { @@ -55,6 +60,7 @@ func RegisterGatewayRoutes( } antigravityV1Beta := r.Group("/antigravity/v1beta") + antigravityV1Beta.Use(bodyLimit) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 94d4c747..dba670b0 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -52,6 +52,9 @@ type UsageLogRepository interface { // Aggregated stats (optimized) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) } // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 9493a11f..ac320535 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -27,6 +28,45 @@ type subscriptionCacheData struct { Version int64 } +// 缓存写入任务类型 +type cacheWriteKind int + +const ( + cacheWriteSetBalance cacheWriteKind = iota + cacheWriteSetSubscription + cacheWriteUpdateSubscriptionUsage + cacheWriteDeductBalance +) + +// 异步缓存写入工作池配置 +// +// 性能优化说明: +// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题: +// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine +// 2. 无法控制并发数量,可能导致 Redis 连接耗尽 +// 3. goroutine 创建/销毁带来额外开销 +// +// 新实现使用固定大小的工作池: +// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁 +// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值 +// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受) +// 4. 统一超时控制,避免慢操作阻塞工作池 +const ( + cacheWriteWorkerCount = 10 // 工作协程数量 + cacheWriteBufferSize = 1000 // 任务队列缓冲大小 + cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 +) + +// cacheWriteTask 缓存写入任务 +type cacheWriteTask struct { + kind cacheWriteKind + userID int64 + groupID int64 + balance float64 + amount float64 + subscriptionData *subscriptionCacheData +} + // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { @@ -34,16 +74,81 @@ type BillingCacheService struct { userRepo UserRepository subRepo UserSubscriptionRepository cfg *config.Config + + cacheWriteChan chan cacheWriteTask + cacheWriteWg sync.WaitGroup + cacheWriteStopOnce sync.Once } // NewBillingCacheService 创建计费缓存服务 func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { - return &BillingCacheService{ + svc := &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, cfg: cfg, } + svc.startCacheWriteWorkers() + return svc +} + +// Stop 关闭缓存写入工作池 +func (s *BillingCacheService) Stop() { + s.cacheWriteStopOnce.Do(func() { + if s.cacheWriteChan == nil { + return + } + close(s.cacheWriteChan) + s.cacheWriteWg.Wait() + s.cacheWriteChan = nil + }) +} + +func (s *BillingCacheService) startCacheWriteWorkers() { + s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize) + for i := 0; i < cacheWriteWorkerCount; i++ { + s.cacheWriteWg.Add(1) + go s.cacheWriteWorker() + } +} + +func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) { + if s.cacheWriteChan == nil { + return + } + defer func() { + _ = recover() + }() + select { + case s.cacheWriteChan <- task: + default: + } +} + +func (s *BillingCacheService) cacheWriteWorker() { + defer s.cacheWriteWg.Done() + for task := range s.cacheWriteChan { + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + switch task.kind { + case cacheWriteSetBalance: + s.setBalanceCache(ctx, task.userID, task.balance) + case cacheWriteSetSubscription: + s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData) + case cacheWriteUpdateSubscriptionUsage: + if s.cache != nil { + if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil { + log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) + } + } + case cacheWriteDeductBalance: + if s.cache != nil { + if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { + log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err) + } + } + } + cancel() + } } // ============================================ @@ -70,11 +175,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) } // 异步建立缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - s.setBalanceCache(cacheCtx, userID, balance) - }() + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetBalance, + userID: userID, + balance: balance, + }) return balance, nil } @@ -98,7 +203,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, } } -// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存) +// DeductBalanceCache 扣减余额缓存(同步调用) func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error { if s.cache == nil { return nil @@ -106,6 +211,15 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int return s.cache.DeductUserBalance(ctx, userID, amount) } +// QueueDeductBalance 异步扣减余额缓存 +func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteDeductBalance, + userID: userID, + amount: amount, + }) +} + // InvalidateUserBalance 失效用户余额缓存 func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error { if s.cache == nil { @@ -141,11 +255,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, } // 异步建立缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - s.setSubscriptionCache(cacheCtx, userID, groupID, data) - }() + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetSubscription, + userID: userID, + groupID: groupID, + subscriptionData: data, + }) return data, nil } @@ -199,7 +314,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, } } -// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存) +// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用) func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error { if s.cache == nil { return nil @@ -207,6 +322,16 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD) } +// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存 +func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) { + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteUpdateSubscriptionUsage, + userID: userID, + groupID: groupID, + amount: costUSD, + }) +} + // InvalidateSubscription 失效指定订阅缓存 func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error { if s.cache == nil { diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go new file mode 100644 index 00000000..445d5319 --- /dev/null +++ b/backend/internal/service/billing_cache_service_test.go @@ -0,0 +1,75 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type billingCacheWorkerStub struct { + balanceUpdates int64 + subscriptionUpdates int64 +} + +func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("not implemented") +} + +func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + atomic.AddInt64(&b.balanceUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + atomic.AddInt64(&b.balanceUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { + return nil, errors.New("not implemented") +} + +func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { + atomic.AddInt64(&b.subscriptionUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + atomic.AddInt64(&b.subscriptionUpdates, 1) + return nil +} + +func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +func TestBillingCacheServiceQueueHighLoad(t *testing.T) { + cache := &billingCacheWorkerStub{} + svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + t.Cleanup(svc.Stop) + + start := time.Now() + for i := 0; i < cacheWriteBufferSize*2; i++ { + svc.QueueDeductBalance(1, 1) + } + require.Less(t, time.Since(start), 2*time.Second) + + svc.QueueUpdateSubscriptionUsage(1, 2, 1.5) + + require.Eventually(t, func() bool { + return atomic.LoadInt64(&cache.balanceUpdates) > 0 + }, 2*time.Second, 10*time.Millisecond) + + require.Eventually(t, func() bool { + return atomic.LoadInt64(&cache.subscriptionUpdates) > 0 + }, 2*time.Second, 10*time.Millisecond) +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index a6cff234..b5229491 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -9,22 +9,22 @@ import ( "time" ) -// ConcurrencyCache defines cache operations for concurrency service -// Uses independent keys per request slot with native Redis TTL for automatic cleanup +// ConcurrencyCache 定义并发控制的缓存接口 +// 使用有序集合存储槽位,按时间戳清理过期条目 type ConcurrencyCache interface { - // Account slot management - each slot is a separate key with independent TTL - // Key format: concurrency:account:{accountID}:{requestID} + // 账号槽位管理 + // 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) - // User slot management - each slot is a separate key with independent TTL - // Key format: concurrency:user:{userID}:{requestID} + // 用户槽位管理 + // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error GetUserConcurrency(ctx context.Context, userID int64) (int, error) - // Wait queue - uses counter with TTL set only on creation + // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error } diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index cd1dbcec..fd23ecb2 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" ) type CRSSyncService struct { @@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput return nil, errors.New("username and password are required") } - client := &http.Client{Timeout: 20 * time.Second} + client, err := httpclient.GetClient(httpclient.Options{ + Timeout: 20 * time.Second, + }) + if err != nil { + client = &http.Client{Timeout: 20 * time.Second} + } adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password) if err != nil { diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go new file mode 100644 index 00000000..6d358c36 --- /dev/null +++ b/backend/internal/service/gateway_request.go @@ -0,0 +1,70 @@ +package service + +import ( + "encoding/json" + "fmt" +) + +// ParsedRequest 保存网关请求的预解析结果 +// +// 性能优化说明: +// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次): +// 1. gateway_handler.go 解析获取 model 和 stream +// 2. gateway_service.go 再次解析获取 system、messages、metadata +// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段 +// +// 新实现一次解析,多处复用: +// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析 +// 2. 将解析结果 ParsedRequest 传递给 Service 层 +// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 +type ParsedRequest struct { + Body []byte // 原始请求体(保留用于转发) + Model string // 请求的模型名称 + Stream bool // 是否为流式请求 + MetadataUserID string // metadata.user_id(用于会话亲和) + System any // system 字段内容 + Messages []any // messages 数组 + HasSystem bool // 是否包含 system 字段 +} + +// ParseGatewayRequest 解析网关请求体并返回结构化结果 +// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal +func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + parsed := &ParsedRequest{ + Body: body, + } + + if rawModel, exists := req["model"]; exists { + model, ok := rawModel.(string) + if !ok { + return nil, fmt.Errorf("invalid model field type") + } + parsed.Model = model + } + if rawStream, exists := req["stream"]; exists { + stream, ok := rawStream.(bool) + if !ok { + return nil, fmt.Errorf("invalid stream field type") + } + parsed.Stream = stream + } + if metadata, ok := req["metadata"].(map[string]any); ok { + if userID, ok := metadata["user_id"].(string); ok { + parsed.MetadataUserID = userID + } + } + if system, ok := req["system"]; ok && system != nil { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } + + return parsed, nil +} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go new file mode 100644 index 00000000..c921e0f6 --- /dev/null +++ b/backend/internal/service/gateway_request_test.go @@ -0,0 +1,38 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseGatewayRequest(t *testing.T) { + body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.Equal(t, "claude-3-7-sonnet", parsed.Model) + require.True(t, parsed.Stream) + require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID) + require.True(t, parsed.HasSystem) + require.NotNil(t, parsed.System) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_SystemNull(t *testing.T) { + body := []byte(`{"model":"claude-3","system":null}`) + parsed, err := ParseGatewayRequest(body) + require.NoError(t, err) + require.False(t, parsed.HasSystem) +} + +func TestParseGatewayRequest_InvalidModelType(t *testing.T) { + body := []byte(`{"model":123}`) + _, err := ParseGatewayRequest(body) + require.Error(t, err) +} + +func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { + body := []byte(`{"stream":"true"}`) + _, err := ParseGatewayRequest(body) + require.Error(t, err) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 50bfd161..41362662 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -19,7 +19,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -33,7 +32,10 @@ const ( // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). -var sseDataRe = regexp.MustCompile(`^data:\s*`) +var ( + sseDataRe = regexp.MustCompile(`^data:\s*`) + sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) +) // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ @@ -141,40 +143,36 @@ func NewGatewayService( } } -// GenerateSessionHash 从请求体计算粘性会话hash -func (s *GatewayService) GenerateSessionHash(body []byte) string { - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { +// GenerateSessionHash 从预解析请求计算粘性会话 hash +func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { + if parsed == nil { return "" } - // 1. 最高优先级:从metadata.user_id提取session_xxx - if metadata, ok := req["metadata"].(map[string]any); ok { - if userID, ok := metadata["user_id"].(string); ok { - re := regexp.MustCompile(`session_([a-f0-9-]{36})`) - if match := re.FindStringSubmatch(userID); len(match) > 1 { - return match[1] - } + // 1. 最高优先级:从 metadata.user_id 提取 session_xxx + if parsed.MetadataUserID != "" { + if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { + return match[1] } } - // 2. 提取带cache_control: {type: "ephemeral"}的内容 - cacheableContent := s.extractCacheableContent(req) + // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 + cacheableContent := s.extractCacheableContent(parsed) if cacheableContent != "" { return s.hashContent(cacheableContent) } - // 3. Fallback: 使用system内容 - if system := req["system"]; system != nil { - systemText := s.extractTextFromSystem(system) + // 3. Fallback: 使用 system 内容 + if parsed.System != nil { + systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { return s.hashContent(systemText) } } - // 4. 最后fallback: 使用第一条消息 - if messages, ok := req["messages"].([]any); ok && len(messages) > 0 { - if firstMsg, ok := messages[0].(map[string]any); ok { + // 4. 最后 fallback: 使用第一条消息 + if len(parsed.Messages) > 0 { + if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { msgText := s.extractTextFromContent(firstMsg["content"]) if msgText != "" { return s.hashContent(msgText) @@ -185,36 +183,38 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string { return "" } -func (s *GatewayService) extractCacheableContent(req map[string]any) string { - var content string +func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } - // 检查system中的cacheable内容 - if system, ok := req["system"].([]any); ok { + var builder strings.Builder + + // 检查 system 中的 cacheable 内容 + if system, ok := parsed.System.([]any); ok { for _, part := range system { if partMap, ok := part.(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc["type"] == "ephemeral" { if text, ok := partMap["text"].(string); ok { - content += text + builder.WriteString(text) } } } } } } + systemText := builder.String() - // 检查messages中的cacheable内容 - if messages, ok := req["messages"].([]any); ok { - for _, msg := range messages { - if msgMap, ok := msg.(map[string]any); ok { - if msgContent, ok := msgMap["content"].([]any); ok { - for _, part := range msgContent { - if partMap, ok := part.(map[string]any); ok { - if cc, ok := partMap["cache_control"].(map[string]any); ok { - if cc["type"] == "ephemeral" { - // 找到cacheable内容,提取第一条消息的文本 - return s.extractTextFromContent(msgMap["content"]) - } + // 检查 messages 中的 cacheable 内容 + for _, msg := range parsed.Messages { + if msgMap, ok := msg.(map[string]any); ok { + if msgContent, ok := msgMap["content"].([]any); ok { + for _, part := range msgContent { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + return s.extractTextFromContent(msgMap["content"]) } } } @@ -223,7 +223,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string { } } - return content + return systemText } func (s *GatewayService) extractTextFromSystem(system any) string { @@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { } // Forward 转发请求到Claude API -func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { +func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() - - // 解析请求获取model和stream - var req struct { - Model string `json:"model"` - Stream bool `json:"stream"` - } - if err := json.Unmarshal(body, &req); err != nil { - return nil, fmt.Errorf("parse request: %w", err) + if parsed == nil { + return nil, fmt.Errorf("parse request: empty request") } - if !gjson.GetBytes(body, "system").Exists() { + body := parsed.Body + reqModel := parsed.Model + reqStream := parsed.Stream + + if !parsed.HasSystem { body, _ = sjson.SetBytes(body, "system", []any{ map[string]any{ "type": "text", @@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 应用模型映射(仅对apikey类型账号) - originalModel := req.Model + originalModel := reqModel if account.Type == AccountTypeApiKey { - mappedModel := account.GetMappedModel(req.Model) - if mappedModel != req.Model { + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { // 替换请求体中的模型名 body = s.replaceModelInBody(body, mappedModel) - req.Model = mappedModel + reqModel = mappedModel log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) } } @@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A var resp *http.Response for attempt := 1; attempt <= maxRetries; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) if err != nil { return nil, err } @@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理正常响应 var usage *ClaudeUsage var firstTokenMs *int - if req.Stream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model) + if reqStream { + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) if err != nil { if err.Error() == "have error in stream" { return nil, &UpstreamFailoverError{ @@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) if err != nil { return nil, err } @@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 - Stream: req.Stream, + Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeApiKey { @@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) } return req, nil @@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // getBetaHeader 处理anthropic-beta header // 对于OAuth账号,需要确保包含oauth-2025-04-20 -func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string { +func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { // 如果客户端传了anthropic-beta if clientBetaHeader != "" { // 已包含oauth beta则直接返回 @@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str } // 客户端没传,根据模型生成 - var modelID string - var reqMap map[string]any - if json.Unmarshal(body, &reqMap) == nil { - if m, ok := reqMap["model"].(string); ok { - modelID = m - } - } - - // haiku模型不需要claude-code beta + // haiku 模型不需要 claude-code beta if strings.Contains(strings.ToLower(modelID), "haiku") { return claude.HaikuBetaHeader } @@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu log.Printf("Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil { - log.Printf("Update subscription cache failed: %v", err) - } - }() + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) @@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu log.Printf("Deduct balance failed: %v", err) } // 异步更新余额缓存 - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil { - log.Printf("Update balance cache failed: %v", err) - } - }() + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) } } @@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 -func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error { +func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { + if parsed == nil { + s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return fmt.Errorf("parse request: empty request") + } + + body := parsed.Body + reqModel := parsed.Model + // Antigravity 账户不支持 count_tokens 转发,返回估算值 // 参考 Antigravity-Manager 和 proxycast 实现 if account.Platform == PlatformAntigravity { @@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 应用模型映射(仅对 apikey 类型账号) if account.Type == AccountTypeApiKey { - var req struct { - Model string `json:"model"` - } - if err := json.Unmarshal(body, &req); err == nil && req.Model != "" { - mappedModel := account.GetMappedModel(req.Model) - if mappedModel != req.Model { + if reqModel != "" { + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { body = s.replaceModelInBody(body, mappedModel) - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name) + reqModel = mappedModel + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) } } } @@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 构建上游请求 - upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel) if err != nil { s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") return err @@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL if account.Type == AccountTypeApiKey { @@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) } return req, nil diff --git a/backend/internal/service/gateway_service_benchmark_test.go b/backend/internal/service/gateway_service_benchmark_test.go new file mode 100644 index 00000000..f15a85d6 --- /dev/null +++ b/backend/internal/service/gateway_service_benchmark_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "strconv" + "testing" +) + +var benchmarkStringSink string + +// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。 +func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { + svc := &GatewayService{} + body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + parsed, err := ParseGatewayRequest(body) + if err != nil { + b.Fatalf("解析请求失败: %v", err) + } + benchmarkStringSink = svc.GenerateSessionHash(parsed) + } +} + +// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。 +func BenchmarkExtractCacheableContent_System(b *testing.B) { + svc := &GatewayService{} + req := buildSystemCacheableRequest(12) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + benchmarkStringSink = svc.extractCacheableContent(req) + } +} + +func buildSystemCacheableRequest(parts int) *ParsedRequest { + systemParts := make([]any, 0, parts) + for i := 0; i < parts; i++ { + systemParts = append(systemParts, map[string]any{ + "text": "system_part_" + strconv.Itoa(i), + "cache_control": map[string]any{ + "type": "ephemeral", + }, + }) + } + return &ParsedRequest{ + System: systemParts, + HasSystem: true, + } +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 6ccbf8ec..34958541 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -921,7 +921,10 @@ func sleepGeminiBackoff(attempt int) { time.Sleep(sleepFor) } -var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`) +var ( + sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`) + retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`) +) func sanitizeUpstreamErrorMessage(msg string) string { if msg == "" { @@ -1925,7 +1928,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 { } // Match "Please retry in Xs" - retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`) matches := retryInRegex.FindStringSubmatch(string(body)) if len(matches) == 2 { if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 36257667..e4bda5f8 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -7,13 +7,13 @@ import ( "fmt" "io" "net/http" - "net/url" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" ) type GeminiOAuthService struct { @@ -497,11 +497,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) - client := &http.Client{Timeout: 30 * time.Second} - if strings.TrimSpace(proxyURL) != "" { - if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil { - client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)} - } + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: strings.TrimSpace(proxyURL), + Timeout: 30 * time.Second, + }) + if err != nil { + client = &http.Client{Timeout: 30 * time.Second} } resp, err := client.Do(req) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 79801b29..aa844554 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -768,20 +768,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if isSubscriptionBilling { if cost.TotalCost > 0 { _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost) - }() + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { if cost.ActualCost > 0 { _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost) - }() + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) } } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index e2e263d9..bb050d0a 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -18,6 +18,11 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var ( + openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`) + openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) +) + // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { @@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) // 3. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { - // 正则匹配日期后缀 (如 -20251222) - datePattern := regexp.MustCompile(`-\d{8}$`) - // 尝试的回退变体 - variants := s.generateOpenAIModelVariants(model, datePattern) + variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern) for _, variant := range variants { if pricing, ok := s.pricingData[variant]; ok { @@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern * // 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2 // 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的 - basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`) - if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 { + if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 { addVariant(matches[1]) } // 3. 同时去掉日期后再提取基础版本号 if withoutDate != model { - if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 { + if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 { addVariant(matches[1]) } } diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 0df8a0de..f57e90eb 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta // GetStatsByAccount 获取账号的使用统计 func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) { - logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime) + stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get account stats: %w", err) } - return s.calculateStats(logs), nil + return &UsageStats{ + TotalRequests: stats.TotalRequests, + TotalInputTokens: stats.TotalInputTokens, + TotalOutputTokens: stats.TotalOutputTokens, + TotalCacheTokens: stats.TotalCacheTokens, + TotalTokens: stats.TotalTokens, + TotalCost: stats.TotalCost, + TotalActualCost: stats.TotalActualCost, + AverageDurationMs: stats.AverageDurationMs, + }, nil } // GetStatsByModel 获取模型的使用统计 func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) { - logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime) + stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get model stats: %w", err) } - return s.calculateStats(logs), nil + return &UsageStats{ + TotalRequests: stats.TotalRequests, + TotalInputTokens: stats.TotalInputTokens, + TotalOutputTokens: stats.TotalOutputTokens, + TotalCacheTokens: stats.TotalCacheTokens, + TotalTokens: stats.TotalTokens, + TotalCost: stats.TotalCost, + TotalActualCost: stats.TotalActualCost, + AverageDurationMs: stats.AverageDurationMs, + }, nil } // GetDailyStats 获取每日使用统计(最近N天) @@ -209,54 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int endTime := time.Now() startTime := endTime.AddDate(0, 0, -days) - logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime) + stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime) if err != nil { - return nil, fmt.Errorf("list usage logs: %w", err) + return nil, fmt.Errorf("get daily stats: %w", err) } - // 按日期分组统计 - dailyStats := make(map[string]*UsageStats) - for _, log := range logs { - dateKey := log.CreatedAt.Format("2006-01-02") - if _, exists := dailyStats[dateKey]; !exists { - dailyStats[dateKey] = &UsageStats{} - } - - stats := dailyStats[dateKey] - stats.TotalRequests++ - stats.TotalInputTokens += int64(log.InputTokens) - stats.TotalOutputTokens += int64(log.OutputTokens) - stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens) - stats.TotalTokens += int64(log.TotalTokens()) - stats.TotalCost += log.TotalCost - stats.TotalActualCost += log.ActualCost - - if log.DurationMs != nil { - stats.AverageDurationMs += float64(*log.DurationMs) - } - } - - // 计算平均值并转换为数组 - result := make([]map[string]any, 0, len(dailyStats)) - for date, stats := range dailyStats { - if stats.TotalRequests > 0 { - stats.AverageDurationMs /= float64(stats.TotalRequests) - } - - result = append(result, map[string]any{ - "date": date, - "total_requests": stats.TotalRequests, - "total_input_tokens": stats.TotalInputTokens, - "total_output_tokens": stats.TotalOutputTokens, - "total_cache_tokens": stats.TotalCacheTokens, - "total_tokens": stats.TotalTokens, - "total_cost": stats.TotalCost, - "total_actual_cost": stats.TotalActualCost, - "average_duration_ms": stats.AverageDurationMs, - }) - } - - return result, nil + return stats, nil } // calculateStats 计算统计数据 diff --git a/backend/migrations/010_add_usage_logs_aggregated_indexes.sql b/backend/migrations/010_add_usage_logs_aggregated_indexes.sql new file mode 100644 index 00000000..ab2dbbc1 --- /dev/null +++ b/backend/migrations/010_add_usage_logs_aggregated_indexes.sql @@ -0,0 +1,4 @@ +-- 为聚合查询补充复合索引 +CREATE INDEX IF NOT EXISTS idx_usage_logs_account_created_at ON usage_logs(account_id, created_at); +CREATE INDEX IF NOT EXISTS idx_usage_logs_api_key_created_at ON usage_logs(api_key_id, created_at); +CREATE INDEX IF NOT EXISTS idx_usage_logs_model_created_at ON usage_logs(model, created_at); diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index fcaa7b7c..8db4cbc9 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -21,6 +21,22 @@ server: # - simple: Hides SaaS features and skips billing/balance checks run_mode: "standard" +# ============================================================================= +# 网关配置 +# ============================================================================= +gateway: + # 等待上游响应头超时时间(秒) + response_header_timeout: 300 + # 请求体最大字节数(默认 100MB) + max_body_size: 104857600 + # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认) + max_idle_conns: 240 + max_idle_conns_per_host: 120 + max_conns_per_host: 240 + idle_conn_timeout_seconds: 300 + # 并发槽位过期时间(分钟) + concurrency_slot_ttl_minutes: 15 + # ============================================================================= # Database Configuration (PostgreSQL) # ============================================================================= From 3d7f8e4b3ac96c537014cd5f8af625aa6a8273b9 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 10:17:38 +0800 Subject: [PATCH 03/49] =?UTF-8?q?fix(=E6=9C=8D=E5=8A=A1):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8Dsystem=E5=88=A4=E5=AE=9A=E3=80=81=E7=BB=9F=E8=AE=A1?= =?UTF-8?q?=E6=97=B6=E5=8C=BA=E4=B8=8E=E7=BC=93=E5=AD=98=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - system 字段存在即视为显式提供,避免 null 触发默认注入 - 日统计分组显式使用应用时区,缺失时从 TZ 回退到 UTC - 缓存写入队列丢弃日志节流汇总,关键任务同步回退 测试: go test ./internal/service -run TestBillingCacheServiceQueueHighLoad --- backend/internal/repository/usage_log_repo.go | 20 ++- .../internal/service/billing_cache_service.go | 120 ++++++++++++++++-- backend/internal/service/gateway_request.go | 6 +- .../internal/service/gateway_request_test.go | 4 +- 4 files changed, 132 insertions(+), 18 deletions(-) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 4e26d751..9a210bde 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "os" "strings" "time" @@ -536,9 +537,11 @@ func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelN // GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据 // 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计 func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) { + tzName := resolveUsageStatsTimezone() query := ` SELECT - TO_CHAR(created_at, 'YYYY-MM-DD') as date, + -- 使用应用时区分组,避免数据库会话时区导致日边界偏移。 + TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date, COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(output_tokens), 0) as total_output_tokens, @@ -552,7 +555,7 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID ORDER BY 1 ` - rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) + rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName) if err != nil { return nil, err } @@ -607,6 +610,19 @@ func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID return result, nil } +// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。 +// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。 +func resolveUsageStatsTimezone() string { + tzName := timezone.Name() + if tzName != "" && tzName != "Local" { + return tzName + } + if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" { + return envTZ + } + return "UTC" +} + func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index ac320535..58ed555a 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "sync" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -49,12 +50,13 @@ const ( // 新实现使用固定大小的工作池: // 1. 预创建 10 个 worker goroutine,避免频繁创建销毁 // 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值 -// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受) +// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警 // 4. 统一超时控制,避免慢操作阻塞工作池 const ( - cacheWriteWorkerCount = 10 // 工作协程数量 - cacheWriteBufferSize = 1000 // 任务队列缓冲大小 - cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 + cacheWriteWorkerCount = 10 // 工作协程数量 + cacheWriteBufferSize = 1000 // 任务队列缓冲大小 + cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 + cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 ) // cacheWriteTask 缓存写入任务 @@ -78,6 +80,11 @@ type BillingCacheService struct { cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup cacheWriteStopOnce sync.Once + // 丢弃日志节流计数器(减少高负载下日志噪音) + cacheWriteDropFullCount uint64 + cacheWriteDropFullLastLog int64 + cacheWriteDropClosedCount uint64 + cacheWriteDropClosedLastLog int64 } // NewBillingCacheService 创建计费缓存服务 @@ -112,16 +119,25 @@ func (s *BillingCacheService) startCacheWriteWorkers() { } } -func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) { +// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 +func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { if s.cacheWriteChan == nil { - return + return false } defer func() { - _ = recover() + if recovered := recover(); recovered != nil { + // 队列已关闭时可能触发 panic,记录后静默失败。 + s.logCacheWriteDrop(task, "closed") + enqueued = false + } }() select { case s.cacheWriteChan <- task: + return true default: + // 队列满时不阻塞主流程,交由调用方决定是否同步回退。 + s.logCacheWriteDrop(task, "full") + return false } } @@ -151,6 +167,62 @@ func (s *BillingCacheService) cacheWriteWorker() { } } +// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。 +func cacheWriteKindName(kind cacheWriteKind) string { + switch kind { + case cacheWriteSetBalance: + return "set_balance" + case cacheWriteSetSubscription: + return "set_subscription" + case cacheWriteUpdateSubscriptionUsage: + return "update_subscription_usage" + case cacheWriteDeductBalance: + return "deduct_balance" + default: + return "unknown" + } +} + +// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。 +func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) { + var ( + countPtr *uint64 + lastPtr *int64 + ) + switch reason { + case "full": + countPtr = &s.cacheWriteDropFullCount + lastPtr = &s.cacheWriteDropFullLastLog + case "closed": + countPtr = &s.cacheWriteDropClosedCount + lastPtr = &s.cacheWriteDropClosedLastLog + default: + return + } + + atomic.AddUint64(countPtr, 1) + now := time.Now().UnixNano() + last := atomic.LoadInt64(lastPtr) + if now-last < int64(cacheWriteDropLogInterval) { + return + } + if !atomic.CompareAndSwapInt64(lastPtr, last, now) { + return + } + dropped := atomic.SwapUint64(countPtr, 0) + if dropped == 0 { + return + } + log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", + reason, + dropped, + cacheWriteDropLogInterval, + cacheWriteKindName(task.kind), + task.userID, + task.groupID, + ) +} + // ============================================ // 余额缓存方法 // ============================================ @@ -175,7 +247,7 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) } // 异步建立缓存 - s.enqueueCacheWrite(cacheWriteTask{ + _ = s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteSetBalance, userID: userID, balance: balance, @@ -213,11 +285,22 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int // QueueDeductBalance 异步扣减余额缓存 func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { - s.enqueueCacheWrite(cacheWriteTask{ + if s.cache == nil { + return + } + // 队列满时同步回退,避免关键扣减被静默丢弃。 + if s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteDeductBalance, userID: userID, amount: amount, - }) + }) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + if err := s.DeductBalanceCache(ctx, userID, amount); err != nil { + log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err) + } } // InvalidateUserBalance 失效用户余额缓存 @@ -255,7 +338,7 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, } // 异步建立缓存 - s.enqueueCacheWrite(cacheWriteTask{ + _ = s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteSetSubscription, userID: userID, groupID: groupID, @@ -324,12 +407,23 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI // QueueUpdateSubscriptionUsage 异步更新订阅用量缓存 func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) { - s.enqueueCacheWrite(cacheWriteTask{ + if s.cache == nil { + return + } + // 队列满时同步回退,确保订阅用量及时更新。 + if s.enqueueCacheWrite(cacheWriteTask{ kind: cacheWriteUpdateSubscriptionUsage, userID: userID, groupID: groupID, amount: costUSD, - }) + }) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil { + log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) + } } // InvalidateSubscription 失效指定订阅缓存 diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 6d358c36..fbec1371 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -24,7 +24,7 @@ type ParsedRequest struct { MetadataUserID string // metadata.user_id(用于会话亲和) System any // system 字段内容 Messages []any // messages 数组 - HasSystem bool // 是否包含 system 字段 + HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) } // ParseGatewayRequest 解析网关请求体并返回结构化结果 @@ -58,7 +58,9 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { parsed.MetadataUserID = userID } } - if system, ok := req["system"]; ok && system != nil { + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if system, ok := req["system"]; ok { parsed.HasSystem = true parsed.System = system } diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index c921e0f6..5d411e2c 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -22,7 +22,9 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) parsed, err := ParseGatewayRequest(body) require.NoError(t, err) - require.False(t, parsed.HasSystem) + // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 + require.True(t, parsed.HasSystem) + require.Nil(t, parsed.System) } func TestParseGatewayRequest_InvalidModelType(t *testing.T) { From d1c98896094fc4bc2f4ecfa0095083cbf29d7ccd Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 11:43:58 +0800 Subject: [PATCH 04/49] =?UTF-8?q?perf(=E7=BD=91=E5=85=B3):=20=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E4=B8=8A=E6=B8=B8=E8=B4=A6=E5=8F=B7=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=B1=A0=E9=9A=94=E7=A6=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增隔离策略与连接池缓存回收 连接池大小跟随账号并发并处理代理切换 同步配置默认值与示例并补充测试 --- backend/internal/config/config.go | 48 +- backend/internal/repository/http_upstream.go | 573 ++++++++++++++++-- .../http_upstream_benchmark_test.go | 30 +- .../internal/repository/http_upstream_test.go | 184 +++++- .../internal/service/account_test_service.go | 6 +- .../service/antigravity_gateway_service.go | 4 +- backend/internal/service/gateway_service.go | 4 +- .../service/gemini_messages_compat_service.go | 6 +- .../internal/service/http_upstream_port.go | 27 +- .../service/openai_gateway_service.go | 2 +- deploy/config.example.yaml | 10 + 11 files changed, 790 insertions(+), 104 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dfc9a844..aeeddcb4 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -12,6 +12,20 @@ const ( RunModeSimple = "simple" ) +// 连接池隔离策略常量 +// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 +const ( + // ConnectionPoolIsolationProxy: 按代理隔离 + // 同一代理地址共享连接池,适合代理数量少、账户数量多的场景 + ConnectionPoolIsolationProxy = "proxy" + // ConnectionPoolIsolationAccount: 按账户隔离 + // 每个账户独立连接池,适合账户数量少、需要严格隔离的场景 + ConnectionPoolIsolationAccount = "account" + // ConnectionPoolIsolationAccountProxy: 按账户+代理组合隔离(默认) + // 同一账户+代理组合共享连接池,提供最细粒度的隔离 + ConnectionPoolIsolationAccountProxy = "account_proxy" +) + type Config struct { Server ServerConfig `mapstructure:"server"` Database DatabaseConfig `mapstructure:"database"` @@ -81,6 +95,8 @@ type GatewayConfig struct { ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 请求体最大字节数,用于网关请求体大小限制 MaxBodySize int64 `mapstructure:"max_body_size"` + // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) + ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -91,6 +107,15 @@ type GatewayConfig struct { MaxConnsPerHost int `mapstructure:"max_conns_per_host"` // IdleConnTimeoutSeconds: 空闲连接超时时间(秒) IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"` + // MaxUpstreamClients: 上游连接池客户端最大缓存数量 + // 当使用连接池隔离策略时,系统会为不同的账户/代理组合创建独立的 HTTP 客户端 + // 此参数限制缓存的客户端数量,超出后会淘汰最久未使用的客户端 + // 建议值:预估的活跃账户数 * 1.2(留有余量) + MaxUpstreamClients int `mapstructure:"max_upstream_clients"` + // ClientIdleTTLSeconds: 上游连接池客户端空闲回收阈值(秒) + // 超过此时间未使用的客户端会被标记为可回收 + // 建议值:根据用户访问频率设置,一般 10-30 分钟 + ClientIdleTTLSeconds int `mapstructure:"client_idle_ttl_seconds"` // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` @@ -289,11 +314,14 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) - viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) - viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) - viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒) + viper.SetDefault("gateway.max_upstream_clients", 5000) + viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) // TokenRefresh @@ -354,6 +382,14 @@ func (c *Config) Validate() error { if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { + switch c.Gateway.ConnectionPoolIsolation { + case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: + default: + return fmt.Errorf("gateway.connection_pool_isolation must be one of: %s/%s/%s", + ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy) + } + } if c.Gateway.MaxIdleConns <= 0 { return fmt.Errorf("gateway.max_idle_conns must be positive") } @@ -366,6 +402,12 @@ func (c *Config) Validate() error { if c.Gateway.IdleConnTimeoutSeconds <= 0 { return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") } + if c.Gateway.MaxUpstreamClients <= 0 { + return fmt.Errorf("gateway.max_upstream_clients must be positive") + } + if c.Gateway.ClientIdleTTLSeconds <= 0 { + return fmt.Errorf("gateway.client_idle_ttl_seconds must be positive") + } if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") } diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 0ca85a09..e7ae46dc 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -1,106 +1,553 @@ package repository import ( + "fmt" + "io" "net/http" "net/url" "strings" "sync" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" ) +// 默认配置常量 +// 这些值在配置文件未指定时作为回退默认值使用 +const ( + // directProxyKey: 无代理时的缓存键标识 + directProxyKey = "direct" + // defaultMaxIdleConns: 默认最大空闲连接总数 + // HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发 + defaultMaxIdleConns = 240 + // defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数 + defaultMaxIdleConnsPerHost = 120 + // defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接) + // 达到上限后新请求会等待,而非无限创建连接 + defaultMaxConnsPerHost = 240 + // defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟) + // 超时后连接会被关闭,释放系统资源 + defaultIdleConnTimeout = 300 * time.Second + // defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟) + // LLM 请求可能排队较久,需要较长超时 + defaultResponseHeaderTimeout = 300 * time.Second + // defaultMaxUpstreamClients: 默认最大客户端缓存数量 + // 超出后会淘汰最久未使用的客户端 + defaultMaxUpstreamClients = 5000 + // defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟) + defaultClientIdleTTLSeconds = 900 +) + +// poolSettings 连接池配置参数 +// 封装 Transport 所需的各项连接池参数 +type poolSettings struct { + maxIdleConns int // 最大空闲连接总数 + maxIdleConnsPerHost int // 每主机最大空闲连接数 + maxConnsPerHost int // 每主机最大连接数(含活跃) + idleConnTimeout time.Duration // 空闲连接超时时间 + responseHeaderTimeout time.Duration // 等待响应头超时时间 +} + +// upstreamClientEntry 上游客户端缓存条目 +// 记录客户端实例及其元数据,用于连接池管理和淘汰策略 +type upstreamClientEntry struct { + client *http.Client // HTTP 客户端实例 + proxyKey string // 代理标识(用于检测代理变更) + poolKey string // 连接池配置标识(用于检测配置变更) + lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 + inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 +} + // httpUpstreamService 通用 HTTP 上游服务 // 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理 // +// 架构设计: +// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例 +// - 每个客户端拥有独立的 Transport 连接池 +// - 支持 LRU + 空闲时间双重淘汰策略 +// // 性能优化: -// 1. 使用 sync.Map 缓存代理客户端实例,避免每次请求都创建新的 http.Client +// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client // 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销 -// 3. 原实现每次请求都 new 一个 http.Client,导致连接无法复用 +// 3. 支持账号级隔离与空闲回收,降低连接层关联风险 +// 4. 达到最大连接数后等待可用连接,而非无限创建 +// 5. 仅回收空闲客户端,避免中断活跃请求 +// 6. HTTP/2 多路复用,连接上限不等于并发请求上限 +// 7. 代理变更时清空旧连接池,避免复用错误代理 +// 8. 账号并发数与连接池上限对应(账号隔离策略下) type httpUpstreamService struct { - // defaultClient: 无代理时使用的默认客户端(单例复用) - defaultClient *http.Client - // proxyClients: 按代理 URL 缓存的客户端池,避免重复创建 - proxyClients sync.Map - cfg *config.Config + cfg *config.Config // 全局配置 + mu sync.RWMutex // 保护 clients map 的读写锁 + clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定 } // NewHTTPUpstream 创建通用 HTTP 上游服务 // 使用配置中的连接池参数构建 Transport +// +// 参数: +// - cfg: 全局配置,包含连接池参数和隔离策略 +// +// 返回: +// - service.HTTPUpstream 接口实现 func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { return &httpUpstreamService{ - defaultClient: &http.Client{Transport: buildUpstreamTransport(cfg, nil)}, - cfg: cfg, + cfg: cfg, + clients: make(map[string]*upstreamClientEntry), } } -func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) { - if strings.TrimSpace(proxyURL) == "" { - return s.defaultClient.Do(req) - } - client := s.getOrCreateClient(proxyURL) - return client.Do(req) -} +// Do 执行 HTTP 请求 +// 根据隔离策略获取或创建客户端,并跟踪请求生命周期 +// +// 参数: +// - req: HTTP 请求对象 +// - proxyURL: 代理地址,空字符串表示直连 +// - accountID: 账户 ID,用于账户级隔离 +// - accountConcurrency: 账户并发限制,用于动态调整连接池大小 +// +// 返回: +// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数) +// - error: 请求错误 +// +// 注意: +// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏 +// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断 +func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + // 获取或创建对应的客户端,并标记请求占用 + entry := s.acquireClient(proxyURL, accountID, accountConcurrency) -// getOrCreateClient 获取或创建代理客户端 -// 性能优化:使用 sync.Map 实现无锁缓存,相同代理 URL 复用同一客户端 -// LoadOrStore 保证并发安全,避免重复创建 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string) *http.Client { - proxyURL = strings.TrimSpace(proxyURL) - if proxyURL == "" { - return s.defaultClient - } - // 优先从缓存获取,命中则直接返回 - if cached, ok := s.proxyClients.Load(proxyURL); ok { - return cached.(*http.Client) - } - - parsedURL, err := url.Parse(proxyURL) + // 执行请求 + resp, err := entry.client.Do(req) if err != nil { - return s.defaultClient + // 请求失败,立即减少计数 + atomic.AddInt64(&entry.inFlight, -1) + return nil, err } - // 创建新客户端并缓存,LoadOrStore 保证只有一个实例被存储 - client := &http.Client{Transport: buildUpstreamTransport(s.cfg, parsedURL)} - actual, _ := s.proxyClients.LoadOrStore(proxyURL, client) - return actual.(*http.Client) + // 包装响应体,在关闭时自动减少计数并更新时间戳 + // 这确保了流式响应(如 SSE)在完全读取前不会被淘汰 + resp.Body = wrapTrackedBody(resp.Body, func() { + atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) + }) + + return resp, nil +} + +// acquireClient 获取或创建客户端,并标记为进行中请求 +// 用于请求路径,避免在获取后被淘汰 +func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, true) +} + +// getOrCreateClient 获取或创建客户端 +// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更 +// +// 参数: +// - proxyURL: 代理地址 +// - accountID: 账户 ID +// - accountConcurrency: 账户并发限制 +// +// 返回: +// - *upstreamClientEntry: 客户端缓存条目 +// +// 隔离策略说明: +// - proxy: 按代理地址隔离,同一代理共享客户端 +// - account: 按账户隔离,同一账户共享客户端(代理变更时重建) +// - account_proxy: 按账户+代理组合隔离,最细粒度 +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, false) +} + +// getClientEntry 获取或创建客户端条目 +// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰 +func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool) *upstreamClientEntry { + // 获取隔离模式 + isolation := s.getIsolationMode() + // 标准化代理 URL 并解析 + proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + // 构建缓存键(根据隔离策略不同) + cacheKey := buildCacheKey(isolation, proxyKey, accountID) + // 构建连接池配置键(用于检测配置变更) + poolKey := s.buildPoolKey(isolation, accountConcurrency) + + now := time.Now() + nowUnix := now.UnixNano() + + // 读锁快速路径:命中缓存直接返回,减少锁竞争 + s.mu.RLock() + if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.RUnlock() + return entry + } + s.mu.RUnlock() + + // 写锁慢路径:创建或重建客户端 + s.mu.Lock() + if entry, ok := s.clients[cacheKey]; ok { + if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.Unlock() + return entry + } + s.removeClientLocked(cacheKey, entry) + } + + // 缓存未命中或需要重建,创建新客户端 + settings := s.resolvePoolSettings(isolation, accountConcurrency) + client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)} + entry := &upstreamClientEntry{ + client: client, + proxyKey: proxyKey, + poolKey: poolKey, + } + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.StoreInt64(&entry.inFlight, 1) + } + s.clients[cacheKey] = entry + + // 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的 + s.evictIdleLocked(now) + s.evictOverLimitLocked() + s.mu.Unlock() + return entry +} + +// shouldReuseEntry 判断缓存条目是否可复用 +// 若代理或连接池配置发生变化,则需要重建客户端 +func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool { + if entry == nil { + return false + } + if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey { + return false + } + if entry.poolKey != poolKey { + return false + } + return true +} + +// removeClientLocked 移除客户端(需持有锁) +// 从缓存中删除并关闭空闲连接 +// +// 参数: +// - key: 缓存键 +// - entry: 客户端条目 +func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) { + delete(s.clients, key) + if entry != nil && entry.client != nil { + // 关闭空闲连接,释放系统资源 + // 注意:这不会中断活跃连接 + entry.client.CloseIdleConnections() + } +} + +// evictIdleLocked 淘汰空闲超时的客户端(需持有锁) +// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目 +// +// 参数: +// - now: 当前时间 +func (s *httpUpstreamService) evictIdleLocked(now time.Time) { + ttl := s.clientIdleTTL() + if ttl <= 0 { + return + } + // 计算淘汰截止时间 + cutoff := now.Add(-ttl).UnixNano() + for key, entry := range s.clients { + // 跳过有活跃请求的客户端 + if atomic.LoadInt64(&entry.inFlight) != 0 { + continue + } + // 淘汰超时的空闲客户端 + if atomic.LoadInt64(&entry.lastUsed) <= cutoff { + s.removeClientLocked(key, entry) + } + } +} + +// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁) +// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端 +func (s *httpUpstreamService) evictOverLimitLocked() { + maxClients := s.maxUpstreamClients() + if maxClients <= 0 { + return + } + // 循环淘汰直到满足数量限制 + for len(s.clients) > maxClients { + var ( + oldestKey string + oldestEntry *upstreamClientEntry + oldestTime int64 + ) + // 查找最久未使用且无活跃请求的客户端 + for key, entry := range s.clients { + // 跳过有活跃请求的客户端 + if atomic.LoadInt64(&entry.inFlight) != 0 { + continue + } + lastUsed := atomic.LoadInt64(&entry.lastUsed) + if oldestEntry == nil || lastUsed < oldestTime { + oldestKey = key + oldestEntry = entry + oldestTime = lastUsed + } + } + // 所有客户端都有活跃请求,无法淘汰 + if oldestEntry == nil { + return + } + s.removeClientLocked(oldestKey, oldestEntry) + } +} + +// getIsolationMode 获取连接池隔离模式 +// 从配置中读取,无效值回退到 account_proxy 模式 +// +// 返回: +// - string: 隔离模式(proxy/account/account_proxy) +func (s *httpUpstreamService) getIsolationMode() string { + if s.cfg == nil { + return config.ConnectionPoolIsolationAccountProxy + } + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation)) + if mode == "" { + return config.ConnectionPoolIsolationAccountProxy + } + switch mode { + case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy: + return mode + default: + return config.ConnectionPoolIsolationAccountProxy + } +} + +// maxUpstreamClients 获取最大客户端缓存数量 +// 从配置中读取,无效值使用默认值 +func (s *httpUpstreamService) maxUpstreamClients() int { + if s.cfg == nil { + return defaultMaxUpstreamClients + } + if s.cfg.Gateway.MaxUpstreamClients > 0 { + return s.cfg.Gateway.MaxUpstreamClients + } + return defaultMaxUpstreamClients +} + +// clientIdleTTL 获取客户端空闲回收阈值 +// 从配置中读取,无效值使用默认值 +func (s *httpUpstreamService) clientIdleTTL() time.Duration { + if s.cfg == nil { + return time.Duration(defaultClientIdleTTLSeconds) * time.Second + } + if s.cfg.Gateway.ClientIdleTTLSeconds > 0 { + return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second + } + return time.Duration(defaultClientIdleTTLSeconds) * time.Second +} + +// resolvePoolSettings 解析连接池配置 +// 根据隔离策略和账户并发数动态调整连接池参数 +// +// 参数: +// - isolation: 隔离模式 +// - accountConcurrency: 账户并发限制 +// +// 返回: +// - poolSettings: 连接池配置 +// +// 说明: +// - 账户隔离模式下,连接池大小与账户并发数对应 +// - 这确保了单账户不会占用过多连接资源 +func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings { + settings := defaultPoolSettings(s.cfg) + // 账户隔离模式下,根据账户并发数调整连接池大小 + if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 { + settings.maxIdleConns = accountConcurrency + settings.maxIdleConnsPerHost = accountConcurrency + settings.maxConnsPerHost = accountConcurrency + } + return settings +} + +// buildPoolKey 构建连接池配置键 +// 用于检测配置变更,配置变更时需要重建客户端 +// +// 参数: +// - isolation: 隔离模式 +// - accountConcurrency: 账户并发限制 +// +// 返回: +// - string: 配置键 +func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string { + if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy { + if accountConcurrency > 0 { + return fmt.Sprintf("account:%d", accountConcurrency) + } + } + return "default" +} + +// buildCacheKey 构建客户端缓存键 +// 根据隔离策略决定缓存键的组成 +// +// 参数: +// - isolation: 隔离模式 +// - proxyKey: 代理标识 +// - accountID: 账户 ID +// +// 返回: +// - string: 缓存键 +// +// 缓存键格式: +// - proxy 模式: "proxy:{proxyKey}" +// - account 模式: "account:{accountID}" +// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}" +func buildCacheKey(isolation, proxyKey string, accountID int64) string { + switch isolation { + case config.ConnectionPoolIsolationAccount: + return fmt.Sprintf("account:%d", accountID) + case config.ConnectionPoolIsolationAccountProxy: + return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) + default: + return fmt.Sprintf("proxy:%s", proxyKey) + } +} + +// normalizeProxyURL 标准化代理 URL +// 处理空值和解析错误,返回标准化的键和解析后的 URL +// +// 参数: +// - raw: 原始代理 URL 字符串 +// +// 返回: +// - string: 标准化的代理键(空或解析失败返回 "direct") +// - *url.URL: 解析后的 URL(空或解析失败返回 nil) +func normalizeProxyURL(raw string) (string, *url.URL) { + proxyURL := strings.TrimSpace(raw) + if proxyURL == "" { + return directProxyKey, nil + } + parsed, err := url.Parse(proxyURL) + if err != nil { + return directProxyKey, nil + } + return proxyURL, parsed +} + +// defaultPoolSettings 获取默认连接池配置 +// 从全局配置中读取,无效值使用常量默认值 +// +// 参数: +// - cfg: 全局配置 +// +// 返回: +// - poolSettings: 连接池配置 +func defaultPoolSettings(cfg *config.Config) poolSettings { + maxIdleConns := defaultMaxIdleConns + maxIdleConnsPerHost := defaultMaxIdleConnsPerHost + maxConnsPerHost := defaultMaxConnsPerHost + idleConnTimeout := defaultIdleConnTimeout + responseHeaderTimeout := defaultResponseHeaderTimeout + + if cfg != nil { + if cfg.Gateway.MaxIdleConns > 0 { + maxIdleConns = cfg.Gateway.MaxIdleConns + } + if cfg.Gateway.MaxIdleConnsPerHost > 0 { + maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost + } + if cfg.Gateway.MaxConnsPerHost >= 0 { + maxConnsPerHost = cfg.Gateway.MaxConnsPerHost + } + if cfg.Gateway.IdleConnTimeoutSeconds > 0 { + idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second + } + if cfg.Gateway.ResponseHeaderTimeout > 0 { + responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second + } + } + + return poolSettings{ + maxIdleConns: maxIdleConns, + maxIdleConnsPerHost: maxIdleConnsPerHost, + maxConnsPerHost: maxConnsPerHost, + idleConnTimeout: idleConnTimeout, + responseHeaderTimeout: responseHeaderTimeout, + } } // buildUpstreamTransport 构建上游请求的 Transport // 使用配置文件中的连接池参数,支持生产环境调优 -func buildUpstreamTransport(cfg *config.Config, proxyURL *url.URL) *http.Transport { - // 读取配置,使用合理的默认值 - maxIdleConns := cfg.Gateway.MaxIdleConns - if maxIdleConns <= 0 { - maxIdleConns = 240 - } - maxIdleConnsPerHost := cfg.Gateway.MaxIdleConnsPerHost - if maxIdleConnsPerHost <= 0 { - maxIdleConnsPerHost = 120 - } - maxConnsPerHost := cfg.Gateway.MaxConnsPerHost - if maxConnsPerHost < 0 { - maxConnsPerHost = 240 - } - idleConnTimeout := time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second - if idleConnTimeout <= 0 { - idleConnTimeout = 300 * time.Second - } - responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout <= 0 { - responseHeaderTimeout = 300 * time.Second - } - +// +// 参数: +// - settings: 连接池配置 +// - proxyURL: 代理 URL(nil 表示直连) +// +// 返回: +// - *http.Transport: 配置好的 Transport 实例 +// +// Transport 参数说明: +// - MaxIdleConns: 所有主机的最大空闲连接总数 +// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率) +// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) +// - IdleConnTimeout: 空闲连接超时(超时后关闭) +// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) +func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport { transport := &http.Transport{ - MaxIdleConns: maxIdleConns, // 最大空闲连接总数 - MaxIdleConnsPerHost: maxIdleConnsPerHost, // 每主机最大空闲连接 - MaxConnsPerHost: maxConnsPerHost, // 每主机最大连接数(含活跃) - IdleConnTimeout: idleConnTimeout, // 空闲连接超时 - ResponseHeaderTimeout: responseHeaderTimeout, + MaxIdleConns: settings.maxIdleConns, + MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, + MaxConnsPerHost: settings.maxConnsPerHost, + IdleConnTimeout: settings.idleConnTimeout, + ResponseHeaderTimeout: settings.responseHeaderTimeout, } if proxyURL != nil { transport.Proxy = http.ProxyURL(proxyURL) } return transport } + +// trackedBody 带跟踪功能的响应体包装器 +// 在 Close 时执行回调,用于更新请求计数 +type trackedBody struct { + io.ReadCloser // 原始响应体 + once sync.Once + onClose func() // 关闭时的回调函数 +} + +// Close 关闭响应体并执行回调 +// 使用 sync.Once 确保回调只执行一次 +func (b *trackedBody) Close() error { + err := b.ReadCloser.Close() + if b.onClose != nil { + b.once.Do(b.onClose) + } + return err +} + +// wrapTrackedBody 包装响应体以跟踪关闭事件 +// 用于在响应体关闭时更新 inFlight 计数 +// +// 参数: +// - body: 原始响应体 +// - onClose: 关闭时的回调函数 +// +// 返回: +// - io.ReadCloser: 包装后的响应体 +func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser { + if body == nil { + return body + } + return &trackedBody{ReadCloser: body, onClose: onClose} +} diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 2ea6e31a..3219c6da 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -8,10 +8,21 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" ) +// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作 +// 这是 Go 基准测试的常见模式,确保测试结果准确 var httpClientSink *http.Client -// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销。 +// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销 +// +// 测试目的: +// - 验证连接池复用相比每次新建的性能提升 +// - 量化内存分配差异 +// +// 预期结果: +// - "复用" 子测试应显著快于 "新建" +// - "复用" 子测试应零内存分配 func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { + // 创建测试配置 cfg := &config.Config{ Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300}, } @@ -22,24 +33,33 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { } proxyURL := "http://127.0.0.1:8080" - b.ReportAllocs() + b.ReportAllocs() // 报告内存分配统计 + // 子测试:每次新建客户端 + // 模拟未优化前的行为,每次请求都创建新的 http.Client b.Run("新建", func(b *testing.B) { parsedProxy, err := url.Parse(proxyURL) if err != nil { b.Fatalf("解析代理地址失败: %v", err) } + settings := defaultPoolSettings(cfg) for i := 0; i < b.N; i++ { + // 每次迭代都创建新客户端,包含 Transport 分配 httpClientSink = &http.Client{ - Transport: buildUpstreamTransport(cfg, parsedProxy), + Transport: buildUpstreamTransport(settings, parsedProxy), } } }) + // 子测试:复用已缓存的客户端 + // 模拟优化后的行为,从缓存获取客户端 b.Run("复用", func(b *testing.B) { - client := svc.getOrCreateClient(proxyURL) - b.ResetTimer() + // 预热:确保客户端已缓存 + entry := svc.getOrCreateClient(proxyURL, 1, 1) + client := entry.client + b.ResetTimer() // 重置计时器,排除预热时间 for i := 0; i < b.N; i++ { + // 直接使用缓存的客户端,无内存分配 httpClientSink = client } }) diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index 74132e1d..763b254f 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" @@ -12,45 +13,61 @@ import ( "github.com/stretchr/testify/suite" ) +// HTTPUpstreamSuite HTTP 上游服务测试套件 +// 使用 testify/suite 组织测试,支持 SetupTest 初始化 type HTTPUpstreamSuite struct { suite.Suite - cfg *config.Config + cfg *config.Config // 测试用配置 } +// SetupTest 每个测试用例执行前的初始化 +// 创建空配置,各测试用例可按需覆盖 func (s *HTTPUpstreamSuite) SetupTest() { s.cfg = &config.Config{} } -func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { +// newService 创建测试用的 httpUpstreamService 实例 +// 返回具体类型以便访问内部状态进行断言 +func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { up := NewHTTPUpstream(s.cfg) svc, ok := up.(*httpUpstreamService) require.True(s.T(), ok, "expected *httpUpstreamService") - transport, ok := svc.defaultClient.Transport.(*http.Transport) + return svc +} + +// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置 +// 验证未配置时使用 300 秒默认值 +func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { + svc := s.newService() + entry := svc.getOrCreateClient("", 0, 0) + transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } +// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置 +// 验证配置值能正确应用到 Transport func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} - up := NewHTTPUpstream(s.cfg) - svc, ok := up.(*httpUpstreamService) - require.True(s.T(), ok, "expected *httpUpstreamService") - transport, ok := svc.defaultClient.Transport.(*http.Transport) + svc := s.newService() + entry := svc.getOrCreateClient("", 0, 0) + transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDefault() { - s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5} - up := NewHTTPUpstream(s.cfg) - svc, ok := up.(*httpUpstreamService) - require.True(s.T(), ok, "expected *httpUpstreamService") - - got := svc.getOrCreateClient("://bad-proxy-url") - require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback") +// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退 +// 验证解析失败时回退到直连模式 +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() { + svc := s.newService() + entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1) + require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") } +// TestDo_WithoutProxy_GoesDirect 测试无代理时直连 +// 验证空代理 URL 时请求直接发送到目标服务器 func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { + // 创建模拟上游服务器 upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "direct") })) @@ -60,17 +77,21 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil) require.NoError(s.T(), err, "NewRequest") - resp, err := up.Do(req, "") + resp, err := up.Do(req, "", 1, 1) require.NoError(s.T(), err, "Do") defer func() { _ = resp.Body.Close() }() b, _ := io.ReadAll(resp.Body) require.Equal(s.T(), "direct", string(b), "unexpected body") } +// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能 +// 验证请求通过代理服务器转发,使用绝对 URI 格式 func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { + // 用于接收代理请求的通道 seen := make(chan string, 1) + // 创建模拟代理服务器 proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - seen <- r.RequestURI + seen <- r.RequestURI // 记录请求 URI _, _ = io.WriteString(w, "proxied") })) s.T().Cleanup(proxySrv.Close) @@ -78,14 +99,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1} up := NewHTTPUpstream(s.cfg) + // 发送请求到外部地址,应通过代理 req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(s.T(), err, "NewRequest") - resp, err := up.Do(req, proxySrv.URL) + resp, err := up.Do(req, proxySrv.URL, 1, 1) require.NoError(s.T(), err, "Do") defer func() { _ = resp.Body.Close() }() b, _ := io.ReadAll(resp.Body) require.Equal(s.T(), "proxied", string(b), "unexpected body") + // 验证代理收到的是绝对 URI 格式(HTTP 代理规范要求) select { case uri := <-seen: require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI") @@ -94,6 +117,8 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { } } +// TestDo_EmptyProxy_UsesDirect 测试空代理字符串 +// 验证空字符串代理等同于直连 func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, "direct-empty") @@ -103,13 +128,134 @@ func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { up := NewHTTPUpstream(s.cfg) req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil) require.NoError(s.T(), err, "NewRequest") - resp, err := up.Do(req, "") + resp, err := up.Do(req, "", 1, 1) require.NoError(s.T(), err, "Do with empty proxy") defer func() { _ = resp.Body.Close() }() b, _ := io.ReadAll(resp.Body) require.Equal(s.T(), "direct-empty", string(b)) } +// TestAccountIsolation_DifferentAccounts 测试账户隔离模式 +// 验证不同账户使用独立的连接池 +func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() { + s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} + svc := s.newService() + // 同一代理,不同账户 + entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3) + entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3) + require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池") + require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端") +} + +// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式 +// 验证同一账户使用不同代理时创建独立连接池 +func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() { + s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy} + svc := s.newService() + // 同一账户,不同代理 + entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) + entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理") + require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端") +} + +// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更 +// 验证账户切换代理时清理旧连接池,避免复用错误代理 +func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() { + s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} + svc := s.newService() + // 同一账户,先后使用不同代理 + entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) + entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池") + require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池") + require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理") +} + +// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置 +// 验证账户隔离模式下,连接池大小与账户并发数对应 +func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() { + s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} + svc := s.newService() + // 账户并发数为 12 + entry := svc.getOrCreateClient("", 1, 12) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + // 连接池参数应与并发数一致 + require.Equal(s.T(), 12, transport.MaxConnsPerHost, "MaxConnsPerHost mismatch") + require.Equal(s.T(), 12, transport.MaxIdleConns, "MaxIdleConns mismatch") + require.Equal(s.T(), 12, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost mismatch") +} + +// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置 +// 验证未指定并发数时使用全局配置值 +func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() { + s.cfg.Gateway = config.GatewayConfig{ + ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount, + MaxIdleConns: 77, + MaxIdleConnsPerHost: 55, + MaxConnsPerHost: 66, + } + svc := s.newService() + // 账户并发数为 0,应使用全局配置 + entry := svc.getOrCreateClient("", 1, 0) + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch") + require.Equal(s.T(), 77, transport.MaxIdleConns, "MaxIdleConns fallback mismatch") + require.Equal(s.T(), 55, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost fallback mismatch") +} + +// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰 +// 验证优先淘汰最久未使用的空闲客户端 +func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() { + s.cfg.Gateway = config.GatewayConfig{ + ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy, + MaxUpstreamClients: 2, // 最多缓存 2 个客户端 + } + svc := s.newService() + // 创建两个客户端,设置不同的最后使用时间 + entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1) + entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1) + atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久 + atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano()) + // 创建第三个客户端,触发淘汰 + _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1) + + require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内") + require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理") +} + +// TestIdleTTLDoesNotEvictActive 测试活跃请求保护 +// 验证有进行中请求的客户端不会被空闲超时淘汰 +func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { + s.cfg.Gateway = config.GatewayConfig{ + ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount, + ClientIdleTTLSeconds: 1, // 1 秒空闲超时 + } + svc := s.newService() + entry1 := svc.getOrCreateClient("", 1, 1) + // 设置为很久之前使用,但有活跃请求 + atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano()) + atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求 + // 创建新客户端,触发淘汰检查 + _ = svc.getOrCreateClient("", 2, 1) + + require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") +} + +// TestHTTPUpstreamSuite 运行测试套件 func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) } + +// hasEntry 检查客户端是否存在于缓存中 +// 辅助函数,用于验证淘汰逻辑 +func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool { + for _, entry := range svc.clients { + if entry == target { + return true + } + } + return false +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 6296f2fe..bfa9b60f 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -256,7 +256,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL) + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -371,7 +371,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL) + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -442,7 +442,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL) + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 18a67fdf..25d9066b 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -230,7 +230,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) upstreamReq.Header.Set("User-Agent", antigravity.UserAgent) - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { if attempt < antigravityMaxRetries { log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) @@ -380,7 +380,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) upstreamReq.Header.Set("User-Agent", antigravity.UserAgent) - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { if attempt < antigravityMaxRetries { log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 41362662..dd879da2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -644,7 +644,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 发送请求 - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { return nil, fmt.Errorf("upstream request failed: %w", err) } @@ -1308,7 +1308,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 发送请求 - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL) + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") return fmt.Errorf("upstream request failed: %w", err) diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 34958541..ee3ade16 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -472,7 +472,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } requestIDHeader = idHeader - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { if attempt < geminiMaxRetries { log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) @@ -725,7 +725,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. } requestIDHeader = idHeader - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { if attempt < geminiMaxRetries { log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) @@ -1756,7 +1756,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, fmt.Errorf("unsupported account type: %s", account.Type) } - resp, err := s.httpUpstream.Do(req, proxyURL) + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { return nil, err } diff --git a/backend/internal/service/http_upstream_port.go b/backend/internal/service/http_upstream_port.go index 7fb9407f..9357f763 100644 --- a/backend/internal/service/http_upstream_port.go +++ b/backend/internal/service/http_upstream_port.go @@ -2,8 +2,29 @@ package service import "net/http" -// HTTPUpstream interface for making HTTP requests to upstream APIs (Claude, OpenAI, etc.) -// This is a generic interface that can be used for any HTTP-based upstream service. +// HTTPUpstream 上游 HTTP 请求接口 +// 用于向上游 API(Claude、OpenAI、Gemini 等)发送请求 +// 这是一个通用接口,可用于任何基于 HTTP 的上游服务 +// +// 设计说明: +// - 支持可选代理配置 +// - 支持账户级连接池隔离 +// - 实现类负责连接池管理和复用 type HTTPUpstream interface { - Do(req *http.Request, proxyURL string) (*http.Response, error) + // Do 执行 HTTP 请求 + // + // 参数: + // - req: HTTP 请求对象,由调用方构建 + // - proxyURL: 代理服务器地址,空字符串表示直连 + // - accountID: 账户 ID,用于连接池隔离(隔离策略为 account 或 account_proxy 时生效) + // - accountConcurrency: 账户并发限制,用于动态调整连接池大小 + // + // 返回: + // - *http.Response: HTTP 响应,调用方必须关闭 Body + // - error: 请求错误(网络错误、超时等) + // + // 注意: + // - 调用方必须关闭 resp.Body,否则会导致连接泄漏 + // - 响应体可能已被包装以跟踪请求生命周期 + Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index aa844554..769d0c3c 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -311,7 +311,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Send request - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL) + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { return nil, fmt.Errorf("upstream request failed: %w", err) } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 8db4cbc9..5bd85d7d 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -29,11 +29,21 @@ gateway: response_header_timeout: 300 # 请求体最大字节数(默认 100MB) max_body_size: 104857600 + # 连接池隔离策略: + # - proxy: 按代理隔离,同一代理共享连接池(适合代理少、账户多) + # - account: 按账户隔离,同一账户共享连接池(适合账户少、需严格隔离) + # - account_proxy: 按账户+代理组合隔离(默认,最细粒度) + connection_pool_isolation: "account_proxy" # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认) max_idle_conns: 240 max_idle_conns_per_host: 120 max_conns_per_host: 240 idle_conn_timeout_seconds: 300 + # 上游连接池客户端缓存配置 + # max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的 + # client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收 + max_upstream_clients: 5000 + client_idle_ttl_seconds: 900 # 并发槽位过期时间(分钟) concurrency_slot_ttl_minutes: 15 From 820bb16ca751dca6c399423963ebc3365f74b3a6 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 12:01:31 +0800 Subject: [PATCH 05/49] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E9=98=B2?= =?UTF-8?q?=E6=AD=A2=E8=BF=9E=E6=8E=A5=E6=B1=A0=E7=BC=93=E5=AD=98=E5=A4=B1?= =?UTF-8?q?=E6=8E=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 超限且无可淘汰条目时拒绝新建 规范化代理地址并更新失败时的访问时间 补充连接池上限与代理规范化测试 --- backend/internal/repository/http_upstream.go | 117 +++++++++++++----- .../internal/repository/http_upstream_test.go | 25 ++++ 2 files changed, 109 insertions(+), 33 deletions(-) diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index e7ae46dc..061866b1 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -1,8 +1,10 @@ package repository import ( + "errors" "fmt" "io" + "net" "net/http" "net/url" "strings" @@ -40,6 +42,8 @@ const ( defaultClientIdleTTLSeconds = 900 ) +var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached") + // poolSettings 连接池配置参数 // 封装 Transport 所需的各项连接池参数 type poolSettings struct { @@ -116,13 +120,17 @@ func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { // - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断 func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { // 获取或创建对应的客户端,并标记请求占用 - entry := s.acquireClient(proxyURL, accountID, accountConcurrency) + entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) + if err != nil { + return nil, err + } // 执行请求 resp, err := entry.client.Do(req) if err != nil { // 请求失败,立即减少计数 atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) return nil, err } @@ -138,8 +146,8 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i // acquireClient 获取或创建客户端,并标记为进行中请求 // 用于请求路径,避免在获取后被淘汰 -func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, true) +func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true) } // getOrCreateClient 获取或创建客户端 @@ -158,12 +166,14 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, false) + entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) + return entry } // getClientEntry 获取或创建客户端条目 // markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰 -func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool) *upstreamClientEntry { +// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误 +func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 @@ -184,7 +194,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a atomic.AddInt64(&entry.inFlight, 1) } s.mu.RUnlock() - return entry + return entry, nil } s.mu.RUnlock() @@ -197,11 +207,22 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a atomic.AddInt64(&entry.inFlight, 1) } s.mu.Unlock() - return entry + return entry, nil } s.removeClientLocked(cacheKey, entry) } + // 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建 + if enforceLimit && s.maxUpstreamClients() > 0 { + s.evictIdleLocked(now) + if len(s.clients) >= s.maxUpstreamClients() { + if !s.evictOldestIdleLocked() { + s.mu.Unlock() + return nil, errUpstreamClientLimitReached + } + } + } + // 缓存未命中或需要重建,创建新客户端 settings := s.resolvePoolSettings(isolation, accountConcurrency) client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)} @@ -220,7 +241,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a s.evictIdleLocked(now) s.evictOverLimitLocked() s.mu.Unlock() - return entry + return entry, nil } // shouldReuseEntry 判断缓存条目是否可复用 @@ -277,39 +298,50 @@ func (s *httpUpstreamService) evictIdleLocked(now time.Time) { } } +// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁) +func (s *httpUpstreamService) evictOldestIdleLocked() bool { + var ( + oldestKey string + oldestEntry *upstreamClientEntry + oldestTime int64 + ) + // 查找最久未使用且无活跃请求的客户端 + for key, entry := range s.clients { + // 跳过有活跃请求的客户端 + if atomic.LoadInt64(&entry.inFlight) != 0 { + continue + } + lastUsed := atomic.LoadInt64(&entry.lastUsed) + if oldestEntry == nil || lastUsed < oldestTime { + oldestKey = key + oldestEntry = entry + oldestTime = lastUsed + } + } + // 所有客户端都有活跃请求,无法淘汰 + if oldestEntry == nil { + return false + } + s.removeClientLocked(oldestKey, oldestEntry) + return true +} + // evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁) // 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端 -func (s *httpUpstreamService) evictOverLimitLocked() { +func (s *httpUpstreamService) evictOverLimitLocked() bool { maxClients := s.maxUpstreamClients() if maxClients <= 0 { - return + return false } + evicted := false // 循环淘汰直到满足数量限制 for len(s.clients) > maxClients { - var ( - oldestKey string - oldestEntry *upstreamClientEntry - oldestTime int64 - ) - // 查找最久未使用且无活跃请求的客户端 - for key, entry := range s.clients { - // 跳过有活跃请求的客户端 - if atomic.LoadInt64(&entry.inFlight) != 0 { - continue - } - lastUsed := atomic.LoadInt64(&entry.lastUsed) - if oldestEntry == nil || lastUsed < oldestTime { - oldestKey = key - oldestEntry = entry - oldestTime = lastUsed - } + if !s.evictOldestIdleLocked() { + return evicted } - // 所有客户端都有活跃请求,无法淘汰 - if oldestEntry == nil { - return - } - s.removeClientLocked(oldestKey, oldestEntry) + evicted = true } + return evicted } // getIsolationMode 获取连接池隔离模式 @@ -443,7 +475,26 @@ func normalizeProxyURL(raw string) (string, *url.URL) { if err != nil { return directProxyKey, nil } - return proxyURL, parsed + parsed.Scheme = strings.ToLower(parsed.Scheme) + parsed.Host = strings.ToLower(parsed.Host) + parsed.Path = "" + parsed.RawPath = "" + parsed.RawQuery = "" + parsed.Fragment = "" + parsed.ForceQuery = false + if hostname := parsed.Hostname(); hostname != "" { + port := parsed.Port() + if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") { + port = "" + } + hostname = strings.ToLower(hostname) + if port != "" { + parsed.Host = net.JoinHostPort(hostname, port) + } else { + parsed.Host = hostname + } + } + return parsed.String(), parsed } // defaultPoolSettings 获取默认连接池配置 diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index 763b254f..70676b7a 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -64,6 +64,31 @@ func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") } +// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 +// 验证等价地址能够映射到同一缓存键 +func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { + key1, _ := normalizeProxyURL("http://proxy.local:8080") + key2, _ := normalizeProxyURL("http://proxy.local:8080/") + require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match") +} + +// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护 +// 验证超限且无可淘汰条目时返回错误 +func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() { + s.cfg.Gateway = config.GatewayConfig{ + ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy, + MaxUpstreamClients: 1, + } + svc := s.newService() + entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1) + require.NoError(s.T(), err, "expected first acquire to succeed") + require.NotNil(s.T(), entry1, "expected entry") + + entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1) + require.Error(s.T(), err, "expected error when cache limit reached") + require.Nil(s.T(), entry2, "expected nil entry when cache limit reached") +} + // TestDo_WithoutProxy_GoesDirect 测试无代理时直连 // 验证空代理 URL 时请求直接发送到目标服务器 func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { From 5906f9ab9843bc3e115a1b7e272de5ccd2719b41 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 14:11:57 +0800 Subject: [PATCH 06/49] =?UTF-8?q?fix(=E6=95=B0=E6=8D=AE=E5=B1=82):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE=E5=AE=8C=E6=95=B4=E6=80=A7?= =?UTF-8?q?=E4=B8=8E=E4=BB=93=E5=82=A8=E4=B8=80=E8=87=B4=E6=80=A7=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 数据完整性修复 (fix-critical-data-integrity) - 添加 error_translate.go 统一错误转换层 - 修复 nil 输入和 NotFound 错误处理 - 增强仓储层错误一致性 ## 仓储一致性修复 (fix-high-repository-consistency) - Group schema 添加 default_validity_days 字段 - Account schema 添加 proxy edge 关联 - 新增 UsageLog ent schema 定义 - 修复 UpdateBalance/UpdateConcurrency 受影响行数校验 ## 数据卫生修复 (fix-medium-data-hygiene) - UserSubscription 添加软删除支持 (SoftDeleteMixin) - RedeemCode/Setting 添加硬删除策略文档 - account_groups/user_allowed_groups 的 created_at 声明 timestamptz - 停止写入 legacy users.allowed_groups 列 - 新增迁移: 011-014 (索引优化、软删除、孤立数据审计、列清理) ## 测试补充 - 添加 UserSubscription 软删除测试 - 添加迁移回归测试 - 添加 NotFound 错误测试 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .gitignore | 1 + backend/cmd/server/wire_gen.go | 2 +- backend/ent/account.go | 39 +- backend/ent/account/account.go | 53 + backend/ent/account/where.go | 66 +- backend/ent/account_create.go | 79 +- backend/ent/account_query.go | 154 +- backend/ent/account_update.go | 276 +- backend/ent/apikey.go | 18 +- backend/ent/apikey/apikey.go | 30 + backend/ent/apikey/where.go | 23 + backend/ent/apikey_create.go | 32 + backend/ent/apikey_query.go | 103 +- backend/ent/apikey_update.go | 163 + backend/ent/client.go | 353 +- backend/ent/ent.go | 2 + backend/ent/group.go | 39 +- backend/ent/group/group.go | 40 + backend/ent/group/where.go | 68 + backend/ent/group_create.go | 117 + backend/ent/group_query.go | 79 +- backend/ent/group_update.go | 217 ++ backend/ent/hook/hook.go | 12 + backend/ent/intercept/intercept.go | 30 + backend/ent/migrate/schema.go | 203 +- backend/ent/mutation.go | 3426 ++++++++++++++++- backend/ent/predicate/predicate.go | 3 + backend/ent/proxy.go | 28 +- backend/ent/proxy/proxy.go | 31 + backend/ent/proxy/where.go | 24 + backend/ent/proxy_create.go | 32 + backend/ent/proxy_query.go | 104 +- backend/ent/proxy_update.go | 163 + backend/ent/runtime/runtime.go | 111 + backend/ent/schema/account.go | 7 + backend/ent/schema/account_group.go | 4 +- backend/ent/schema/api_key.go | 3 +- backend/ent/schema/group.go | 5 +- backend/ent/schema/proxy.go | 10 + backend/ent/schema/redeem_code.go | 10 +- backend/ent/schema/setting.go | 14 +- backend/ent/schema/usage_log.go | 152 + backend/ent/schema/user.go | 3 +- backend/ent/schema/user_allowed_group.go | 4 +- backend/ent/schema/user_subscription.go | 3 + backend/ent/tx.go | 5 +- backend/ent/usagelog.go | 491 +++ backend/ent/usagelog/usagelog.go | 396 ++ backend/ent/usagelog/where.go | 1271 ++++++ backend/ent/usagelog_create.go | 2431 ++++++++++++ backend/ent/usagelog_delete.go | 88 + backend/ent/usagelog_query.go | 912 +++++ backend/ent/usagelog_update.go | 1800 +++++++++ backend/ent/user.go | 20 +- backend/ent/user/user.go | 30 + backend/ent/user/where.go | 23 + backend/ent/user_create.go | 32 + backend/ent/user_query.go | 76 +- backend/ent/user_update.go | 163 + backend/ent/usersubscription.go | 34 +- .../ent/usersubscription/usersubscription.go | 46 + backend/ent/usersubscription/where.go | 78 + backend/ent/usersubscription_create.go | 126 +- backend/ent/usersubscription_query.go | 80 +- backend/ent/usersubscription_update.go | 235 +- backend/internal/repository/account_repo.go | 96 +- backend/internal/repository/api_key_repo.go | 13 +- .../internal/repository/error_translate.go | 20 + backend/internal/repository/group_repo.go | 27 +- .../repository/group_repo_integration_test.go | 55 + .../migrations_schema_integration_test.go | 14 + backend/internal/repository/proxy_repo.go | 2 +- .../internal/repository/redeem_code_repo.go | 3 +- .../soft_delete_ent_integration_test.go | 103 + backend/internal/repository/usage_log_repo.go | 12 + backend/internal/repository/user_repo.go | 104 +- .../repository/user_repo_integration_test.go | 21 + .../repository/user_subscription_repo.go | 171 +- ...user_subscription_repo_integration_test.go | 247 ++ backend/internal/service/account_service.go | 1 + backend/internal/service/group.go | 9 +- backend/internal/service/redeem_service.go | 72 +- .../internal/service/subscription_service.go | 1 + .../011_remove_duplicate_unique_indexes.sql | 39 + .../012_add_user_subscription_soft_delete.sql | 13 + .../013_log_orphan_allowed_groups.sql | 32 + .../014_drop_legacy_allowed_groups.sql | 15 + 87 files changed, 15258 insertions(+), 485 deletions(-) create mode 100644 backend/ent/schema/usage_log.go create mode 100644 backend/ent/usagelog.go create mode 100644 backend/ent/usagelog/usagelog.go create mode 100644 backend/ent/usagelog/where.go create mode 100644 backend/ent/usagelog_create.go create mode 100644 backend/ent/usagelog_delete.go create mode 100644 backend/ent/usagelog_query.go create mode 100644 backend/ent/usagelog_update.go create mode 100644 backend/migrations/011_remove_duplicate_unique_indexes.sql create mode 100644 backend/migrations/012_add_user_subscription_soft_delete.sql create mode 100644 backend/migrations/013_log_orphan_allowed_groups.sql create mode 100644 backend/migrations/014_drop_legacy_allowed_groups.sql diff --git a/.gitignore b/.gitignore index 5a611909..390c8a03 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,4 @@ openspec/ docs/ code-reviews/ AGENTS.md +backend/cmd/server/server diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 664e7aca..ebbaa172 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -70,7 +70,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) diff --git a/backend/ent/account.go b/backend/ent/account.go index 59f55edb..82867111 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/proxy" ) // Account is the model entity for the Account schema. @@ -70,11 +71,15 @@ type Account struct { type AccountEdges struct { // Groups holds the value of the groups edge. Groups []*Group `json:"groups,omitempty"` + // Proxy holds the value of the proxy edge. + Proxy *Proxy `json:"proxy,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // AccountGroups holds the value of the account_groups edge. AccountGroups []*AccountGroup `json:"account_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [4]bool } // GroupsOrErr returns the Groups value or an error if the edge @@ -86,10 +91,30 @@ func (e AccountEdges) GroupsOrErr() ([]*Group, error) { return nil, &NotLoadedError{edge: "groups"} } +// ProxyOrErr returns the Proxy value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AccountEdges) ProxyOrErr() (*Proxy, error) { + if e.Proxy != nil { + return e.Proxy, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: proxy.Label} + } + return nil, &NotLoadedError{edge: "proxy"} +} + +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e AccountEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[2] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // AccountGroupsOrErr returns the AccountGroups value or an error if the edge // was not loaded in eager-loading. func (e AccountEdges) AccountGroupsOrErr() ([]*AccountGroup, error) { - if e.loadedTypes[1] { + if e.loadedTypes[3] { return e.AccountGroups, nil } return nil, &NotLoadedError{edge: "account_groups"} @@ -289,6 +314,16 @@ func (_m *Account) QueryGroups() *GroupQuery { return NewAccountClient(_m.config).QueryGroups(_m) } +// QueryProxy queries the "proxy" edge of the Account entity. +func (_m *Account) QueryProxy() *ProxyQuery { + return NewAccountClient(_m.config).QueryProxy(_m) +} + +// QueryUsageLogs queries the "usage_logs" edge of the Account entity. +func (_m *Account) QueryUsageLogs() *UsageLogQuery { + return NewAccountClient(_m.config).QueryUsageLogs(_m) +} + // QueryAccountGroups queries the "account_groups" edge of the Account entity. func (_m *Account) QueryAccountGroups() *AccountGroupQuery { return NewAccountClient(_m.config).QueryAccountGroups(_m) diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 65a130fd..c48db1e3 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -59,6 +59,10 @@ const ( FieldSessionWindowStatus = "session_window_status" // EdgeGroups holds the string denoting the groups edge name in mutations. EdgeGroups = "groups" + // EdgeProxy holds the string denoting the proxy edge name in mutations. + EdgeProxy = "proxy" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // EdgeAccountGroups holds the string denoting the account_groups edge name in mutations. EdgeAccountGroups = "account_groups" // Table holds the table name of the account in the database. @@ -68,6 +72,20 @@ const ( // GroupsInverseTable is the table name for the Group entity. // It exists in this package in order to avoid circular dependency with the "group" package. GroupsInverseTable = "groups" + // ProxyTable is the table that holds the proxy relation/edge. + ProxyTable = "accounts" + // ProxyInverseTable is the table name for the Proxy entity. + // It exists in this package in order to avoid circular dependency with the "proxy" package. + ProxyInverseTable = "proxies" + // ProxyColumn is the table column denoting the proxy relation/edge. + ProxyColumn = "proxy_id" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "account_id" // AccountGroupsTable is the table that holds the account_groups relation/edge. AccountGroupsTable = "account_groups" // AccountGroupsInverseTable is the table name for the AccountGroup entity. @@ -274,6 +292,27 @@ func ByGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByProxyField orders the results by proxy field. +func ByProxyField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newProxyStep(), sql.OrderByField(field, opts...)) + } +} + +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByAccountGroupsCount orders the results by account_groups count. func ByAccountGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -294,6 +333,20 @@ func newGroupsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2M, false, GroupsTable, GroupsPrimaryKey...), ) } +func newProxyStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ProxyInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, ProxyTable, ProxyColumn), + ) +} +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} func newAccountGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index f54f538f..b79b5f8b 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -495,26 +495,6 @@ func ProxyIDNotIn(vs ...int64) predicate.Account { return predicate.Account(sql.FieldNotIn(FieldProxyID, vs...)) } -// ProxyIDGT applies the GT predicate on the "proxy_id" field. -func ProxyIDGT(v int64) predicate.Account { - return predicate.Account(sql.FieldGT(FieldProxyID, v)) -} - -// ProxyIDGTE applies the GTE predicate on the "proxy_id" field. -func ProxyIDGTE(v int64) predicate.Account { - return predicate.Account(sql.FieldGTE(FieldProxyID, v)) -} - -// ProxyIDLT applies the LT predicate on the "proxy_id" field. -func ProxyIDLT(v int64) predicate.Account { - return predicate.Account(sql.FieldLT(FieldProxyID, v)) -} - -// ProxyIDLTE applies the LTE predicate on the "proxy_id" field. -func ProxyIDLTE(v int64) predicate.Account { - return predicate.Account(sql.FieldLTE(FieldProxyID, v)) -} - // ProxyIDIsNil applies the IsNil predicate on the "proxy_id" field. func ProxyIDIsNil() predicate.Account { return predicate.Account(sql.FieldIsNull(FieldProxyID)) @@ -1153,6 +1133,52 @@ func HasGroupsWith(preds ...predicate.Group) predicate.Account { }) } +// HasProxy applies the HasEdge predicate on the "proxy" edge. +func HasProxy() predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, ProxyTable, ProxyColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasProxyWith applies the HasEdge predicate on the "proxy" edge with a given conditions (other predicates). +func HasProxyWith(preds ...predicate.Proxy) predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := newProxyStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.Account { + return predicate.Account(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasAccountGroups applies the HasEdge predicate on the "account_groups" edge. func HasAccountGroups() predicate.Account { return predicate.Account(func(s *sql.Selector) { diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 6d813817..2fb52a81 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -13,6 +13,8 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/ent/usagelog" ) // AccountCreate is the builder for creating a Account entity. @@ -292,6 +294,26 @@ func (_c *AccountCreate) AddGroups(v ...*Group) *AccountCreate { return _c.AddGroupIDs(ids...) } +// SetProxy sets the "proxy" edge to the Proxy entity. +func (_c *AccountCreate) SetProxy(v *Proxy) *AccountCreate { + return _c.SetProxyID(v.ID) +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *AccountCreate) AddUsageLogIDs(ids ...int64) *AccountCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *AccountCreate) AddUsageLogs(v ...*UsageLog) *AccountCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the AccountMutation object of the builder. func (_c *AccountCreate) Mutation() *AccountMutation { return _c.mutation @@ -495,10 +517,6 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldExtra, field.TypeJSON, value) _node.Extra = value } - if value, ok := _c.mutation.ProxyID(); ok { - _spec.SetField(account.FieldProxyID, field.TypeInt64, value) - _node.ProxyID = &value - } if value, ok := _c.mutation.Concurrency(); ok { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) _node.Concurrency = value @@ -567,6 +585,39 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { edge.Target.Fields = specE.Fields _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.ProxyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.ProxyID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -721,12 +772,6 @@ func (u *AccountUpsert) UpdateProxyID() *AccountUpsert { return u } -// AddProxyID adds v to the "proxy_id" field. -func (u *AccountUpsert) AddProxyID(v int64) *AccountUpsert { - u.Add(account.FieldProxyID, v) - return u -} - // ClearProxyID clears the value of the "proxy_id" field. func (u *AccountUpsert) ClearProxyID() *AccountUpsert { u.SetNull(account.FieldProxyID) @@ -1094,13 +1139,6 @@ func (u *AccountUpsertOne) SetProxyID(v int64) *AccountUpsertOne { }) } -// AddProxyID adds v to the "proxy_id" field. -func (u *AccountUpsertOne) AddProxyID(v int64) *AccountUpsertOne { - return u.Update(func(s *AccountUpsert) { - s.AddProxyID(v) - }) -} - // UpdateProxyID sets the "proxy_id" field to the value that was provided on create. func (u *AccountUpsertOne) UpdateProxyID() *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -1676,13 +1714,6 @@ func (u *AccountUpsertBulk) SetProxyID(v int64) *AccountUpsertBulk { }) } -// AddProxyID adds v to the "proxy_id" field. -func (u *AccountUpsertBulk) AddProxyID(v int64) *AccountUpsertBulk { - return u.Update(func(s *AccountUpsert) { - s.AddProxyID(v) - }) -} - // UpdateProxyID sets the "proxy_id" field to the value that was provided on create. func (u *AccountUpsertBulk) UpdateProxyID() *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_query.go b/backend/ent/account_query.go index e5712884..3e363ecd 100644 --- a/backend/ent/account_query.go +++ b/backend/ent/account_query.go @@ -16,6 +16,8 @@ import ( "github.com/Wei-Shaw/sub2api/ent/accountgroup" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/ent/usagelog" ) // AccountQuery is the builder for querying Account entities. @@ -26,6 +28,8 @@ type AccountQuery struct { inters []Interceptor predicates []predicate.Account withGroups *GroupQuery + withProxy *ProxyQuery + withUsageLogs *UsageLogQuery withAccountGroups *AccountGroupQuery // intermediate query (i.e. traversal path). sql *sql.Selector @@ -85,6 +89,50 @@ func (_q *AccountQuery) QueryGroups() *GroupQuery { return query } +// QueryProxy chains the current query on the "proxy" edge. +func (_q *AccountQuery) QueryProxy() *ProxyQuery { + query := (&ProxyClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, selector), + sqlgraph.To(proxy.Table, proxy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, account.ProxyTable, account.ProxyColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *AccountQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, account.UsageLogsTable, account.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryAccountGroups chains the current query on the "account_groups" edge. func (_q *AccountQuery) QueryAccountGroups() *AccountGroupQuery { query := (&AccountGroupClient{config: _q.config}).Query() @@ -300,6 +348,8 @@ func (_q *AccountQuery) Clone() *AccountQuery { inters: append([]Interceptor{}, _q.inters...), predicates: append([]predicate.Account{}, _q.predicates...), withGroups: _q.withGroups.Clone(), + withProxy: _q.withProxy.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), withAccountGroups: _q.withAccountGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -318,6 +368,28 @@ func (_q *AccountQuery) WithGroups(opts ...func(*GroupQuery)) *AccountQuery { return _q } +// WithProxy tells the query-builder to eager-load the nodes that are connected to +// the "proxy" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AccountQuery) WithProxy(opts ...func(*ProxyQuery)) *AccountQuery { + query := (&ProxyClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withProxy = query + return _q +} + +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AccountQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *AccountQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // WithAccountGroups tells the query-builder to eager-load the nodes that are connected to // the "account_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *AccountQuery) WithAccountGroups(opts ...func(*AccountGroupQuery)) *AccountQuery { @@ -407,8 +479,10 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco var ( nodes = []*Account{} _spec = _q.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [4]bool{ _q.withGroups != nil, + _q.withProxy != nil, + _q.withUsageLogs != nil, _q.withAccountGroups != nil, } ) @@ -437,6 +511,19 @@ func (_q *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco return nil, err } } + if query := _q.withProxy; query != nil { + if err := _q.loadProxy(ctx, query, nodes, nil, + func(n *Account, e *Proxy) { n.Edges.Proxy = e }); err != nil { + return nil, err + } + } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *Account) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *Account, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } if query := _q.withAccountGroups; query != nil { if err := _q.loadAccountGroups(ctx, query, nodes, func(n *Account) { n.Edges.AccountGroups = []*AccountGroup{} }, @@ -508,6 +595,68 @@ func (_q *AccountQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes } return nil } +func (_q *AccountQuery) loadProxy(ctx context.Context, query *ProxyQuery, nodes []*Account, init func(*Account), assign func(*Account, *Proxy)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*Account) + for i := range nodes { + if nodes[i].ProxyID == nil { + continue + } + fk := *nodes[i].ProxyID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(proxy.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "proxy_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *AccountQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*Account, init func(*Account), assign func(*Account, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*Account) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldAccountID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(account.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.AccountID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "account_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *AccountQuery) loadAccountGroups(ctx context.Context, query *AccountGroupQuery, nodes []*Account, init func(*Account), assign func(*Account, *AccountGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*Account) @@ -564,6 +713,9 @@ func (_q *AccountQuery) querySpec() *sqlgraph.QuerySpec { _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) } } + if _q.withProxy != nil { + _spec.Node.AddColumnOnce(account.FieldProxyID) + } } if ps := _q.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 49eaaea8..cf8708c5 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -14,6 +14,8 @@ import ( "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/ent/usagelog" ) // AccountUpdate is the builder for updating Account entities. @@ -111,7 +113,6 @@ func (_u *AccountUpdate) SetExtra(v map[string]interface{}) *AccountUpdate { // SetProxyID sets the "proxy_id" field. func (_u *AccountUpdate) SetProxyID(v int64) *AccountUpdate { - _u.mutation.ResetProxyID() _u.mutation.SetProxyID(v) return _u } @@ -124,12 +125,6 @@ func (_u *AccountUpdate) SetNillableProxyID(v *int64) *AccountUpdate { return _u } -// AddProxyID adds value to the "proxy_id" field. -func (_u *AccountUpdate) AddProxyID(v int64) *AccountUpdate { - _u.mutation.AddProxyID(v) - return _u -} - // ClearProxyID clears the value of the "proxy_id" field. func (_u *AccountUpdate) ClearProxyID() *AccountUpdate { _u.mutation.ClearProxyID() @@ -381,6 +376,26 @@ func (_u *AccountUpdate) AddGroups(v ...*Group) *AccountUpdate { return _u.AddGroupIDs(ids...) } +// SetProxy sets the "proxy" edge to the Proxy entity. +func (_u *AccountUpdate) SetProxy(v *Proxy) *AccountUpdate { + return _u.SetProxyID(v.ID) +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *AccountUpdate) AddUsageLogIDs(ids ...int64) *AccountUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdate) AddUsageLogs(v ...*UsageLog) *AccountUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the AccountMutation object of the builder. func (_u *AccountUpdate) Mutation() *AccountMutation { return _u.mutation @@ -407,6 +422,33 @@ func (_u *AccountUpdate) RemoveGroups(v ...*Group) *AccountUpdate { return _u.RemoveGroupIDs(ids...) } +// ClearProxy clears the "proxy" edge to the Proxy entity. +func (_u *AccountUpdate) ClearProxy() *AccountUpdate { + _u.mutation.ClearProxy() + return _u +} + +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdate) ClearUsageLogs() *AccountUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *AccountUpdate) RemoveUsageLogIDs(ids ...int64) *AccountUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *AccountUpdate) RemoveUsageLogs(v ...*UsageLog) *AccountUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *AccountUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -515,15 +557,6 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Extra(); ok { _spec.SetField(account.FieldExtra, field.TypeJSON, value) } - if value, ok := _u.mutation.ProxyID(); ok { - _spec.SetField(account.FieldProxyID, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedProxyID(); ok { - _spec.AddField(account.FieldProxyID, field.TypeInt64, value) - } - if _u.mutation.ProxyIDCleared() { - _spec.ClearField(account.FieldProxyID, field.TypeInt64) - } if value, ok := _u.mutation.Concurrency(); ok { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) } @@ -647,6 +680,80 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.ProxyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ProxyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{account.Label} @@ -749,7 +856,6 @@ func (_u *AccountUpdateOne) SetExtra(v map[string]interface{}) *AccountUpdateOne // SetProxyID sets the "proxy_id" field. func (_u *AccountUpdateOne) SetProxyID(v int64) *AccountUpdateOne { - _u.mutation.ResetProxyID() _u.mutation.SetProxyID(v) return _u } @@ -762,12 +868,6 @@ func (_u *AccountUpdateOne) SetNillableProxyID(v *int64) *AccountUpdateOne { return _u } -// AddProxyID adds value to the "proxy_id" field. -func (_u *AccountUpdateOne) AddProxyID(v int64) *AccountUpdateOne { - _u.mutation.AddProxyID(v) - return _u -} - // ClearProxyID clears the value of the "proxy_id" field. func (_u *AccountUpdateOne) ClearProxyID() *AccountUpdateOne { _u.mutation.ClearProxyID() @@ -1019,6 +1119,26 @@ func (_u *AccountUpdateOne) AddGroups(v ...*Group) *AccountUpdateOne { return _u.AddGroupIDs(ids...) } +// SetProxy sets the "proxy" edge to the Proxy entity. +func (_u *AccountUpdateOne) SetProxy(v *Proxy) *AccountUpdateOne { + return _u.SetProxyID(v.ID) +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *AccountUpdateOne) AddUsageLogIDs(ids ...int64) *AccountUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdateOne) AddUsageLogs(v ...*UsageLog) *AccountUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the AccountMutation object of the builder. func (_u *AccountUpdateOne) Mutation() *AccountMutation { return _u.mutation @@ -1045,6 +1165,33 @@ func (_u *AccountUpdateOne) RemoveGroups(v ...*Group) *AccountUpdateOne { return _u.RemoveGroupIDs(ids...) } +// ClearProxy clears the "proxy" edge to the Proxy entity. +func (_u *AccountUpdateOne) ClearProxy() *AccountUpdateOne { + _u.mutation.ClearProxy() + return _u +} + +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *AccountUpdateOne) ClearUsageLogs() *AccountUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *AccountUpdateOne) RemoveUsageLogIDs(ids ...int64) *AccountUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *AccountUpdateOne) RemoveUsageLogs(v ...*UsageLog) *AccountUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the AccountUpdate builder. func (_u *AccountUpdateOne) Where(ps ...predicate.Account) *AccountUpdateOne { _u.mutation.Where(ps...) @@ -1183,15 +1330,6 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if value, ok := _u.mutation.Extra(); ok { _spec.SetField(account.FieldExtra, field.TypeJSON, value) } - if value, ok := _u.mutation.ProxyID(); ok { - _spec.SetField(account.FieldProxyID, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedProxyID(); ok { - _spec.AddField(account.FieldProxyID, field.TypeInt64, value) - } - if _u.mutation.ProxyIDCleared() { - _spec.ClearField(account.FieldProxyID, field.TypeInt64) - } if value, ok := _u.mutation.Concurrency(); ok { _spec.SetField(account.FieldConcurrency, field.TypeInt, value) } @@ -1315,6 +1453,80 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.ProxyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ProxyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: false, + Table: account.ProxyTable, + Columns: []string{account.ProxyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(proxy.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: account.UsageLogsTable, + Columns: []string{account.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &Account{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 30cf9b4d..61ac15fa 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -47,9 +47,11 @@ type ApiKeyEdges struct { User *User `json:"user,omitempty"` // Group holds the value of the group edge. Group *Group `json:"group,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [3]bool } // UserOrErr returns the User value or an error if the edge @@ -74,6 +76,15 @@ func (e ApiKeyEdges) GroupOrErr() (*Group, error) { return nil, &NotLoadedError{edge: "group"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e ApiKeyEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[2] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*ApiKey) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -179,6 +190,11 @@ func (_m *ApiKey) QueryGroup() *GroupQuery { return NewApiKeyClient(_m.config).QueryGroup(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the ApiKey entity. +func (_m *ApiKey) QueryUsageLogs() *UsageLogQuery { + return NewApiKeyClient(_m.config).QueryUsageLogs(_m) +} + // Update returns a builder for updating this ApiKey. // Note that you need to call ApiKey.Unwrap() before calling this method if this ApiKey // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 4eba5f53..f03b2daa 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -35,6 +35,8 @@ const ( EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. EdgeGroup = "group" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // Table holds the table name of the apikey in the database. Table = "api_keys" // UserTable is the table that holds the user relation/edge. @@ -51,6 +53,13 @@ const ( GroupInverseTable = "groups" // GroupColumn is the table column denoting the group relation/edge. GroupColumn = "group_id" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "api_key_id" ) // Columns holds all SQL columns for apikey fields. @@ -161,6 +170,20 @@ func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption { sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...)) } } + +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newUserStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -175,3 +198,10 @@ func newGroupStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index 11cabd3f..95bc4e2a 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -516,6 +516,29 @@ func HasGroupWith(preds ...predicate.Group) predicate.ApiKey { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.ApiKey { + return predicate.ApiKey(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.ApiKey { + return predicate.ApiKey(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.ApiKey) predicate.ApiKey { return predicate.ApiKey(sql.AndPredicates(predicates...)) diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index 8d7ddb69..5b984b21 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" ) @@ -122,6 +123,21 @@ func (_c *ApiKeyCreate) SetGroup(v *Group) *ApiKeyCreate { return _c.SetGroupID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *ApiKeyCreate) AddUsageLogIDs(ids ...int64) *ApiKeyCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *ApiKeyCreate) AddUsageLogs(v ...*UsageLog) *ApiKeyCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the ApiKeyMutation object of the builder. func (_c *ApiKeyCreate) Mutation() *ApiKeyMutation { return _c.mutation @@ -303,6 +319,22 @@ func (_c *ApiKeyCreate) createSpec() (*ApiKey, *sqlgraph.CreateSpec) { _node.GroupID = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/apikey_query.go b/backend/ent/apikey_query.go index 86051a60..d4029feb 100644 --- a/backend/ent/apikey_query.go +++ b/backend/ent/apikey_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" @@ -14,18 +15,20 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" ) // ApiKeyQuery is the builder for querying ApiKey entities. type ApiKeyQuery struct { config - ctx *QueryContext - order []apikey.OrderOption - inters []Interceptor - predicates []predicate.ApiKey - withUser *UserQuery - withGroup *GroupQuery + ctx *QueryContext + order []apikey.OrderOption + inters []Interceptor + predicates []predicate.ApiKey + withUser *UserQuery + withGroup *GroupQuery + withUsageLogs *UsageLogQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -106,6 +109,28 @@ func (_q *ApiKeyQuery) QueryGroup() *GroupQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *ApiKeyQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(apikey.Table, apikey.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, apikey.UsageLogsTable, apikey.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first ApiKey entity from the query. // Returns a *NotFoundError when no ApiKey was found. func (_q *ApiKeyQuery) First(ctx context.Context) (*ApiKey, error) { @@ -293,13 +318,14 @@ func (_q *ApiKeyQuery) Clone() *ApiKeyQuery { return nil } return &ApiKeyQuery{ - config: _q.config, - ctx: _q.ctx.Clone(), - order: append([]apikey.OrderOption{}, _q.order...), - inters: append([]Interceptor{}, _q.inters...), - predicates: append([]predicate.ApiKey{}, _q.predicates...), - withUser: _q.withUser.Clone(), - withGroup: _q.withGroup.Clone(), + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]apikey.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ApiKey{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withGroup: _q.withGroup.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), // clone intermediate query. sql: _q.sql.Clone(), path: _q.path, @@ -328,6 +354,17 @@ func (_q *ApiKeyQuery) WithGroup(opts ...func(*GroupQuery)) *ApiKeyQuery { return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ApiKeyQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *ApiKeyQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -406,9 +443,10 @@ func (_q *ApiKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ApiKe var ( nodes = []*ApiKey{} _spec = _q.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [3]bool{ _q.withUser != nil, _q.withGroup != nil, + _q.withUsageLogs != nil, } ) _spec.ScanValues = func(columns []string) ([]any, error) { @@ -441,6 +479,13 @@ func (_q *ApiKeyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ApiKe return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *ApiKey) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *ApiKey, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } return nodes, nil } @@ -505,6 +550,36 @@ func (_q *ApiKeyQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes [ } return nil } +func (_q *ApiKeyQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*ApiKey, init func(*ApiKey), assign func(*ApiKey, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*ApiKey) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldAPIKeyID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(apikey.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.APIKeyID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "api_key_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *ApiKeyQuery) sqlCount(ctx context.Context) (int, error) { _spec := _q.querySpec() diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 3917d068..3259bfd9 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" ) @@ -142,6 +143,21 @@ func (_u *ApiKeyUpdate) SetGroup(v *Group) *ApiKeyUpdate { return _u.SetGroupID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *ApiKeyUpdate) AddUsageLogIDs(ids ...int64) *ApiKeyUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdate) AddUsageLogs(v ...*UsageLog) *ApiKeyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the ApiKeyMutation object of the builder. func (_u *ApiKeyUpdate) Mutation() *ApiKeyMutation { return _u.mutation @@ -159,6 +175,27 @@ func (_u *ApiKeyUpdate) ClearGroup() *ApiKeyUpdate { return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdate) ClearUsageLogs() *ApiKeyUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *ApiKeyUpdate) RemoveUsageLogIDs(ids ...int64) *ApiKeyUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *ApiKeyUpdate) RemoveUsageLogs(v ...*UsageLog) *ApiKeyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *ApiKeyUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -312,6 +349,51 @@ func (_u *ApiKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{apikey.Label} @@ -444,6 +526,21 @@ func (_u *ApiKeyUpdateOne) SetGroup(v *Group) *ApiKeyUpdateOne { return _u.SetGroupID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *ApiKeyUpdateOne) AddUsageLogIDs(ids ...int64) *ApiKeyUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdateOne) AddUsageLogs(v ...*UsageLog) *ApiKeyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the ApiKeyMutation object of the builder. func (_u *ApiKeyUpdateOne) Mutation() *ApiKeyMutation { return _u.mutation @@ -461,6 +558,27 @@ func (_u *ApiKeyUpdateOne) ClearGroup() *ApiKeyUpdateOne { return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *ApiKeyUpdateOne) ClearUsageLogs() *ApiKeyUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *ApiKeyUpdateOne) RemoveUsageLogIDs(ids ...int64) *ApiKeyUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *ApiKeyUpdateOne) RemoveUsageLogs(v ...*UsageLog) *ApiKeyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the ApiKeyUpdate builder. func (_u *ApiKeyUpdateOne) Where(ps ...predicate.ApiKey) *ApiKeyUpdateOne { _u.mutation.Where(ps...) @@ -644,6 +762,51 @@ func (_u *ApiKeyUpdateOne) sqlSave(ctx context.Context) (_node *ApiKey, err erro } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: apikey.UsageLogsTable, + Columns: []string{apikey.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &ApiKey{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/client.go b/backend/ent/client.go index 113dc7ff..909226fa 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -48,6 +49,8 @@ type Client struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // UsageLog is the client for interacting with the UsageLog builders. + UsageLog *UsageLogClient // User is the client for interacting with the User builders. User *UserClient // UserAllowedGroup is the client for interacting with the UserAllowedGroup builders. @@ -72,6 +75,7 @@ func (c *Client) init() { c.Proxy = NewProxyClient(c.config) c.RedeemCode = NewRedeemCodeClient(c.config) c.Setting = NewSettingClient(c.config) + c.UsageLog = NewUsageLogClient(c.config) c.User = NewUserClient(c.config) c.UserAllowedGroup = NewUserAllowedGroupClient(c.config) c.UserSubscription = NewUserSubscriptionClient(c.config) @@ -174,6 +178,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), @@ -203,6 +208,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg), UserSubscription: NewUserSubscriptionClient(cfg), @@ -236,7 +242,7 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting, - c.User, c.UserAllowedGroup, c.UserSubscription, + c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription, } { n.Use(hooks...) } @@ -247,7 +253,7 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.Account, c.AccountGroup, c.ApiKey, c.Group, c.Proxy, c.RedeemCode, c.Setting, - c.User, c.UserAllowedGroup, c.UserSubscription, + c.UsageLog, c.User, c.UserAllowedGroup, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -270,6 +276,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.RedeemCode.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) + case *UsageLogMutation: + return c.UsageLog.mutate(ctx, m) case *UserMutation: return c.User.mutate(ctx, m) case *UserAllowedGroupMutation: @@ -405,6 +413,38 @@ func (c *AccountClient) QueryGroups(_m *Account) *GroupQuery { return query } +// QueryProxy queries the proxy edge of a Account. +func (c *AccountClient) QueryProxy(_m *Account) *ProxyQuery { + query := (&ProxyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, id), + sqlgraph.To(proxy.Table, proxy.FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, account.ProxyTable, account.ProxyColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryUsageLogs queries the usage_logs edge of a Account. +func (c *AccountClient) QueryUsageLogs(_m *Account) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(account.Table, account.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, account.UsageLogsTable, account.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryAccountGroups queries the account_groups edge of a Account. func (c *AccountClient) QueryAccountGroups(_m *Account) *AccountGroupQuery { query := (&AccountGroupClient{config: c.config}).Query() @@ -704,6 +744,22 @@ func (c *ApiKeyClient) QueryGroup(_m *ApiKey) *GroupQuery { return query } +// QueryUsageLogs queries the usage_logs edge of a ApiKey. +func (c *ApiKeyClient) QueryUsageLogs(_m *ApiKey) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(apikey.Table, apikey.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, apikey.UsageLogsTable, apikey.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *ApiKeyClient) Hooks() []Hook { hooks := c.hooks.ApiKey @@ -887,6 +943,22 @@ func (c *GroupClient) QuerySubscriptions(_m *Group) *UserSubscriptionQuery { return query } +// QueryUsageLogs queries the usage_logs edge of a Group. +func (c *GroupClient) QueryUsageLogs(_m *Group) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.UsageLogsTable, group.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryAccounts queries the accounts edge of a Group. func (c *GroupClient) QueryAccounts(_m *Group) *AccountQuery { query := (&AccountClient{config: c.config}).Query() @@ -1086,6 +1158,22 @@ func (c *ProxyClient) GetX(ctx context.Context, id int64) *Proxy { return obj } +// QueryAccounts queries the accounts edge of a Proxy. +func (c *ProxyClient) QueryAccounts(_m *Proxy) *AccountQuery { + query := (&AccountClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(proxy.Table, proxy.FieldID, id), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, proxy.AccountsTable, proxy.AccountsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *ProxyClient) Hooks() []Hook { hooks := c.hooks.Proxy @@ -1411,6 +1499,219 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value, } } +// UsageLogClient is a client for the UsageLog schema. +type UsageLogClient struct { + config +} + +// NewUsageLogClient returns a client for the UsageLog from the given config. +func NewUsageLogClient(c config) *UsageLogClient { + return &UsageLogClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `usagelog.Hooks(f(g(h())))`. +func (c *UsageLogClient) Use(hooks ...Hook) { + c.hooks.UsageLog = append(c.hooks.UsageLog, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `usagelog.Intercept(f(g(h())))`. +func (c *UsageLogClient) Intercept(interceptors ...Interceptor) { + c.inters.UsageLog = append(c.inters.UsageLog, interceptors...) +} + +// Create returns a builder for creating a UsageLog entity. +func (c *UsageLogClient) Create() *UsageLogCreate { + mutation := newUsageLogMutation(c.config, OpCreate) + return &UsageLogCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of UsageLog entities. +func (c *UsageLogClient) CreateBulk(builders ...*UsageLogCreate) *UsageLogCreateBulk { + return &UsageLogCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *UsageLogClient) MapCreateBulk(slice any, setFunc func(*UsageLogCreate, int)) *UsageLogCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &UsageLogCreateBulk{err: fmt.Errorf("calling to UsageLogClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*UsageLogCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &UsageLogCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for UsageLog. +func (c *UsageLogClient) Update() *UsageLogUpdate { + mutation := newUsageLogMutation(c.config, OpUpdate) + return &UsageLogUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UsageLogClient) UpdateOne(_m *UsageLog) *UsageLogUpdateOne { + mutation := newUsageLogMutation(c.config, OpUpdateOne, withUsageLog(_m)) + return &UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *UsageLogClient) UpdateOneID(id int64) *UsageLogUpdateOne { + mutation := newUsageLogMutation(c.config, OpUpdateOne, withUsageLogID(id)) + return &UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for UsageLog. +func (c *UsageLogClient) Delete() *UsageLogDelete { + mutation := newUsageLogMutation(c.config, OpDelete) + return &UsageLogDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *UsageLogClient) DeleteOne(_m *UsageLog) *UsageLogDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *UsageLogClient) DeleteOneID(id int64) *UsageLogDeleteOne { + builder := c.Delete().Where(usagelog.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &UsageLogDeleteOne{builder} +} + +// Query returns a query builder for UsageLog. +func (c *UsageLogClient) Query() *UsageLogQuery { + return &UsageLogQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeUsageLog}, + inters: c.Interceptors(), + } +} + +// Get returns a UsageLog entity by its id. +func (c *UsageLogClient) Get(ctx context.Context, id int64) (*UsageLog, error) { + return c.Query().Where(usagelog.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *UsageLogClient) GetX(ctx context.Context, id int64) *UsageLog { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a UsageLog. +func (c *UsageLogClient) QueryUser(_m *UsageLog) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.UserTable, usagelog.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAPIKey queries the api_key edge of a UsageLog. +func (c *UsageLogClient) QueryAPIKey(_m *UsageLog) *ApiKeyQuery { + query := (&ApiKeyClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(apikey.Table, apikey.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.APIKeyTable, usagelog.APIKeyColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAccount queries the account edge of a UsageLog. +func (c *UsageLogClient) QueryAccount(_m *UsageLog) *AccountQuery { + query := (&AccountClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.AccountTable, usagelog.AccountColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryGroup queries the group edge of a UsageLog. +func (c *UsageLogClient) QueryGroup(_m *UsageLog) *GroupQuery { + query := (&GroupClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.GroupTable, usagelog.GroupColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QuerySubscription queries the subscription edge of a UsageLog. +func (c *UsageLogClient) QuerySubscription(_m *UsageLog) *UserSubscriptionQuery { + query := (&UserSubscriptionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, id), + sqlgraph.To(usersubscription.Table, usersubscription.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.SubscriptionTable, usagelog.SubscriptionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *UsageLogClient) Hooks() []Hook { + return c.hooks.UsageLog +} + +// Interceptors returns the client interceptors. +func (c *UsageLogClient) Interceptors() []Interceptor { + return c.inters.UsageLog +} + +func (c *UsageLogClient) mutate(ctx context.Context, m *UsageLogMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UsageLogCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UsageLogUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UsageLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UsageLogDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown UsageLog mutation op: %q", m.Op()) + } +} + // UserClient is a client for the User schema. type UserClient struct { config @@ -1599,6 +1900,22 @@ func (c *UserClient) QueryAllowedGroups(_m *User) *GroupQuery { return query } +// QueryUsageLogs queries the usage_logs edge of a User. +func (c *UserClient) QueryUsageLogs(_m *User) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.UsageLogsTable, user.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -1914,14 +2231,32 @@ func (c *UserSubscriptionClient) QueryAssignedByUser(_m *UserSubscription) *User return query } +// QueryUsageLogs queries the usage_logs edge of a UserSubscription. +func (c *UserSubscriptionClient) QueryUsageLogs(_m *UserSubscription) *UsageLogQuery { + query := (&UsageLogClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(usersubscription.Table, usersubscription.FieldID, id), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *UserSubscriptionClient) Hooks() []Hook { - return c.hooks.UserSubscription + hooks := c.hooks.UserSubscription + return append(hooks[:len(hooks):len(hooks)], usersubscription.Hooks[:]...) } // Interceptors returns the client interceptors. func (c *UserSubscriptionClient) Interceptors() []Interceptor { - return c.inters.UserSubscription + inters := c.inters.UserSubscription + return append(inters[:len(inters):len(inters)], usersubscription.Interceptors[:]...) } func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscriptionMutation) (Value, error) { @@ -1942,16 +2277,15 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, User, - UserAllowedGroup, UserSubscription []ent.Hook + Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog, + User, UserAllowedGroup, UserSubscription []ent.Hook } inters struct { - Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, User, - UserAllowedGroup, UserSubscription []ent.Interceptor + Account, AccountGroup, ApiKey, Group, Proxy, RedeemCode, Setting, UsageLog, + User, UserAllowedGroup, UserSubscription []ent.Interceptor } ) -// ExecContext 透传到底层 driver,用于在 ent 事务中执行原生 SQL(例如同步 legacy 字段)。 // ExecContext allows calling the underlying ExecContext method of the driver if it is supported by it. // See, database/sql#DB.ExecContext for more information. func (c *config) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { @@ -1964,7 +2298,6 @@ func (c *config) ExecContext(ctx context.Context, query string, args ...any) (st return ex.ExecContext(ctx, query, args...) } -// QueryContext 透传到底层 driver,用于在事务内执行原生查询并共享锁/一致性语义。 // QueryContext allows calling the underlying QueryContext method of the driver if it is supported by it. // See, database/sql#DB.QueryContext for more information. func (c *config) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { diff --git a/backend/ent/ent.go b/backend/ent/ent.go index e2c8b56c..29890206 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -89,6 +90,7 @@ func checkColumn(t, c string) error { proxy.Table: proxy.ValidColumn, redeemcode.Table: redeemcode.ValidColumn, setting.Table: setting.ValidColumn, + usagelog.Table: usagelog.ValidColumn, user.Table: user.ValidColumn, userallowedgroup.Table: userallowedgroup.ValidColumn, usersubscription.Table: usersubscription.ValidColumn, diff --git a/backend/ent/group.go b/backend/ent/group.go index fecb202a..9b1e8604 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -43,6 +43,8 @@ type Group struct { WeeklyLimitUsd *float64 `json:"weekly_limit_usd,omitempty"` // MonthlyLimitUsd holds the value of the "monthly_limit_usd" field. MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"` + // DefaultValidityDays holds the value of the "default_validity_days" field. + DefaultValidityDays int `json:"default_validity_days,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -57,6 +59,8 @@ type GroupEdges struct { RedeemCodes []*RedeemCode `json:"redeem_codes,omitempty"` // Subscriptions holds the value of the subscriptions edge. Subscriptions []*UserSubscription `json:"subscriptions,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // Accounts holds the value of the accounts edge. Accounts []*Account `json:"accounts,omitempty"` // AllowedUsers holds the value of the allowed_users edge. @@ -67,7 +71,7 @@ type GroupEdges struct { UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [7]bool + loadedTypes [8]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -97,10 +101,19 @@ func (e GroupEdges) SubscriptionsOrErr() ([]*UserSubscription, error) { return nil, &NotLoadedError{edge: "subscriptions"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e GroupEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[3] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // AccountsOrErr returns the Accounts value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) AccountsOrErr() ([]*Account, error) { - if e.loadedTypes[3] { + if e.loadedTypes[4] { return e.Accounts, nil } return nil, &NotLoadedError{edge: "accounts"} @@ -109,7 +122,7 @@ func (e GroupEdges) AccountsOrErr() ([]*Account, error) { // AllowedUsersOrErr returns the AllowedUsers value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) AllowedUsersOrErr() ([]*User, error) { - if e.loadedTypes[4] { + if e.loadedTypes[5] { return e.AllowedUsers, nil } return nil, &NotLoadedError{edge: "allowed_users"} @@ -118,7 +131,7 @@ func (e GroupEdges) AllowedUsersOrErr() ([]*User, error) { // AccountGroupsOrErr returns the AccountGroups value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) AccountGroupsOrErr() ([]*AccountGroup, error) { - if e.loadedTypes[5] { + if e.loadedTypes[6] { return e.AccountGroups, nil } return nil, &NotLoadedError{edge: "account_groups"} @@ -127,7 +140,7 @@ func (e GroupEdges) AccountGroupsOrErr() ([]*AccountGroup, error) { // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e GroupEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[6] { + if e.loadedTypes[7] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -142,7 +155,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd: values[i] = new(sql.NullFloat64) - case group.FieldID: + case group.FieldID, group.FieldDefaultValidityDays: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -252,6 +265,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.MonthlyLimitUsd = new(float64) *_m.MonthlyLimitUsd = value.Float64 } + case group.FieldDefaultValidityDays: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field default_validity_days", values[i]) + } else if value.Valid { + _m.DefaultValidityDays = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -280,6 +299,11 @@ func (_m *Group) QuerySubscriptions() *UserSubscriptionQuery { return NewGroupClient(_m.config).QuerySubscriptions(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the Group entity. +func (_m *Group) QueryUsageLogs() *UsageLogQuery { + return NewGroupClient(_m.config).QueryUsageLogs(_m) +} + // QueryAccounts queries the "accounts" edge of the Group entity. func (_m *Group) QueryAccounts() *AccountQuery { return NewGroupClient(_m.config).QueryAccounts(_m) @@ -371,6 +395,9 @@ func (_m *Group) String() string { builder.WriteString("monthly_limit_usd=") builder.WriteString(fmt.Sprintf("%v", *v)) } + builder.WriteString(", ") + builder.WriteString("default_validity_days=") + builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 05a5673d..8dc53c49 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -41,12 +41,16 @@ const ( FieldWeeklyLimitUsd = "weekly_limit_usd" // FieldMonthlyLimitUsd holds the string denoting the monthly_limit_usd field in the database. FieldMonthlyLimitUsd = "monthly_limit_usd" + // FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database. + FieldDefaultValidityDays = "default_validity_days" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. EdgeRedeemCodes = "redeem_codes" // EdgeSubscriptions holds the string denoting the subscriptions edge name in mutations. EdgeSubscriptions = "subscriptions" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // EdgeAccounts holds the string denoting the accounts edge name in mutations. EdgeAccounts = "accounts" // EdgeAllowedUsers holds the string denoting the allowed_users edge name in mutations. @@ -78,6 +82,13 @@ const ( SubscriptionsInverseTable = "user_subscriptions" // SubscriptionsColumn is the table column denoting the subscriptions relation/edge. SubscriptionsColumn = "group_id" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "group_id" // AccountsTable is the table that holds the accounts relation/edge. The primary key declared below. AccountsTable = "account_groups" // AccountsInverseTable is the table name for the Account entity. @@ -120,6 +131,7 @@ var Columns = []string{ FieldDailyLimitUsd, FieldWeeklyLimitUsd, FieldMonthlyLimitUsd, + FieldDefaultValidityDays, } var ( @@ -173,6 +185,8 @@ var ( DefaultSubscriptionType string // SubscriptionTypeValidator is a validator for the "subscription_type" field. It is called by the builders before save. SubscriptionTypeValidator func(string) error + // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. + DefaultDefaultValidityDays int ) // OrderOption defines the ordering options for the Group queries. @@ -248,6 +262,11 @@ func ByMonthlyLimitUsd(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldMonthlyLimitUsd, opts...).ToFunc() } +// ByDefaultValidityDays orders the results by the default_validity_days field. +func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -290,6 +309,20 @@ func BySubscriptions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByAccountsCount orders the results by accounts count. func ByAccountsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -366,6 +399,13 @@ func newSubscriptionsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, SubscriptionsTable, SubscriptionsColumn), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} func newAccountsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index fd597be9..ac18a418 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -120,6 +120,11 @@ func MonthlyLimitUsd(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldMonthlyLimitUsd, v)) } +// DefaultValidityDays applies equality check predicate on the "default_validity_days" field. It's identical to DefaultValidityDaysEQ. +func DefaultValidityDays(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -785,6 +790,46 @@ func MonthlyLimitUsdNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldMonthlyLimitUsd)) } +// DefaultValidityDaysEQ applies the EQ predicate on the "default_validity_days" field. +func DefaultValidityDaysEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysNEQ applies the NEQ predicate on the "default_validity_days" field. +func DefaultValidityDaysNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysIn applies the In predicate on the "default_validity_days" field. +func DefaultValidityDaysIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldDefaultValidityDays, vs...)) +} + +// DefaultValidityDaysNotIn applies the NotIn predicate on the "default_validity_days" field. +func DefaultValidityDaysNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldDefaultValidityDays, vs...)) +} + +// DefaultValidityDaysGT applies the GT predicate on the "default_validity_days" field. +func DefaultValidityDaysGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysGTE applies the GTE predicate on the "default_validity_days" field. +func DefaultValidityDaysGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysLT applies the LT predicate on the "default_validity_days" field. +func DefaultValidityDaysLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldDefaultValidityDays, v)) +} + +// DefaultValidityDaysLTE applies the LTE predicate on the "default_validity_days" field. +func DefaultValidityDaysLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { @@ -854,6 +899,29 @@ func HasSubscriptionsWith(preds ...predicate.UserSubscription) predicate.Group { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.Group { + return predicate.Group(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasAccounts applies the HasEdge predicate on the "accounts" edge. func HasAccounts() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 873cf84c..383a1352 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -201,6 +202,20 @@ func (_c *GroupCreate) SetNillableMonthlyLimitUsd(v *float64) *GroupCreate { return _c } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (_c *GroupCreate) SetDefaultValidityDays(v int) *GroupCreate { + _c.mutation.SetDefaultValidityDays(v) + return _c +} + +// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil. +func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate { + if v != nil { + _c.SetDefaultValidityDays(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -246,6 +261,21 @@ func (_c *GroupCreate) AddSubscriptions(v ...*UserSubscription) *GroupCreate { return _c.AddSubscriptionIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *GroupCreate) AddUsageLogIDs(ids ...int64) *GroupCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *GroupCreate) AddUsageLogs(v ...*UsageLog) *GroupCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // AddAccountIDs adds the "accounts" edge to the Account entity by IDs. func (_c *GroupCreate) AddAccountIDs(ids ...int64) *GroupCreate { _c.mutation.AddAccountIDs(ids...) @@ -347,6 +377,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultSubscriptionType _c.mutation.SetSubscriptionType(v) } + if _, ok := _c.mutation.DefaultValidityDays(); !ok { + v := group.DefaultDefaultValidityDays + _c.mutation.SetDefaultValidityDays(v) + } return nil } @@ -396,6 +430,9 @@ func (_c *GroupCreate) check() error { return &ValidationError{Name: "subscription_type", err: fmt.Errorf(`ent: validator failed for field "Group.subscription_type": %w`, err)} } } + if _, ok := _c.mutation.DefaultValidityDays(); !ok { + return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} + } return nil } @@ -475,6 +512,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldMonthlyLimitUsd, field.TypeFloat64, value) _node.MonthlyLimitUsd = &value } + if value, ok := _c.mutation.DefaultValidityDays(); ok { + _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) + _node.DefaultValidityDays = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -523,6 +564,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } if nodes := _c.mutation.AccountsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2M, @@ -813,6 +870,24 @@ func (u *GroupUpsert) ClearMonthlyLimitUsd() *GroupUpsert { return u } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (u *GroupUpsert) SetDefaultValidityDays(v int) *GroupUpsert { + u.Set(group.FieldDefaultValidityDays, v) + return u +} + +// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create. +func (u *GroupUpsert) UpdateDefaultValidityDays() *GroupUpsert { + u.SetExcluded(group.FieldDefaultValidityDays) + return u +} + +// AddDefaultValidityDays adds v to the "default_validity_days" field. +func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert { + u.Add(group.FieldDefaultValidityDays, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1089,6 +1164,27 @@ func (u *GroupUpsertOne) ClearMonthlyLimitUsd() *GroupUpsertOne { }) } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (u *GroupUpsertOne) SetDefaultValidityDays(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultValidityDays(v) + }) +} + +// AddDefaultValidityDays adds v to the "default_validity_days" field. +func (u *GroupUpsertOne) AddDefaultValidityDays(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddDefaultValidityDays(v) + }) +} + +// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultValidityDays() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1531,6 +1627,27 @@ func (u *GroupUpsertBulk) ClearMonthlyLimitUsd() *GroupUpsertBulk { }) } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (u *GroupUpsertBulk) SetDefaultValidityDays(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetDefaultValidityDays(v) + }) +} + +// AddDefaultValidityDays adds v to the "default_validity_days" field. +func (u *GroupUpsertBulk) AddDefaultValidityDays(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddDefaultValidityDays(v) + }) +} + +// UpdateDefaultValidityDays sets the "default_validity_days" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateDefaultValidityDays() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_query.go b/backend/ent/group_query.go index 0b86e069..93a8d8c2 100644 --- a/backend/ent/group_query.go +++ b/backend/ent/group_query.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -33,6 +34,7 @@ type GroupQuery struct { withAPIKeys *ApiKeyQuery withRedeemCodes *RedeemCodeQuery withSubscriptions *UserSubscriptionQuery + withUsageLogs *UsageLogQuery withAccounts *AccountQuery withAllowedUsers *UserQuery withAccountGroups *AccountGroupQuery @@ -139,6 +141,28 @@ func (_q *GroupQuery) QuerySubscriptions() *UserSubscriptionQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *GroupQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(group.Table, group.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, group.UsageLogsTable, group.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryAccounts chains the current query on the "accounts" edge. func (_q *GroupQuery) QueryAccounts() *AccountQuery { query := (&AccountClient{config: _q.config}).Query() @@ -422,6 +446,7 @@ func (_q *GroupQuery) Clone() *GroupQuery { withAPIKeys: _q.withAPIKeys.Clone(), withRedeemCodes: _q.withRedeemCodes.Clone(), withSubscriptions: _q.withSubscriptions.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), withAccounts: _q.withAccounts.Clone(), withAllowedUsers: _q.withAllowedUsers.Clone(), withAccountGroups: _q.withAccountGroups.Clone(), @@ -465,6 +490,17 @@ func (_q *GroupQuery) WithSubscriptions(opts ...func(*UserSubscriptionQuery)) *G return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *GroupQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *GroupQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // WithAccounts tells the query-builder to eager-load the nodes that are connected to // the "accounts" edge. The optional arguments are used to configure the query builder of the edge. func (_q *GroupQuery) WithAccounts(opts ...func(*AccountQuery)) *GroupQuery { @@ -587,10 +623,11 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, var ( nodes = []*Group{} _spec = _q.querySpec() - loadedTypes = [7]bool{ + loadedTypes = [8]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, + _q.withUsageLogs != nil, _q.withAccounts != nil, _q.withAllowedUsers != nil, _q.withAccountGroups != nil, @@ -636,6 +673,13 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *Group) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *Group, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } if query := _q.withAccounts; query != nil { if err := _q.loadAccounts(ctx, query, nodes, func(n *Group) { n.Edges.Accounts = []*Account{} }, @@ -763,6 +807,39 @@ func (_q *GroupQuery) loadSubscriptions(ctx context.Context, query *UserSubscrip } return nil } +func (_q *GroupQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*Group, init func(*Group), assign func(*Group, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*Group) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldGroupID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(group.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.GroupID + if fk == nil { + return fmt.Errorf(`foreign-key "group_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "group_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *GroupQuery) loadAccounts(ctx context.Context, query *AccountQuery, nodes []*Group, init func(*Group), assign func(*Group, *Account)) error { edgeIDs := make([]driver.Value, len(nodes)) byID := make(map[int64]*Group) diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 0ed1e3fd..1825a892 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -251,6 +252,27 @@ func (_u *GroupUpdate) ClearMonthlyLimitUsd() *GroupUpdate { return _u } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (_u *GroupUpdate) SetDefaultValidityDays(v int) *GroupUpdate { + _u.mutation.ResetDefaultValidityDays() + _u.mutation.SetDefaultValidityDays(v) + return _u +} + +// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableDefaultValidityDays(v *int) *GroupUpdate { + if v != nil { + _u.SetDefaultValidityDays(*v) + } + return _u +} + +// AddDefaultValidityDays adds value to the "default_validity_days" field. +func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate { + _u.mutation.AddDefaultValidityDays(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -296,6 +318,21 @@ func (_u *GroupUpdate) AddSubscriptions(v ...*UserSubscription) *GroupUpdate { return _u.AddSubscriptionIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *GroupUpdate) AddUsageLogIDs(ids ...int64) *GroupUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdate) AddUsageLogs(v ...*UsageLog) *GroupUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // AddAccountIDs adds the "accounts" edge to the Account entity by IDs. func (_u *GroupUpdate) AddAccountIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAccountIDs(ids...) @@ -394,6 +431,27 @@ func (_u *GroupUpdate) RemoveSubscriptions(v ...*UserSubscription) *GroupUpdate return _u.RemoveSubscriptionIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdate) ClearUsageLogs() *GroupUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *GroupUpdate) RemoveUsageLogIDs(ids ...int64) *GroupUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *GroupUpdate) RemoveUsageLogs(v ...*UsageLog) *GroupUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // ClearAccounts clears all "accounts" edges to the Account entity. func (_u *GroupUpdate) ClearAccounts() *GroupUpdate { _u.mutation.ClearAccounts() @@ -578,6 +636,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.MonthlyLimitUsdCleared() { _spec.ClearField(group.FieldMonthlyLimitUsd, field.TypeFloat64) } + if value, ok := _u.mutation.DefaultValidityDays(); ok { + _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { + _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -713,6 +777,51 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _u.mutation.AccountsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2M, @@ -1065,6 +1174,27 @@ func (_u *GroupUpdateOne) ClearMonthlyLimitUsd() *GroupUpdateOne { return _u } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (_u *GroupUpdateOne) SetDefaultValidityDays(v int) *GroupUpdateOne { + _u.mutation.ResetDefaultValidityDays() + _u.mutation.SetDefaultValidityDays(v) + return _u +} + +// SetNillableDefaultValidityDays sets the "default_validity_days" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableDefaultValidityDays(v *int) *GroupUpdateOne { + if v != nil { + _u.SetDefaultValidityDays(*v) + } + return _u +} + +// AddDefaultValidityDays adds value to the "default_validity_days" field. +func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne { + _u.mutation.AddDefaultValidityDays(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1110,6 +1240,21 @@ func (_u *GroupUpdateOne) AddSubscriptions(v ...*UserSubscription) *GroupUpdateO return _u.AddSubscriptionIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *GroupUpdateOne) AddUsageLogIDs(ids ...int64) *GroupUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdateOne) AddUsageLogs(v ...*UsageLog) *GroupUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // AddAccountIDs adds the "accounts" edge to the Account entity by IDs. func (_u *GroupUpdateOne) AddAccountIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAccountIDs(ids...) @@ -1208,6 +1353,27 @@ func (_u *GroupUpdateOne) RemoveSubscriptions(v ...*UserSubscription) *GroupUpda return _u.RemoveSubscriptionIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *GroupUpdateOne) ClearUsageLogs() *GroupUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *GroupUpdateOne) RemoveUsageLogIDs(ids ...int64) *GroupUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *GroupUpdateOne) RemoveUsageLogs(v ...*UsageLog) *GroupUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // ClearAccounts clears all "accounts" edges to the Account entity. func (_u *GroupUpdateOne) ClearAccounts() *GroupUpdateOne { _u.mutation.ClearAccounts() @@ -1422,6 +1588,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.MonthlyLimitUsdCleared() { _spec.ClearField(group.FieldMonthlyLimitUsd, field.TypeFloat64) } + if value, ok := _u.mutation.DefaultValidityDays(); ok { + _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { + _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1557,6 +1729,51 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: group.UsageLogsTable, + Columns: []string{group.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _u.mutation.AccountsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2M, diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 46933bb0..33955cbb 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -93,6 +93,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m) } +// The UsageLogFunc type is an adapter to allow the use of ordinary +// function as UsageLog mutator. +type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f UsageLogFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.UsageLogMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageLogMutation", m) +} + // The UserFunc type is an adapter to allow the use of ordinary // function as User mutator. type UserFunc func(context.Context, *ent.UserMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index ab5f5554..9815f477 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -266,6 +267,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q) } +// The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier. +type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f UsageLogFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.UsageLogQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageLogQuery", q) +} + +// The TraverseUsageLog type is an adapter to allow the use of ordinary function as Traverser. +type TraverseUsageLog func(context.Context, *ent.UsageLogQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseUsageLog) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseUsageLog) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.UsageLogQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.UsageLogQuery", q) +} + // The UserFunc type is an adapter to allow the use of ordinary function as a Querier. type UserFunc func(context.Context, *ent.UserQuery) (ent.Value, error) @@ -364,6 +392,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil + case *ent.UsageLogQuery: + return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil case *ent.UserQuery: return &query[*ent.UserQuery, predicate.User, user.OrderOption]{typ: ent.TypeUser, tq: q}, nil case *ent.UserAllowedGroupQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 45408760..848ac74c 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -20,7 +20,6 @@ var ( {Name: "type", Type: field.TypeString, Size: 20}, {Name: "credentials", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, - {Name: "proxy_id", Type: field.TypeInt64, Nullable: true}, {Name: "concurrency", Type: field.TypeInt, Default: 3}, {Name: "priority", Type: field.TypeInt, Default: 50}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, @@ -33,12 +32,21 @@ var ( {Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20}, + {Name: "proxy_id", Type: field.TypeInt64, Nullable: true}, } // AccountsTable holds the schema information for the "accounts" table. AccountsTable = &schema.Table{ Name: "accounts", Columns: AccountsColumns, PrimaryKey: []*schema.Column{AccountsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "accounts_proxies_proxy", + Columns: []*schema.Column{AccountsColumns[21]}, + RefColumns: []*schema.Column{ProxiesColumns[0]}, + OnDelete: schema.SetNull, + }, + }, Indexes: []*schema.Index{ { Name: "account_platform", @@ -53,42 +61,42 @@ var ( { Name: "account_status", Unique: false, - Columns: []*schema.Column{AccountsColumns[12]}, + Columns: []*schema.Column{AccountsColumns[11]}, }, { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[9]}, + Columns: []*schema.Column{AccountsColumns[21]}, }, { Name: "account_priority", Unique: false, - Columns: []*schema.Column{AccountsColumns[11]}, + Columns: []*schema.Column{AccountsColumns[10]}, }, { Name: "account_last_used_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[14]}, + Columns: []*schema.Column{AccountsColumns[13]}, }, { Name: "account_schedulable", Unique: false, - Columns: []*schema.Column{AccountsColumns[15]}, + Columns: []*schema.Column{AccountsColumns[14]}, }, { Name: "account_rate_limited_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[16]}, + Columns: []*schema.Column{AccountsColumns[15]}, }, { Name: "account_rate_limit_reset_at", Unique: false, - Columns: []*schema.Column{AccountsColumns[17]}, + Columns: []*schema.Column{AccountsColumns[16]}, }, { Name: "account_overload_until", Unique: false, - Columns: []*schema.Column{AccountsColumns[18]}, + Columns: []*schema.Column{AccountsColumns[17]}, }, { Name: "account_deleted_at", @@ -100,7 +108,7 @@ var ( // AccountGroupsColumns holds the columns for the "account_groups" table. AccountGroupsColumns = []*schema.Column{ {Name: "priority", Type: field.TypeInt, Default: 50}, - {Name: "created_at", Type: field.TypeTime}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "account_id", Type: field.TypeInt64}, {Name: "group_id", Type: field.TypeInt64}, } @@ -168,11 +176,6 @@ var ( }, }, Indexes: []*schema.Index{ - { - Name: "apikey_key", - Unique: true, - Columns: []*schema.Column{APIKeysColumns[4]}, - }, { Name: "apikey_user_id", Unique: false, @@ -211,6 +214,7 @@ var ( {Name: "daily_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "default_validity_days", Type: field.TypeInt, Default: 30}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -218,11 +222,6 @@ var ( Columns: GroupsColumns, PrimaryKey: []*schema.Column{GroupsColumns[0]}, Indexes: []*schema.Index{ - { - Name: "group_name", - Unique: true, - Columns: []*schema.Column{GroupsColumns[4]}, - }, { Name: "group_status", Unique: false, @@ -316,11 +315,6 @@ var ( }, }, Indexes: []*schema.Index{ - { - Name: "redeemcode_code", - Unique: true, - Columns: []*schema.Column{RedeemCodesColumns[1]}, - }, { Name: "redeemcode_status", Unique: false, @@ -350,11 +344,123 @@ var ( Name: "settings", Columns: SettingsColumns, PrimaryKey: []*schema.Column{SettingsColumns[0]}, + } + // UsageLogsColumns holds the columns for the "usage_logs" table. + UsageLogsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "request_id", Type: field.TypeString, Size: 64}, + {Name: "model", Type: field.TypeString, Size: 100}, + {Name: "input_tokens", Type: field.TypeInt, Default: 0}, + {Name: "output_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_read_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_creation_5m_tokens", Type: field.TypeInt, Default: 0}, + {Name: "cache_creation_1h_tokens", Type: field.TypeInt, Default: 0}, + {Name: "input_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "output_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "cache_creation_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "cache_read_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}}, + {Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, + {Name: "billing_type", Type: field.TypeInt8, Default: 0}, + {Name: "stream", Type: field.TypeBool, Default: false}, + {Name: "duration_ms", Type: field.TypeInt, Nullable: true}, + {Name: "first_token_ms", Type: field.TypeInt, Nullable: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "account_id", Type: field.TypeInt64}, + {Name: "api_key_id", Type: field.TypeInt64}, + {Name: "group_id", Type: field.TypeInt64, Nullable: true}, + {Name: "user_id", Type: field.TypeInt64}, + {Name: "subscription_id", Type: field.TypeInt64, Nullable: true}, + } + // UsageLogsTable holds the schema information for the "usage_logs" table. + UsageLogsTable = &schema.Table{ + Name: "usage_logs", + Columns: UsageLogsColumns, + PrimaryKey: []*schema.Column{UsageLogsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "usage_logs_accounts_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[21]}, + RefColumns: []*schema.Column{AccountsColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "usage_logs_api_keys_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[22]}, + RefColumns: []*schema.Column{APIKeysColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "usage_logs_groups_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[23]}, + RefColumns: []*schema.Column{GroupsColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "usage_logs_users_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[24]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "usage_logs_user_subscriptions_usage_logs", + Columns: []*schema.Column{UsageLogsColumns[25]}, + RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, + OnDelete: schema.SetNull, + }, + }, Indexes: []*schema.Index{ { - Name: "setting_key", - Unique: true, - Columns: []*schema.Column{SettingsColumns[1]}, + Name: "usagelog_user_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[24]}, + }, + { + Name: "usagelog_api_key_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[22]}, + }, + { + Name: "usagelog_account_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[21]}, + }, + { + Name: "usagelog_group_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[23]}, + }, + { + Name: "usagelog_subscription_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[25]}, + }, + { + Name: "usagelog_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[20]}, + }, + { + Name: "usagelog_model", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[2]}, + }, + { + Name: "usagelog_request_id", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[1]}, + }, + { + Name: "usagelog_user_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[20]}, + }, + { + Name: "usagelog_api_key_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[22], UsageLogsColumns[20]}, }, }, } @@ -380,11 +486,6 @@ var ( Columns: UsersColumns, PrimaryKey: []*schema.Column{UsersColumns[0]}, Indexes: []*schema.Index{ - { - Name: "user_email", - Unique: true, - Columns: []*schema.Column{UsersColumns[4]}, - }, { Name: "user_status", Unique: false, @@ -399,7 +500,7 @@ var ( } // UserAllowedGroupsColumns holds the columns for the "user_allowed_groups" table. UserAllowedGroupsColumns = []*schema.Column{ - {Name: "created_at", Type: field.TypeTime}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "user_id", Type: field.TypeInt64}, {Name: "group_id", Type: field.TypeInt64}, } @@ -435,6 +536,7 @@ var ( {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "starts_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, @@ -458,19 +560,19 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "user_subscriptions_groups_subscriptions", - Columns: []*schema.Column{UserSubscriptionsColumns[14]}, + Columns: []*schema.Column{UserSubscriptionsColumns[15]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "user_subscriptions_users_subscriptions", - Columns: []*schema.Column{UserSubscriptionsColumns[15]}, + Columns: []*schema.Column{UserSubscriptionsColumns[16]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "user_subscriptions_users_assigned_subscriptions", - Columns: []*schema.Column{UserSubscriptionsColumns[16]}, + Columns: []*schema.Column{UserSubscriptionsColumns[17]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, @@ -479,32 +581,37 @@ var ( { Name: "usersubscription_user_id", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[15]}, + Columns: []*schema.Column{UserSubscriptionsColumns[16]}, }, { Name: "usersubscription_group_id", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[14]}, + Columns: []*schema.Column{UserSubscriptionsColumns[15]}, }, { Name: "usersubscription_status", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[5]}, + Columns: []*schema.Column{UserSubscriptionsColumns[6]}, }, { Name: "usersubscription_expires_at", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[4]}, + Columns: []*schema.Column{UserSubscriptionsColumns[5]}, }, { Name: "usersubscription_assigned_by", Unique: false, - Columns: []*schema.Column{UserSubscriptionsColumns[16]}, + Columns: []*schema.Column{UserSubscriptionsColumns[17]}, }, { Name: "usersubscription_user_id_group_id", Unique: true, - Columns: []*schema.Column{UserSubscriptionsColumns[15], UserSubscriptionsColumns[14]}, + Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[15]}, + }, + { + Name: "usersubscription_deleted_at", + Unique: false, + Columns: []*schema.Column{UserSubscriptionsColumns[3]}, }, }, } @@ -517,6 +624,7 @@ var ( ProxiesTable, RedeemCodesTable, SettingsTable, + UsageLogsTable, UsersTable, UserAllowedGroupsTable, UserSubscriptionsTable, @@ -524,6 +632,7 @@ var ( ) func init() { + AccountsTable.ForeignKeys[0].RefTable = ProxiesTable AccountsTable.Annotation = &entsql.Annotation{ Table: "accounts", } @@ -551,6 +660,14 @@ func init() { SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } + UsageLogsTable.ForeignKeys[0].RefTable = AccountsTable + UsageLogsTable.ForeignKeys[1].RefTable = APIKeysTable + UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable + UsageLogsTable.ForeignKeys[3].RefTable = UsersTable + UsageLogsTable.ForeignKeys[4].RefTable = UserSubscriptionsTable + UsageLogsTable.Annotation = &entsql.Annotation{ + Table: "usage_logs", + } UsersTable.Annotation = &entsql.Annotation{ Table: "users", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 45a6f5a7..9e4359ab 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -40,6 +41,7 @@ const ( TypeProxy = "Proxy" TypeRedeemCode = "RedeemCode" TypeSetting = "Setting" + TypeUsageLog = "UsageLog" TypeUser = "User" TypeUserAllowedGroup = "UserAllowedGroup" TypeUserSubscription = "UserSubscription" @@ -59,8 +61,6 @@ type AccountMutation struct { _type *string credentials *map[string]interface{} extra *map[string]interface{} - proxy_id *int64 - addproxy_id *int64 concurrency *int addconcurrency *int priority *int @@ -79,6 +79,11 @@ type AccountMutation struct { groups map[int64]struct{} removedgroups map[int64]struct{} clearedgroups bool + proxy *int64 + clearedproxy bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool done bool oldValue func(context.Context) (*Account, error) predicates []predicate.Account @@ -485,13 +490,12 @@ func (m *AccountMutation) ResetExtra() { // SetProxyID sets the "proxy_id" field. func (m *AccountMutation) SetProxyID(i int64) { - m.proxy_id = &i - m.addproxy_id = nil + m.proxy = &i } // ProxyID returns the value of the "proxy_id" field in the mutation. func (m *AccountMutation) ProxyID() (r int64, exists bool) { - v := m.proxy_id + v := m.proxy if v == nil { return } @@ -515,28 +519,9 @@ func (m *AccountMutation) OldProxyID(ctx context.Context) (v *int64, err error) return oldValue.ProxyID, nil } -// AddProxyID adds i to the "proxy_id" field. -func (m *AccountMutation) AddProxyID(i int64) { - if m.addproxy_id != nil { - *m.addproxy_id += i - } else { - m.addproxy_id = &i - } -} - -// AddedProxyID returns the value that was added to the "proxy_id" field in this mutation. -func (m *AccountMutation) AddedProxyID() (r int64, exists bool) { - v := m.addproxy_id - if v == nil { - return - } - return *v, true -} - // ClearProxyID clears the value of the "proxy_id" field. func (m *AccountMutation) ClearProxyID() { - m.proxy_id = nil - m.addproxy_id = nil + m.proxy = nil m.clearedFields[account.FieldProxyID] = struct{}{} } @@ -548,8 +533,7 @@ func (m *AccountMutation) ProxyIDCleared() bool { // ResetProxyID resets all changes to the "proxy_id" field. func (m *AccountMutation) ResetProxyID() { - m.proxy_id = nil - m.addproxy_id = nil + m.proxy = nil delete(m.clearedFields, account.FieldProxyID) } @@ -1183,6 +1167,87 @@ func (m *AccountMutation) ResetGroups() { m.removedgroups = nil } +// ClearProxy clears the "proxy" edge to the Proxy entity. +func (m *AccountMutation) ClearProxy() { + m.clearedproxy = true + m.clearedFields[account.FieldProxyID] = struct{}{} +} + +// ProxyCleared reports if the "proxy" edge to the Proxy entity was cleared. +func (m *AccountMutation) ProxyCleared() bool { + return m.ProxyIDCleared() || m.clearedproxy +} + +// ProxyIDs returns the "proxy" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// ProxyID instead. It exists only for internal usage by the builders. +func (m *AccountMutation) ProxyIDs() (ids []int64) { + if id := m.proxy; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetProxy resets all changes to the "proxy" edge. +func (m *AccountMutation) ResetProxy() { + m.proxy = nil + m.clearedproxy = false +} + +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *AccountMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *AccountMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *AccountMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *AccountMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *AccountMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *AccountMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *AccountMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the AccountMutation builder. func (m *AccountMutation) Where(ps ...predicate.Account) { m.predicates = append(m.predicates, ps...) @@ -1242,7 +1307,7 @@ func (m *AccountMutation) Fields() []string { if m.extra != nil { fields = append(fields, account.FieldExtra) } - if m.proxy_id != nil { + if m.proxy != nil { fields = append(fields, account.FieldProxyID) } if m.concurrency != nil { @@ -1546,9 +1611,6 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { // this mutation. func (m *AccountMutation) AddedFields() []string { var fields []string - if m.addproxy_id != nil { - fields = append(fields, account.FieldProxyID) - } if m.addconcurrency != nil { fields = append(fields, account.FieldConcurrency) } @@ -1563,8 +1625,6 @@ func (m *AccountMutation) AddedFields() []string { // was not set, or was not defined in the schema. func (m *AccountMutation) AddedField(name string) (ent.Value, bool) { switch name { - case account.FieldProxyID: - return m.AddedProxyID() case account.FieldConcurrency: return m.AddedConcurrency() case account.FieldPriority: @@ -1578,13 +1638,6 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *AccountMutation) AddField(name string, value ent.Value) error { switch name { - case account.FieldProxyID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddProxyID(v) - return nil case account.FieldConcurrency: v, ok := value.(int) if !ok { @@ -1758,10 +1811,16 @@ func (m *AccountMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *AccountMutation) AddedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 3) if m.groups != nil { edges = append(edges, account.EdgeGroups) } + if m.proxy != nil { + edges = append(edges, account.EdgeProxy) + } + if m.usage_logs != nil { + edges = append(edges, account.EdgeUsageLogs) + } return edges } @@ -1775,16 +1834,29 @@ func (m *AccountMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case account.EdgeProxy: + if id := m.proxy; id != nil { + return []ent.Value{*id} + } + case account.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *AccountMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 3) if m.removedgroups != nil { edges = append(edges, account.EdgeGroups) } + if m.removedusage_logs != nil { + edges = append(edges, account.EdgeUsageLogs) + } return edges } @@ -1798,16 +1870,28 @@ func (m *AccountMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case account.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *AccountMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 3) if m.clearedgroups { edges = append(edges, account.EdgeGroups) } + if m.clearedproxy { + edges = append(edges, account.EdgeProxy) + } + if m.clearedusage_logs { + edges = append(edges, account.EdgeUsageLogs) + } return edges } @@ -1817,6 +1901,10 @@ func (m *AccountMutation) EdgeCleared(name string) bool { switch name { case account.EdgeGroups: return m.clearedgroups + case account.EdgeProxy: + return m.clearedproxy + case account.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -1825,6 +1913,9 @@ func (m *AccountMutation) EdgeCleared(name string) bool { // if that edge is not defined in the schema. func (m *AccountMutation) ClearEdge(name string) error { switch name { + case account.EdgeProxy: + m.ClearProxy() + return nil } return fmt.Errorf("unknown Account unique edge %s", name) } @@ -1836,6 +1927,12 @@ func (m *AccountMutation) ResetEdge(name string) error { case account.EdgeGroups: m.ResetGroups() return nil + case account.EdgeProxy: + m.ResetProxy() + return nil + case account.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown Account edge %s", name) } @@ -2328,23 +2425,26 @@ func (m *AccountGroupMutation) ResetEdge(name string) error { // ApiKeyMutation represents an operation that mutates the ApiKey nodes in the graph. type ApiKeyMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - key *string - name *string - status *string - clearedFields map[string]struct{} - user *int64 - cleareduser bool - group *int64 - clearedgroup bool - done bool - oldValue func(context.Context) (*ApiKey, error) - predicates []predicate.ApiKey + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + key *string + name *string + status *string + clearedFields map[string]struct{} + user *int64 + cleareduser bool + group *int64 + clearedgroup bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*ApiKey, error) + predicates []predicate.ApiKey } var _ ent.Mutation = (*ApiKeyMutation)(nil) @@ -2813,6 +2913,60 @@ func (m *ApiKeyMutation) ResetGroup() { m.clearedgroup = false } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *ApiKeyMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *ApiKeyMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *ApiKeyMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *ApiKeyMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *ApiKeyMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *ApiKeyMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *ApiKeyMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the ApiKeyMutation builder. func (m *ApiKeyMutation) Where(ps ...predicate.ApiKey) { m.predicates = append(m.predicates, ps...) @@ -3083,13 +3237,16 @@ func (m *ApiKeyMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *ApiKeyMutation) AddedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.user != nil { edges = append(edges, apikey.EdgeUser) } if m.group != nil { edges = append(edges, apikey.EdgeGroup) } + if m.usage_logs != nil { + edges = append(edges, apikey.EdgeUsageLogs) + } return edges } @@ -3105,31 +3262,51 @@ func (m *ApiKeyMutation) AddedIDs(name string) []ent.Value { if id := m.group; id != nil { return []ent.Value{*id} } + case apikey.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *ApiKeyMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) + if m.removedusage_logs != nil { + edges = append(edges, apikey.EdgeUsageLogs) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *ApiKeyMutation) RemovedIDs(name string) []ent.Value { + switch name { + case apikey.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *ApiKeyMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) + edges := make([]string, 0, 3) if m.cleareduser { edges = append(edges, apikey.EdgeUser) } if m.clearedgroup { edges = append(edges, apikey.EdgeGroup) } + if m.clearedusage_logs { + edges = append(edges, apikey.EdgeUsageLogs) + } return edges } @@ -3141,6 +3318,8 @@ func (m *ApiKeyMutation) EdgeCleared(name string) bool { return m.cleareduser case apikey.EdgeGroup: return m.clearedgroup + case apikey.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -3169,6 +3348,9 @@ func (m *ApiKeyMutation) ResetEdge(name string) error { case apikey.EdgeGroup: m.ResetGroup() return nil + case apikey.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown ApiKey edge %s", name) } @@ -3176,45 +3358,50 @@ func (m *ApiKeyMutation) ResetEdge(name string) error { // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -3931,6 +4118,62 @@ func (m *GroupMutation) ResetMonthlyLimitUsd() { delete(m.clearedFields, group.FieldMonthlyLimitUsd) } +// SetDefaultValidityDays sets the "default_validity_days" field. +func (m *GroupMutation) SetDefaultValidityDays(i int) { + m.default_validity_days = &i + m.adddefault_validity_days = nil +} + +// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. +func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { + v := m.default_validity_days + if v == nil { + return + } + return *v, true +} + +// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) + } + return oldValue.DefaultValidityDays, nil +} + +// AddDefaultValidityDays adds i to the "default_validity_days" field. +func (m *GroupMutation) AddDefaultValidityDays(i int) { + if m.adddefault_validity_days != nil { + *m.adddefault_validity_days += i + } else { + m.adddefault_validity_days = &i + } +} + +// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. +func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { + v := m.adddefault_validity_days + if v == nil { + return + } + return *v, true +} + +// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. +func (m *GroupMutation) ResetDefaultValidityDays() { + m.default_validity_days = nil + m.adddefault_validity_days = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the ApiKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -4093,6 +4336,60 @@ func (m *GroupMutation) ResetSubscriptions() { m.removedsubscriptions = nil } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *GroupMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *GroupMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *GroupMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // AddAccountIDs adds the "accounts" edge to the Account entity by ids. func (m *GroupMutation) AddAccountIDs(ids ...int64) { if m.accounts == nil { @@ -4235,7 +4532,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -4275,6 +4572,9 @@ func (m *GroupMutation) Fields() []string { if m.monthly_limit_usd != nil { fields = append(fields, group.FieldMonthlyLimitUsd) } + if m.default_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } return fields } @@ -4309,6 +4609,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.WeeklyLimitUsd() case group.FieldMonthlyLimitUsd: return m.MonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.DefaultValidityDays() } return nil, false } @@ -4344,6 +4646,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldWeeklyLimitUsd(ctx) case group.FieldMonthlyLimitUsd: return m.OldMonthlyLimitUsd(ctx) + case group.FieldDefaultValidityDays: + return m.OldDefaultValidityDays(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -4444,6 +4748,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetMonthlyLimitUsd(v) return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultValidityDays(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -4464,6 +4775,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addmonthly_limit_usd != nil { fields = append(fields, group.FieldMonthlyLimitUsd) } + if m.adddefault_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } return fields } @@ -4480,6 +4794,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedWeeklyLimitUsd() case group.FieldMonthlyLimitUsd: return m.AddedMonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.AddedDefaultValidityDays() } return nil, false } @@ -4517,6 +4833,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddMonthlyLimitUsd(v) return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDefaultValidityDays(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -4616,13 +4939,16 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldMonthlyLimitUsd: m.ResetMonthlyLimitUsd() return nil + case group.FieldDefaultValidityDays: + m.ResetDefaultValidityDays() + return nil } return fmt.Errorf("unknown Group field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. func (m *GroupMutation) AddedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.api_keys != nil { edges = append(edges, group.EdgeAPIKeys) } @@ -4632,6 +4958,9 @@ func (m *GroupMutation) AddedEdges() []string { if m.subscriptions != nil { edges = append(edges, group.EdgeSubscriptions) } + if m.usage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } if m.accounts != nil { edges = append(edges, group.EdgeAccounts) } @@ -4663,6 +4992,12 @@ func (m *GroupMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids case group.EdgeAccounts: ids := make([]ent.Value, 0, len(m.accounts)) for id := range m.accounts { @@ -4681,7 +5016,7 @@ func (m *GroupMutation) AddedIDs(name string) []ent.Value { // RemovedEdges returns all edge names that were removed in this mutation. func (m *GroupMutation) RemovedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.removedapi_keys != nil { edges = append(edges, group.EdgeAPIKeys) } @@ -4691,6 +5026,9 @@ func (m *GroupMutation) RemovedEdges() []string { if m.removedsubscriptions != nil { edges = append(edges, group.EdgeSubscriptions) } + if m.removedusage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } if m.removedaccounts != nil { edges = append(edges, group.EdgeAccounts) } @@ -4722,6 +5060,12 @@ func (m *GroupMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids case group.EdgeAccounts: ids := make([]ent.Value, 0, len(m.removedaccounts)) for id := range m.removedaccounts { @@ -4740,7 +5084,7 @@ func (m *GroupMutation) RemovedIDs(name string) []ent.Value { // ClearedEdges returns all edge names that were cleared in this mutation. func (m *GroupMutation) ClearedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.clearedapi_keys { edges = append(edges, group.EdgeAPIKeys) } @@ -4750,6 +5094,9 @@ func (m *GroupMutation) ClearedEdges() []string { if m.clearedsubscriptions { edges = append(edges, group.EdgeSubscriptions) } + if m.clearedusage_logs { + edges = append(edges, group.EdgeUsageLogs) + } if m.clearedaccounts { edges = append(edges, group.EdgeAccounts) } @@ -4769,6 +5116,8 @@ func (m *GroupMutation) EdgeCleared(name string) bool { return m.clearedredeem_codes case group.EdgeSubscriptions: return m.clearedsubscriptions + case group.EdgeUsageLogs: + return m.clearedusage_logs case group.EdgeAccounts: return m.clearedaccounts case group.EdgeAllowedUsers: @@ -4798,6 +5147,9 @@ func (m *GroupMutation) ResetEdge(name string) error { case group.EdgeSubscriptions: m.ResetSubscriptions() return nil + case group.EdgeUsageLogs: + m.ResetUsageLogs() + return nil case group.EdgeAccounts: m.ResetAccounts() return nil @@ -4811,24 +5163,27 @@ func (m *GroupMutation) ResetEdge(name string) error { // ProxyMutation represents an operation that mutates the Proxy nodes in the graph. type ProxyMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - protocol *string - host *string - port *int - addport *int - username *string - password *string - status *string - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*Proxy, error) - predicates []predicate.Proxy + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + protocol *string + host *string + port *int + addport *int + username *string + password *string + status *string + clearedFields map[string]struct{} + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + done bool + oldValue func(context.Context) (*Proxy, error) + predicates []predicate.Proxy } var _ ent.Mutation = (*ProxyMutation)(nil) @@ -5348,6 +5703,60 @@ func (m *ProxyMutation) ResetStatus() { m.status = nil } +// AddAccountIDs adds the "accounts" edge to the Account entity by ids. +func (m *ProxyMutation) AddAccountIDs(ids ...int64) { + if m.accounts == nil { + m.accounts = make(map[int64]struct{}) + } + for i := range ids { + m.accounts[ids[i]] = struct{}{} + } +} + +// ClearAccounts clears the "accounts" edge to the Account entity. +func (m *ProxyMutation) ClearAccounts() { + m.clearedaccounts = true +} + +// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. +func (m *ProxyMutation) AccountsCleared() bool { + return m.clearedaccounts +} + +// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. +func (m *ProxyMutation) RemoveAccountIDs(ids ...int64) { + if m.removedaccounts == nil { + m.removedaccounts = make(map[int64]struct{}) + } + for i := range ids { + delete(m.accounts, ids[i]) + m.removedaccounts[ids[i]] = struct{}{} + } +} + +// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. +func (m *ProxyMutation) RemovedAccountsIDs() (ids []int64) { + for id := range m.removedaccounts { + ids = append(ids, id) + } + return +} + +// AccountsIDs returns the "accounts" edge IDs in the mutation. +func (m *ProxyMutation) AccountsIDs() (ids []int64) { + for id := range m.accounts { + ids = append(ids, id) + } + return +} + +// ResetAccounts resets all changes to the "accounts" edge. +func (m *ProxyMutation) ResetAccounts() { + m.accounts = nil + m.clearedaccounts = false + m.removedaccounts = nil +} + // Where appends a list predicates to the ProxyMutation builder. func (m *ProxyMutation) Where(ps ...predicate.Proxy) { m.predicates = append(m.predicates, ps...) @@ -5670,49 +6079,85 @@ func (m *ProxyMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *ProxyMutation) AddedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.accounts != nil { + edges = append(edges, proxy.EdgeAccounts) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. func (m *ProxyMutation) AddedIDs(name string) []ent.Value { + switch name { + case proxy.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.accounts)) + for id := range m.accounts { + ids = append(ids, id) + } + return ids + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *ProxyMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.removedaccounts != nil { + edges = append(edges, proxy.EdgeAccounts) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *ProxyMutation) RemovedIDs(name string) []ent.Value { + switch name { + case proxy.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.removedaccounts)) + for id := range m.removedaccounts { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *ProxyMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) + edges := make([]string, 0, 1) + if m.clearedaccounts { + edges = append(edges, proxy.EdgeAccounts) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. func (m *ProxyMutation) EdgeCleared(name string) bool { + switch name { + case proxy.EdgeAccounts: + return m.clearedaccounts + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. func (m *ProxyMutation) ClearEdge(name string) error { + switch name { + } return fmt.Errorf("unknown Proxy unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. func (m *ProxyMutation) ResetEdge(name string) error { + switch name { + case proxy.EdgeAccounts: + m.ResetAccounts() + return nil + } return fmt.Errorf("unknown Proxy edge %s", name) } @@ -7223,6 +7668,2478 @@ func (m *SettingMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Setting edge %s", name) } +// UsageLogMutation represents an operation that mutates the UsageLog nodes in the graph. +type UsageLogMutation struct { + config + op Op + typ string + id *int64 + request_id *string + model *string + input_tokens *int + addinput_tokens *int + output_tokens *int + addoutput_tokens *int + cache_creation_tokens *int + addcache_creation_tokens *int + cache_read_tokens *int + addcache_read_tokens *int + cache_creation_5m_tokens *int + addcache_creation_5m_tokens *int + cache_creation_1h_tokens *int + addcache_creation_1h_tokens *int + input_cost *float64 + addinput_cost *float64 + output_cost *float64 + addoutput_cost *float64 + cache_creation_cost *float64 + addcache_creation_cost *float64 + cache_read_cost *float64 + addcache_read_cost *float64 + total_cost *float64 + addtotal_cost *float64 + actual_cost *float64 + addactual_cost *float64 + rate_multiplier *float64 + addrate_multiplier *float64 + billing_type *int8 + addbilling_type *int8 + stream *bool + duration_ms *int + addduration_ms *int + first_token_ms *int + addfirst_token_ms *int + created_at *time.Time + clearedFields map[string]struct{} + user *int64 + cleareduser bool + api_key *int64 + clearedapi_key bool + account *int64 + clearedaccount bool + group *int64 + clearedgroup bool + subscription *int64 + clearedsubscription bool + done bool + oldValue func(context.Context) (*UsageLog, error) + predicates []predicate.UsageLog +} + +var _ ent.Mutation = (*UsageLogMutation)(nil) + +// usagelogOption allows management of the mutation configuration using functional options. +type usagelogOption func(*UsageLogMutation) + +// newUsageLogMutation creates new mutation for the UsageLog entity. +func newUsageLogMutation(c config, op Op, opts ...usagelogOption) *UsageLogMutation { + m := &UsageLogMutation{ + config: c, + op: op, + typ: TypeUsageLog, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withUsageLogID sets the ID field of the mutation. +func withUsageLogID(id int64) usagelogOption { + return func(m *UsageLogMutation) { + var ( + err error + once sync.Once + value *UsageLog + ) + m.oldValue = func(ctx context.Context) (*UsageLog, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().UsageLog.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withUsageLog sets the old UsageLog of the mutation. +func withUsageLog(node *UsageLog) usagelogOption { + return func(m *UsageLogMutation) { + m.oldValue = func(context.Context) (*UsageLog, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m UsageLogMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m UsageLogMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *UsageLogMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *UsageLogMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().UsageLog.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetUserID sets the "user_id" field. +func (m *UsageLogMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *UsageLogMutation) UserID() (r int64, exists bool) { + v := m.user + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *UsageLogMutation) ResetUserID() { + m.user = nil +} + +// SetAPIKeyID sets the "api_key_id" field. +func (m *UsageLogMutation) SetAPIKeyID(i int64) { + m.api_key = &i +} + +// APIKeyID returns the value of the "api_key_id" field in the mutation. +func (m *UsageLogMutation) APIKeyID() (r int64, exists bool) { + v := m.api_key + if v == nil { + return + } + return *v, true +} + +// OldAPIKeyID returns the old "api_key_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldAPIKeyID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAPIKeyID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAPIKeyID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAPIKeyID: %w", err) + } + return oldValue.APIKeyID, nil +} + +// ResetAPIKeyID resets all changes to the "api_key_id" field. +func (m *UsageLogMutation) ResetAPIKeyID() { + m.api_key = nil +} + +// SetAccountID sets the "account_id" field. +func (m *UsageLogMutation) SetAccountID(i int64) { + m.account = &i +} + +// AccountID returns the value of the "account_id" field in the mutation. +func (m *UsageLogMutation) AccountID() (r int64, exists bool) { + v := m.account + if v == nil { + return + } + return *v, true +} + +// OldAccountID returns the old "account_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldAccountID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccountID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccountID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccountID: %w", err) + } + return oldValue.AccountID, nil +} + +// ResetAccountID resets all changes to the "account_id" field. +func (m *UsageLogMutation) ResetAccountID() { + m.account = nil +} + +// SetRequestID sets the "request_id" field. +func (m *UsageLogMutation) SetRequestID(s string) { + m.request_id = &s +} + +// RequestID returns the value of the "request_id" field in the mutation. +func (m *UsageLogMutation) RequestID() (r string, exists bool) { + v := m.request_id + if v == nil { + return + } + return *v, true +} + +// OldRequestID returns the old "request_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldRequestID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestID: %w", err) + } + return oldValue.RequestID, nil +} + +// ResetRequestID resets all changes to the "request_id" field. +func (m *UsageLogMutation) ResetRequestID() { + m.request_id = nil +} + +// SetModel sets the "model" field. +func (m *UsageLogMutation) SetModel(s string) { + m.model = &s +} + +// Model returns the value of the "model" field in the mutation. +func (m *UsageLogMutation) Model() (r string, exists bool) { + v := m.model + if v == nil { + return + } + return *v, true +} + +// OldModel returns the old "model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldModel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModel: %w", err) + } + return oldValue.Model, nil +} + +// ResetModel resets all changes to the "model" field. +func (m *UsageLogMutation) ResetModel() { + m.model = nil +} + +// SetGroupID sets the "group_id" field. +func (m *UsageLogMutation) SetGroupID(i int64) { + m.group = &i +} + +// GroupID returns the value of the "group_id" field in the mutation. +func (m *UsageLogMutation) GroupID() (r int64, exists bool) { + v := m.group + if v == nil { + return + } + return *v, true +} + +// OldGroupID returns the old "group_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldGroupID: %w", err) + } + return oldValue.GroupID, nil +} + +// ClearGroupID clears the value of the "group_id" field. +func (m *UsageLogMutation) ClearGroupID() { + m.group = nil + m.clearedFields[usagelog.FieldGroupID] = struct{}{} +} + +// GroupIDCleared returns if the "group_id" field was cleared in this mutation. +func (m *UsageLogMutation) GroupIDCleared() bool { + _, ok := m.clearedFields[usagelog.FieldGroupID] + return ok +} + +// ResetGroupID resets all changes to the "group_id" field. +func (m *UsageLogMutation) ResetGroupID() { + m.group = nil + delete(m.clearedFields, usagelog.FieldGroupID) +} + +// SetSubscriptionID sets the "subscription_id" field. +func (m *UsageLogMutation) SetSubscriptionID(i int64) { + m.subscription = &i +} + +// SubscriptionID returns the value of the "subscription_id" field in the mutation. +func (m *UsageLogMutation) SubscriptionID() (r int64, exists bool) { + v := m.subscription + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionID returns the old "subscription_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldSubscriptionID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionID: %w", err) + } + return oldValue.SubscriptionID, nil +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (m *UsageLogMutation) ClearSubscriptionID() { + m.subscription = nil + m.clearedFields[usagelog.FieldSubscriptionID] = struct{}{} +} + +// SubscriptionIDCleared returns if the "subscription_id" field was cleared in this mutation. +func (m *UsageLogMutation) SubscriptionIDCleared() bool { + _, ok := m.clearedFields[usagelog.FieldSubscriptionID] + return ok +} + +// ResetSubscriptionID resets all changes to the "subscription_id" field. +func (m *UsageLogMutation) ResetSubscriptionID() { + m.subscription = nil + delete(m.clearedFields, usagelog.FieldSubscriptionID) +} + +// SetInputTokens sets the "input_tokens" field. +func (m *UsageLogMutation) SetInputTokens(i int) { + m.input_tokens = &i + m.addinput_tokens = nil +} + +// InputTokens returns the value of the "input_tokens" field in the mutation. +func (m *UsageLogMutation) InputTokens() (r int, exists bool) { + v := m.input_tokens + if v == nil { + return + } + return *v, true +} + +// OldInputTokens returns the old "input_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldInputTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldInputTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldInputTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldInputTokens: %w", err) + } + return oldValue.InputTokens, nil +} + +// AddInputTokens adds i to the "input_tokens" field. +func (m *UsageLogMutation) AddInputTokens(i int) { + if m.addinput_tokens != nil { + *m.addinput_tokens += i + } else { + m.addinput_tokens = &i + } +} + +// AddedInputTokens returns the value that was added to the "input_tokens" field in this mutation. +func (m *UsageLogMutation) AddedInputTokens() (r int, exists bool) { + v := m.addinput_tokens + if v == nil { + return + } + return *v, true +} + +// ResetInputTokens resets all changes to the "input_tokens" field. +func (m *UsageLogMutation) ResetInputTokens() { + m.input_tokens = nil + m.addinput_tokens = nil +} + +// SetOutputTokens sets the "output_tokens" field. +func (m *UsageLogMutation) SetOutputTokens(i int) { + m.output_tokens = &i + m.addoutput_tokens = nil +} + +// OutputTokens returns the value of the "output_tokens" field in the mutation. +func (m *UsageLogMutation) OutputTokens() (r int, exists bool) { + v := m.output_tokens + if v == nil { + return + } + return *v, true +} + +// OldOutputTokens returns the old "output_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldOutputTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOutputTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOutputTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOutputTokens: %w", err) + } + return oldValue.OutputTokens, nil +} + +// AddOutputTokens adds i to the "output_tokens" field. +func (m *UsageLogMutation) AddOutputTokens(i int) { + if m.addoutput_tokens != nil { + *m.addoutput_tokens += i + } else { + m.addoutput_tokens = &i + } +} + +// AddedOutputTokens returns the value that was added to the "output_tokens" field in this mutation. +func (m *UsageLogMutation) AddedOutputTokens() (r int, exists bool) { + v := m.addoutput_tokens + if v == nil { + return + } + return *v, true +} + +// ResetOutputTokens resets all changes to the "output_tokens" field. +func (m *UsageLogMutation) ResetOutputTokens() { + m.output_tokens = nil + m.addoutput_tokens = nil +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (m *UsageLogMutation) SetCacheCreationTokens(i int) { + m.cache_creation_tokens = &i + m.addcache_creation_tokens = nil +} + +// CacheCreationTokens returns the value of the "cache_creation_tokens" field in the mutation. +func (m *UsageLogMutation) CacheCreationTokens() (r int, exists bool) { + v := m.cache_creation_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheCreationTokens returns the old "cache_creation_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreationTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreationTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreationTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreationTokens: %w", err) + } + return oldValue.CacheCreationTokens, nil +} + +// AddCacheCreationTokens adds i to the "cache_creation_tokens" field. +func (m *UsageLogMutation) AddCacheCreationTokens(i int) { + if m.addcache_creation_tokens != nil { + *m.addcache_creation_tokens += i + } else { + m.addcache_creation_tokens = &i + } +} + +// AddedCacheCreationTokens returns the value that was added to the "cache_creation_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreationTokens() (r int, exists bool) { + v := m.addcache_creation_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreationTokens resets all changes to the "cache_creation_tokens" field. +func (m *UsageLogMutation) ResetCacheCreationTokens() { + m.cache_creation_tokens = nil + m.addcache_creation_tokens = nil +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (m *UsageLogMutation) SetCacheReadTokens(i int) { + m.cache_read_tokens = &i + m.addcache_read_tokens = nil +} + +// CacheReadTokens returns the value of the "cache_read_tokens" field in the mutation. +func (m *UsageLogMutation) CacheReadTokens() (r int, exists bool) { + v := m.cache_read_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheReadTokens returns the old "cache_read_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheReadTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheReadTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheReadTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheReadTokens: %w", err) + } + return oldValue.CacheReadTokens, nil +} + +// AddCacheReadTokens adds i to the "cache_read_tokens" field. +func (m *UsageLogMutation) AddCacheReadTokens(i int) { + if m.addcache_read_tokens != nil { + *m.addcache_read_tokens += i + } else { + m.addcache_read_tokens = &i + } +} + +// AddedCacheReadTokens returns the value that was added to the "cache_read_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheReadTokens() (r int, exists bool) { + v := m.addcache_read_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheReadTokens resets all changes to the "cache_read_tokens" field. +func (m *UsageLogMutation) ResetCacheReadTokens() { + m.cache_read_tokens = nil + m.addcache_read_tokens = nil +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (m *UsageLogMutation) SetCacheCreation5mTokens(i int) { + m.cache_creation_5m_tokens = &i + m.addcache_creation_5m_tokens = nil +} + +// CacheCreation5mTokens returns the value of the "cache_creation_5m_tokens" field in the mutation. +func (m *UsageLogMutation) CacheCreation5mTokens() (r int, exists bool) { + v := m.cache_creation_5m_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheCreation5mTokens returns the old "cache_creation_5m_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreation5mTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreation5mTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreation5mTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreation5mTokens: %w", err) + } + return oldValue.CacheCreation5mTokens, nil +} + +// AddCacheCreation5mTokens adds i to the "cache_creation_5m_tokens" field. +func (m *UsageLogMutation) AddCacheCreation5mTokens(i int) { + if m.addcache_creation_5m_tokens != nil { + *m.addcache_creation_5m_tokens += i + } else { + m.addcache_creation_5m_tokens = &i + } +} + +// AddedCacheCreation5mTokens returns the value that was added to the "cache_creation_5m_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreation5mTokens() (r int, exists bool) { + v := m.addcache_creation_5m_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreation5mTokens resets all changes to the "cache_creation_5m_tokens" field. +func (m *UsageLogMutation) ResetCacheCreation5mTokens() { + m.cache_creation_5m_tokens = nil + m.addcache_creation_5m_tokens = nil +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (m *UsageLogMutation) SetCacheCreation1hTokens(i int) { + m.cache_creation_1h_tokens = &i + m.addcache_creation_1h_tokens = nil +} + +// CacheCreation1hTokens returns the value of the "cache_creation_1h_tokens" field in the mutation. +func (m *UsageLogMutation) CacheCreation1hTokens() (r int, exists bool) { + v := m.cache_creation_1h_tokens + if v == nil { + return + } + return *v, true +} + +// OldCacheCreation1hTokens returns the old "cache_creation_1h_tokens" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreation1hTokens(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreation1hTokens is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreation1hTokens requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreation1hTokens: %w", err) + } + return oldValue.CacheCreation1hTokens, nil +} + +// AddCacheCreation1hTokens adds i to the "cache_creation_1h_tokens" field. +func (m *UsageLogMutation) AddCacheCreation1hTokens(i int) { + if m.addcache_creation_1h_tokens != nil { + *m.addcache_creation_1h_tokens += i + } else { + m.addcache_creation_1h_tokens = &i + } +} + +// AddedCacheCreation1hTokens returns the value that was added to the "cache_creation_1h_tokens" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreation1hTokens() (r int, exists bool) { + v := m.addcache_creation_1h_tokens + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreation1hTokens resets all changes to the "cache_creation_1h_tokens" field. +func (m *UsageLogMutation) ResetCacheCreation1hTokens() { + m.cache_creation_1h_tokens = nil + m.addcache_creation_1h_tokens = nil +} + +// SetInputCost sets the "input_cost" field. +func (m *UsageLogMutation) SetInputCost(f float64) { + m.input_cost = &f + m.addinput_cost = nil +} + +// InputCost returns the value of the "input_cost" field in the mutation. +func (m *UsageLogMutation) InputCost() (r float64, exists bool) { + v := m.input_cost + if v == nil { + return + } + return *v, true +} + +// OldInputCost returns the old "input_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldInputCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldInputCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldInputCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldInputCost: %w", err) + } + return oldValue.InputCost, nil +} + +// AddInputCost adds f to the "input_cost" field. +func (m *UsageLogMutation) AddInputCost(f float64) { + if m.addinput_cost != nil { + *m.addinput_cost += f + } else { + m.addinput_cost = &f + } +} + +// AddedInputCost returns the value that was added to the "input_cost" field in this mutation. +func (m *UsageLogMutation) AddedInputCost() (r float64, exists bool) { + v := m.addinput_cost + if v == nil { + return + } + return *v, true +} + +// ResetInputCost resets all changes to the "input_cost" field. +func (m *UsageLogMutation) ResetInputCost() { + m.input_cost = nil + m.addinput_cost = nil +} + +// SetOutputCost sets the "output_cost" field. +func (m *UsageLogMutation) SetOutputCost(f float64) { + m.output_cost = &f + m.addoutput_cost = nil +} + +// OutputCost returns the value of the "output_cost" field in the mutation. +func (m *UsageLogMutation) OutputCost() (r float64, exists bool) { + v := m.output_cost + if v == nil { + return + } + return *v, true +} + +// OldOutputCost returns the old "output_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldOutputCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOutputCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOutputCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOutputCost: %w", err) + } + return oldValue.OutputCost, nil +} + +// AddOutputCost adds f to the "output_cost" field. +func (m *UsageLogMutation) AddOutputCost(f float64) { + if m.addoutput_cost != nil { + *m.addoutput_cost += f + } else { + m.addoutput_cost = &f + } +} + +// AddedOutputCost returns the value that was added to the "output_cost" field in this mutation. +func (m *UsageLogMutation) AddedOutputCost() (r float64, exists bool) { + v := m.addoutput_cost + if v == nil { + return + } + return *v, true +} + +// ResetOutputCost resets all changes to the "output_cost" field. +func (m *UsageLogMutation) ResetOutputCost() { + m.output_cost = nil + m.addoutput_cost = nil +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (m *UsageLogMutation) SetCacheCreationCost(f float64) { + m.cache_creation_cost = &f + m.addcache_creation_cost = nil +} + +// CacheCreationCost returns the value of the "cache_creation_cost" field in the mutation. +func (m *UsageLogMutation) CacheCreationCost() (r float64, exists bool) { + v := m.cache_creation_cost + if v == nil { + return + } + return *v, true +} + +// OldCacheCreationCost returns the old "cache_creation_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheCreationCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheCreationCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheCreationCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheCreationCost: %w", err) + } + return oldValue.CacheCreationCost, nil +} + +// AddCacheCreationCost adds f to the "cache_creation_cost" field. +func (m *UsageLogMutation) AddCacheCreationCost(f float64) { + if m.addcache_creation_cost != nil { + *m.addcache_creation_cost += f + } else { + m.addcache_creation_cost = &f + } +} + +// AddedCacheCreationCost returns the value that was added to the "cache_creation_cost" field in this mutation. +func (m *UsageLogMutation) AddedCacheCreationCost() (r float64, exists bool) { + v := m.addcache_creation_cost + if v == nil { + return + } + return *v, true +} + +// ResetCacheCreationCost resets all changes to the "cache_creation_cost" field. +func (m *UsageLogMutation) ResetCacheCreationCost() { + m.cache_creation_cost = nil + m.addcache_creation_cost = nil +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (m *UsageLogMutation) SetCacheReadCost(f float64) { + m.cache_read_cost = &f + m.addcache_read_cost = nil +} + +// CacheReadCost returns the value of the "cache_read_cost" field in the mutation. +func (m *UsageLogMutation) CacheReadCost() (r float64, exists bool) { + v := m.cache_read_cost + if v == nil { + return + } + return *v, true +} + +// OldCacheReadCost returns the old "cache_read_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheReadCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheReadCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheReadCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheReadCost: %w", err) + } + return oldValue.CacheReadCost, nil +} + +// AddCacheReadCost adds f to the "cache_read_cost" field. +func (m *UsageLogMutation) AddCacheReadCost(f float64) { + if m.addcache_read_cost != nil { + *m.addcache_read_cost += f + } else { + m.addcache_read_cost = &f + } +} + +// AddedCacheReadCost returns the value that was added to the "cache_read_cost" field in this mutation. +func (m *UsageLogMutation) AddedCacheReadCost() (r float64, exists bool) { + v := m.addcache_read_cost + if v == nil { + return + } + return *v, true +} + +// ResetCacheReadCost resets all changes to the "cache_read_cost" field. +func (m *UsageLogMutation) ResetCacheReadCost() { + m.cache_read_cost = nil + m.addcache_read_cost = nil +} + +// SetTotalCost sets the "total_cost" field. +func (m *UsageLogMutation) SetTotalCost(f float64) { + m.total_cost = &f + m.addtotal_cost = nil +} + +// TotalCost returns the value of the "total_cost" field in the mutation. +func (m *UsageLogMutation) TotalCost() (r float64, exists bool) { + v := m.total_cost + if v == nil { + return + } + return *v, true +} + +// OldTotalCost returns the old "total_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldTotalCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotalCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotalCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotalCost: %w", err) + } + return oldValue.TotalCost, nil +} + +// AddTotalCost adds f to the "total_cost" field. +func (m *UsageLogMutation) AddTotalCost(f float64) { + if m.addtotal_cost != nil { + *m.addtotal_cost += f + } else { + m.addtotal_cost = &f + } +} + +// AddedTotalCost returns the value that was added to the "total_cost" field in this mutation. +func (m *UsageLogMutation) AddedTotalCost() (r float64, exists bool) { + v := m.addtotal_cost + if v == nil { + return + } + return *v, true +} + +// ResetTotalCost resets all changes to the "total_cost" field. +func (m *UsageLogMutation) ResetTotalCost() { + m.total_cost = nil + m.addtotal_cost = nil +} + +// SetActualCost sets the "actual_cost" field. +func (m *UsageLogMutation) SetActualCost(f float64) { + m.actual_cost = &f + m.addactual_cost = nil +} + +// ActualCost returns the value of the "actual_cost" field in the mutation. +func (m *UsageLogMutation) ActualCost() (r float64, exists bool) { + v := m.actual_cost + if v == nil { + return + } + return *v, true +} + +// OldActualCost returns the old "actual_cost" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldActualCost(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldActualCost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldActualCost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldActualCost: %w", err) + } + return oldValue.ActualCost, nil +} + +// AddActualCost adds f to the "actual_cost" field. +func (m *UsageLogMutation) AddActualCost(f float64) { + if m.addactual_cost != nil { + *m.addactual_cost += f + } else { + m.addactual_cost = &f + } +} + +// AddedActualCost returns the value that was added to the "actual_cost" field in this mutation. +func (m *UsageLogMutation) AddedActualCost() (r float64, exists bool) { + v := m.addactual_cost + if v == nil { + return + } + return *v, true +} + +// ResetActualCost resets all changes to the "actual_cost" field. +func (m *UsageLogMutation) ResetActualCost() { + m.actual_cost = nil + m.addactual_cost = nil +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (m *UsageLogMutation) SetRateMultiplier(f float64) { + m.rate_multiplier = &f + m.addrate_multiplier = nil +} + +// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. +func (m *UsageLogMutation) RateMultiplier() (r float64, exists bool) { + v := m.rate_multiplier + if v == nil { + return + } + return *v, true +} + +// OldRateMultiplier returns the old "rate_multiplier" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateMultiplier requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) + } + return oldValue.RateMultiplier, nil +} + +// AddRateMultiplier adds f to the "rate_multiplier" field. +func (m *UsageLogMutation) AddRateMultiplier(f float64) { + if m.addrate_multiplier != nil { + *m.addrate_multiplier += f + } else { + m.addrate_multiplier = &f + } +} + +// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. +func (m *UsageLogMutation) AddedRateMultiplier() (r float64, exists bool) { + v := m.addrate_multiplier + if v == nil { + return + } + return *v, true +} + +// ResetRateMultiplier resets all changes to the "rate_multiplier" field. +func (m *UsageLogMutation) ResetRateMultiplier() { + m.rate_multiplier = nil + m.addrate_multiplier = nil +} + +// SetBillingType sets the "billing_type" field. +func (m *UsageLogMutation) SetBillingType(i int8) { + m.billing_type = &i + m.addbilling_type = nil +} + +// BillingType returns the value of the "billing_type" field in the mutation. +func (m *UsageLogMutation) BillingType() (r int8, exists bool) { + v := m.billing_type + if v == nil { + return + } + return *v, true +} + +// OldBillingType returns the old "billing_type" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldBillingType(ctx context.Context) (v int8, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBillingType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBillingType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBillingType: %w", err) + } + return oldValue.BillingType, nil +} + +// AddBillingType adds i to the "billing_type" field. +func (m *UsageLogMutation) AddBillingType(i int8) { + if m.addbilling_type != nil { + *m.addbilling_type += i + } else { + m.addbilling_type = &i + } +} + +// AddedBillingType returns the value that was added to the "billing_type" field in this mutation. +func (m *UsageLogMutation) AddedBillingType() (r int8, exists bool) { + v := m.addbilling_type + if v == nil { + return + } + return *v, true +} + +// ResetBillingType resets all changes to the "billing_type" field. +func (m *UsageLogMutation) ResetBillingType() { + m.billing_type = nil + m.addbilling_type = nil +} + +// SetStream sets the "stream" field. +func (m *UsageLogMutation) SetStream(b bool) { + m.stream = &b +} + +// Stream returns the value of the "stream" field in the mutation. +func (m *UsageLogMutation) Stream() (r bool, exists bool) { + v := m.stream + if v == nil { + return + } + return *v, true +} + +// OldStream returns the old "stream" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldStream(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStream is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStream requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStream: %w", err) + } + return oldValue.Stream, nil +} + +// ResetStream resets all changes to the "stream" field. +func (m *UsageLogMutation) ResetStream() { + m.stream = nil +} + +// SetDurationMs sets the "duration_ms" field. +func (m *UsageLogMutation) SetDurationMs(i int) { + m.duration_ms = &i + m.addduration_ms = nil +} + +// DurationMs returns the value of the "duration_ms" field in the mutation. +func (m *UsageLogMutation) DurationMs() (r int, exists bool) { + v := m.duration_ms + if v == nil { + return + } + return *v, true +} + +// OldDurationMs returns the old "duration_ms" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldDurationMs(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDurationMs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDurationMs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDurationMs: %w", err) + } + return oldValue.DurationMs, nil +} + +// AddDurationMs adds i to the "duration_ms" field. +func (m *UsageLogMutation) AddDurationMs(i int) { + if m.addduration_ms != nil { + *m.addduration_ms += i + } else { + m.addduration_ms = &i + } +} + +// AddedDurationMs returns the value that was added to the "duration_ms" field in this mutation. +func (m *UsageLogMutation) AddedDurationMs() (r int, exists bool) { + v := m.addduration_ms + if v == nil { + return + } + return *v, true +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (m *UsageLogMutation) ClearDurationMs() { + m.duration_ms = nil + m.addduration_ms = nil + m.clearedFields[usagelog.FieldDurationMs] = struct{}{} +} + +// DurationMsCleared returns if the "duration_ms" field was cleared in this mutation. +func (m *UsageLogMutation) DurationMsCleared() bool { + _, ok := m.clearedFields[usagelog.FieldDurationMs] + return ok +} + +// ResetDurationMs resets all changes to the "duration_ms" field. +func (m *UsageLogMutation) ResetDurationMs() { + m.duration_ms = nil + m.addduration_ms = nil + delete(m.clearedFields, usagelog.FieldDurationMs) +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (m *UsageLogMutation) SetFirstTokenMs(i int) { + m.first_token_ms = &i + m.addfirst_token_ms = nil +} + +// FirstTokenMs returns the value of the "first_token_ms" field in the mutation. +func (m *UsageLogMutation) FirstTokenMs() (r int, exists bool) { + v := m.first_token_ms + if v == nil { + return + } + return *v, true +} + +// OldFirstTokenMs returns the old "first_token_ms" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldFirstTokenMs(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFirstTokenMs is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFirstTokenMs requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFirstTokenMs: %w", err) + } + return oldValue.FirstTokenMs, nil +} + +// AddFirstTokenMs adds i to the "first_token_ms" field. +func (m *UsageLogMutation) AddFirstTokenMs(i int) { + if m.addfirst_token_ms != nil { + *m.addfirst_token_ms += i + } else { + m.addfirst_token_ms = &i + } +} + +// AddedFirstTokenMs returns the value that was added to the "first_token_ms" field in this mutation. +func (m *UsageLogMutation) AddedFirstTokenMs() (r int, exists bool) { + v := m.addfirst_token_ms + if v == nil { + return + } + return *v, true +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (m *UsageLogMutation) ClearFirstTokenMs() { + m.first_token_ms = nil + m.addfirst_token_ms = nil + m.clearedFields[usagelog.FieldFirstTokenMs] = struct{}{} +} + +// FirstTokenMsCleared returns if the "first_token_ms" field was cleared in this mutation. +func (m *UsageLogMutation) FirstTokenMsCleared() bool { + _, ok := m.clearedFields[usagelog.FieldFirstTokenMs] + return ok +} + +// ResetFirstTokenMs resets all changes to the "first_token_ms" field. +func (m *UsageLogMutation) ResetFirstTokenMs() { + m.first_token_ms = nil + m.addfirst_token_ms = nil + delete(m.clearedFields, usagelog.FieldFirstTokenMs) +} + +// SetCreatedAt sets the "created_at" field. +func (m *UsageLogMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *UsageLogMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *UsageLogMutation) ResetCreatedAt() { + m.created_at = nil +} + +// ClearUser clears the "user" edge to the User entity. +func (m *UsageLogMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[usagelog.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *UsageLogMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *UsageLogMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// ClearAPIKey clears the "api_key" edge to the ApiKey entity. +func (m *UsageLogMutation) ClearAPIKey() { + m.clearedapi_key = true + m.clearedFields[usagelog.FieldAPIKeyID] = struct{}{} +} + +// APIKeyCleared reports if the "api_key" edge to the ApiKey entity was cleared. +func (m *UsageLogMutation) APIKeyCleared() bool { + return m.clearedapi_key +} + +// APIKeyIDs returns the "api_key" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// APIKeyID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) APIKeyIDs() (ids []int64) { + if id := m.api_key; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAPIKey resets all changes to the "api_key" edge. +func (m *UsageLogMutation) ResetAPIKey() { + m.api_key = nil + m.clearedapi_key = false +} + +// ClearAccount clears the "account" edge to the Account entity. +func (m *UsageLogMutation) ClearAccount() { + m.clearedaccount = true + m.clearedFields[usagelog.FieldAccountID] = struct{}{} +} + +// AccountCleared reports if the "account" edge to the Account entity was cleared. +func (m *UsageLogMutation) AccountCleared() bool { + return m.clearedaccount +} + +// AccountIDs returns the "account" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// AccountID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) AccountIDs() (ids []int64) { + if id := m.account; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAccount resets all changes to the "account" edge. +func (m *UsageLogMutation) ResetAccount() { + m.account = nil + m.clearedaccount = false +} + +// ClearGroup clears the "group" edge to the Group entity. +func (m *UsageLogMutation) ClearGroup() { + m.clearedgroup = true + m.clearedFields[usagelog.FieldGroupID] = struct{}{} +} + +// GroupCleared reports if the "group" edge to the Group entity was cleared. +func (m *UsageLogMutation) GroupCleared() bool { + return m.GroupIDCleared() || m.clearedgroup +} + +// GroupIDs returns the "group" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// GroupID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) GroupIDs() (ids []int64) { + if id := m.group; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetGroup resets all changes to the "group" edge. +func (m *UsageLogMutation) ResetGroup() { + m.group = nil + m.clearedgroup = false +} + +// ClearSubscription clears the "subscription" edge to the UserSubscription entity. +func (m *UsageLogMutation) ClearSubscription() { + m.clearedsubscription = true + m.clearedFields[usagelog.FieldSubscriptionID] = struct{}{} +} + +// SubscriptionCleared reports if the "subscription" edge to the UserSubscription entity was cleared. +func (m *UsageLogMutation) SubscriptionCleared() bool { + return m.SubscriptionIDCleared() || m.clearedsubscription +} + +// SubscriptionIDs returns the "subscription" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// SubscriptionID instead. It exists only for internal usage by the builders. +func (m *UsageLogMutation) SubscriptionIDs() (ids []int64) { + if id := m.subscription; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetSubscription resets all changes to the "subscription" edge. +func (m *UsageLogMutation) ResetSubscription() { + m.subscription = nil + m.clearedsubscription = false +} + +// Where appends a list predicates to the UsageLogMutation builder. +func (m *UsageLogMutation) Where(ps ...predicate.UsageLog) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the UsageLogMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UsageLogMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UsageLog, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *UsageLogMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *UsageLogMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (UsageLog). +func (m *UsageLogMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *UsageLogMutation) Fields() []string { + fields := make([]string, 0, 25) + if m.user != nil { + fields = append(fields, usagelog.FieldUserID) + } + if m.api_key != nil { + fields = append(fields, usagelog.FieldAPIKeyID) + } + if m.account != nil { + fields = append(fields, usagelog.FieldAccountID) + } + if m.request_id != nil { + fields = append(fields, usagelog.FieldRequestID) + } + if m.model != nil { + fields = append(fields, usagelog.FieldModel) + } + if m.group != nil { + fields = append(fields, usagelog.FieldGroupID) + } + if m.subscription != nil { + fields = append(fields, usagelog.FieldSubscriptionID) + } + if m.input_tokens != nil { + fields = append(fields, usagelog.FieldInputTokens) + } + if m.output_tokens != nil { + fields = append(fields, usagelog.FieldOutputTokens) + } + if m.cache_creation_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreationTokens) + } + if m.cache_read_tokens != nil { + fields = append(fields, usagelog.FieldCacheReadTokens) + } + if m.cache_creation_5m_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation5mTokens) + } + if m.cache_creation_1h_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation1hTokens) + } + if m.input_cost != nil { + fields = append(fields, usagelog.FieldInputCost) + } + if m.output_cost != nil { + fields = append(fields, usagelog.FieldOutputCost) + } + if m.cache_creation_cost != nil { + fields = append(fields, usagelog.FieldCacheCreationCost) + } + if m.cache_read_cost != nil { + fields = append(fields, usagelog.FieldCacheReadCost) + } + if m.total_cost != nil { + fields = append(fields, usagelog.FieldTotalCost) + } + if m.actual_cost != nil { + fields = append(fields, usagelog.FieldActualCost) + } + if m.rate_multiplier != nil { + fields = append(fields, usagelog.FieldRateMultiplier) + } + if m.billing_type != nil { + fields = append(fields, usagelog.FieldBillingType) + } + if m.stream != nil { + fields = append(fields, usagelog.FieldStream) + } + if m.duration_ms != nil { + fields = append(fields, usagelog.FieldDurationMs) + } + if m.first_token_ms != nil { + fields = append(fields, usagelog.FieldFirstTokenMs) + } + if m.created_at != nil { + fields = append(fields, usagelog.FieldCreatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { + switch name { + case usagelog.FieldUserID: + return m.UserID() + case usagelog.FieldAPIKeyID: + return m.APIKeyID() + case usagelog.FieldAccountID: + return m.AccountID() + case usagelog.FieldRequestID: + return m.RequestID() + case usagelog.FieldModel: + return m.Model() + case usagelog.FieldGroupID: + return m.GroupID() + case usagelog.FieldSubscriptionID: + return m.SubscriptionID() + case usagelog.FieldInputTokens: + return m.InputTokens() + case usagelog.FieldOutputTokens: + return m.OutputTokens() + case usagelog.FieldCacheCreationTokens: + return m.CacheCreationTokens() + case usagelog.FieldCacheReadTokens: + return m.CacheReadTokens() + case usagelog.FieldCacheCreation5mTokens: + return m.CacheCreation5mTokens() + case usagelog.FieldCacheCreation1hTokens: + return m.CacheCreation1hTokens() + case usagelog.FieldInputCost: + return m.InputCost() + case usagelog.FieldOutputCost: + return m.OutputCost() + case usagelog.FieldCacheCreationCost: + return m.CacheCreationCost() + case usagelog.FieldCacheReadCost: + return m.CacheReadCost() + case usagelog.FieldTotalCost: + return m.TotalCost() + case usagelog.FieldActualCost: + return m.ActualCost() + case usagelog.FieldRateMultiplier: + return m.RateMultiplier() + case usagelog.FieldBillingType: + return m.BillingType() + case usagelog.FieldStream: + return m.Stream() + case usagelog.FieldDurationMs: + return m.DurationMs() + case usagelog.FieldFirstTokenMs: + return m.FirstTokenMs() + case usagelog.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case usagelog.FieldUserID: + return m.OldUserID(ctx) + case usagelog.FieldAPIKeyID: + return m.OldAPIKeyID(ctx) + case usagelog.FieldAccountID: + return m.OldAccountID(ctx) + case usagelog.FieldRequestID: + return m.OldRequestID(ctx) + case usagelog.FieldModel: + return m.OldModel(ctx) + case usagelog.FieldGroupID: + return m.OldGroupID(ctx) + case usagelog.FieldSubscriptionID: + return m.OldSubscriptionID(ctx) + case usagelog.FieldInputTokens: + return m.OldInputTokens(ctx) + case usagelog.FieldOutputTokens: + return m.OldOutputTokens(ctx) + case usagelog.FieldCacheCreationTokens: + return m.OldCacheCreationTokens(ctx) + case usagelog.FieldCacheReadTokens: + return m.OldCacheReadTokens(ctx) + case usagelog.FieldCacheCreation5mTokens: + return m.OldCacheCreation5mTokens(ctx) + case usagelog.FieldCacheCreation1hTokens: + return m.OldCacheCreation1hTokens(ctx) + case usagelog.FieldInputCost: + return m.OldInputCost(ctx) + case usagelog.FieldOutputCost: + return m.OldOutputCost(ctx) + case usagelog.FieldCacheCreationCost: + return m.OldCacheCreationCost(ctx) + case usagelog.FieldCacheReadCost: + return m.OldCacheReadCost(ctx) + case usagelog.FieldTotalCost: + return m.OldTotalCost(ctx) + case usagelog.FieldActualCost: + return m.OldActualCost(ctx) + case usagelog.FieldRateMultiplier: + return m.OldRateMultiplier(ctx) + case usagelog.FieldBillingType: + return m.OldBillingType(ctx) + case usagelog.FieldStream: + return m.OldStream(ctx) + case usagelog.FieldDurationMs: + return m.OldDurationMs(ctx) + case usagelog.FieldFirstTokenMs: + return m.OldFirstTokenMs(ctx) + case usagelog.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown UsageLog field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UsageLogMutation) SetField(name string, value ent.Value) error { + switch name { + case usagelog.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case usagelog.FieldAPIKeyID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAPIKeyID(v) + return nil + case usagelog.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccountID(v) + return nil + case usagelog.FieldRequestID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestID(v) + return nil + case usagelog.FieldModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModel(v) + return nil + case usagelog.FieldGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetGroupID(v) + return nil + case usagelog.FieldSubscriptionID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionID(v) + return nil + case usagelog.FieldInputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetInputTokens(v) + return nil + case usagelog.FieldOutputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOutputTokens(v) + return nil + case usagelog.FieldCacheCreationTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreationTokens(v) + return nil + case usagelog.FieldCacheReadTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheReadTokens(v) + return nil + case usagelog.FieldCacheCreation5mTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreation5mTokens(v) + return nil + case usagelog.FieldCacheCreation1hTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreation1hTokens(v) + return nil + case usagelog.FieldInputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetInputCost(v) + return nil + case usagelog.FieldOutputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOutputCost(v) + return nil + case usagelog.FieldCacheCreationCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheCreationCost(v) + return nil + case usagelog.FieldCacheReadCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheReadCost(v) + return nil + case usagelog.FieldTotalCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotalCost(v) + return nil + case usagelog.FieldActualCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetActualCost(v) + return nil + case usagelog.FieldRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateMultiplier(v) + return nil + case usagelog.FieldBillingType: + v, ok := value.(int8) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBillingType(v) + return nil + case usagelog.FieldStream: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStream(v) + return nil + case usagelog.FieldDurationMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDurationMs(v) + return nil + case usagelog.FieldFirstTokenMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFirstTokenMs(v) + return nil + case usagelog.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown UsageLog field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *UsageLogMutation) AddedFields() []string { + var fields []string + if m.addinput_tokens != nil { + fields = append(fields, usagelog.FieldInputTokens) + } + if m.addoutput_tokens != nil { + fields = append(fields, usagelog.FieldOutputTokens) + } + if m.addcache_creation_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreationTokens) + } + if m.addcache_read_tokens != nil { + fields = append(fields, usagelog.FieldCacheReadTokens) + } + if m.addcache_creation_5m_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation5mTokens) + } + if m.addcache_creation_1h_tokens != nil { + fields = append(fields, usagelog.FieldCacheCreation1hTokens) + } + if m.addinput_cost != nil { + fields = append(fields, usagelog.FieldInputCost) + } + if m.addoutput_cost != nil { + fields = append(fields, usagelog.FieldOutputCost) + } + if m.addcache_creation_cost != nil { + fields = append(fields, usagelog.FieldCacheCreationCost) + } + if m.addcache_read_cost != nil { + fields = append(fields, usagelog.FieldCacheReadCost) + } + if m.addtotal_cost != nil { + fields = append(fields, usagelog.FieldTotalCost) + } + if m.addactual_cost != nil { + fields = append(fields, usagelog.FieldActualCost) + } + if m.addrate_multiplier != nil { + fields = append(fields, usagelog.FieldRateMultiplier) + } + if m.addbilling_type != nil { + fields = append(fields, usagelog.FieldBillingType) + } + if m.addduration_ms != nil { + fields = append(fields, usagelog.FieldDurationMs) + } + if m.addfirst_token_ms != nil { + fields = append(fields, usagelog.FieldFirstTokenMs) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case usagelog.FieldInputTokens: + return m.AddedInputTokens() + case usagelog.FieldOutputTokens: + return m.AddedOutputTokens() + case usagelog.FieldCacheCreationTokens: + return m.AddedCacheCreationTokens() + case usagelog.FieldCacheReadTokens: + return m.AddedCacheReadTokens() + case usagelog.FieldCacheCreation5mTokens: + return m.AddedCacheCreation5mTokens() + case usagelog.FieldCacheCreation1hTokens: + return m.AddedCacheCreation1hTokens() + case usagelog.FieldInputCost: + return m.AddedInputCost() + case usagelog.FieldOutputCost: + return m.AddedOutputCost() + case usagelog.FieldCacheCreationCost: + return m.AddedCacheCreationCost() + case usagelog.FieldCacheReadCost: + return m.AddedCacheReadCost() + case usagelog.FieldTotalCost: + return m.AddedTotalCost() + case usagelog.FieldActualCost: + return m.AddedActualCost() + case usagelog.FieldRateMultiplier: + return m.AddedRateMultiplier() + case usagelog.FieldBillingType: + return m.AddedBillingType() + case usagelog.FieldDurationMs: + return m.AddedDurationMs() + case usagelog.FieldFirstTokenMs: + return m.AddedFirstTokenMs() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UsageLogMutation) AddField(name string, value ent.Value) error { + switch name { + case usagelog.FieldInputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddInputTokens(v) + return nil + case usagelog.FieldOutputTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddOutputTokens(v) + return nil + case usagelog.FieldCacheCreationTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreationTokens(v) + return nil + case usagelog.FieldCacheReadTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheReadTokens(v) + return nil + case usagelog.FieldCacheCreation5mTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreation5mTokens(v) + return nil + case usagelog.FieldCacheCreation1hTokens: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreation1hTokens(v) + return nil + case usagelog.FieldInputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddInputCost(v) + return nil + case usagelog.FieldOutputCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddOutputCost(v) + return nil + case usagelog.FieldCacheCreationCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheCreationCost(v) + return nil + case usagelog.FieldCacheReadCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCacheReadCost(v) + return nil + case usagelog.FieldTotalCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTotalCost(v) + return nil + case usagelog.FieldActualCost: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddActualCost(v) + return nil + case usagelog.FieldRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateMultiplier(v) + return nil + case usagelog.FieldBillingType: + v, ok := value.(int8) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddBillingType(v) + return nil + case usagelog.FieldDurationMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDurationMs(v) + return nil + case usagelog.FieldFirstTokenMs: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFirstTokenMs(v) + return nil + } + return fmt.Errorf("unknown UsageLog numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *UsageLogMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(usagelog.FieldGroupID) { + fields = append(fields, usagelog.FieldGroupID) + } + if m.FieldCleared(usagelog.FieldSubscriptionID) { + fields = append(fields, usagelog.FieldSubscriptionID) + } + if m.FieldCleared(usagelog.FieldDurationMs) { + fields = append(fields, usagelog.FieldDurationMs) + } + if m.FieldCleared(usagelog.FieldFirstTokenMs) { + fields = append(fields, usagelog.FieldFirstTokenMs) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *UsageLogMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *UsageLogMutation) ClearField(name string) error { + switch name { + case usagelog.FieldGroupID: + m.ClearGroupID() + return nil + case usagelog.FieldSubscriptionID: + m.ClearSubscriptionID() + return nil + case usagelog.FieldDurationMs: + m.ClearDurationMs() + return nil + case usagelog.FieldFirstTokenMs: + m.ClearFirstTokenMs() + return nil + } + return fmt.Errorf("unknown UsageLog nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *UsageLogMutation) ResetField(name string) error { + switch name { + case usagelog.FieldUserID: + m.ResetUserID() + return nil + case usagelog.FieldAPIKeyID: + m.ResetAPIKeyID() + return nil + case usagelog.FieldAccountID: + m.ResetAccountID() + return nil + case usagelog.FieldRequestID: + m.ResetRequestID() + return nil + case usagelog.FieldModel: + m.ResetModel() + return nil + case usagelog.FieldGroupID: + m.ResetGroupID() + return nil + case usagelog.FieldSubscriptionID: + m.ResetSubscriptionID() + return nil + case usagelog.FieldInputTokens: + m.ResetInputTokens() + return nil + case usagelog.FieldOutputTokens: + m.ResetOutputTokens() + return nil + case usagelog.FieldCacheCreationTokens: + m.ResetCacheCreationTokens() + return nil + case usagelog.FieldCacheReadTokens: + m.ResetCacheReadTokens() + return nil + case usagelog.FieldCacheCreation5mTokens: + m.ResetCacheCreation5mTokens() + return nil + case usagelog.FieldCacheCreation1hTokens: + m.ResetCacheCreation1hTokens() + return nil + case usagelog.FieldInputCost: + m.ResetInputCost() + return nil + case usagelog.FieldOutputCost: + m.ResetOutputCost() + return nil + case usagelog.FieldCacheCreationCost: + m.ResetCacheCreationCost() + return nil + case usagelog.FieldCacheReadCost: + m.ResetCacheReadCost() + return nil + case usagelog.FieldTotalCost: + m.ResetTotalCost() + return nil + case usagelog.FieldActualCost: + m.ResetActualCost() + return nil + case usagelog.FieldRateMultiplier: + m.ResetRateMultiplier() + return nil + case usagelog.FieldBillingType: + m.ResetBillingType() + return nil + case usagelog.FieldStream: + m.ResetStream() + return nil + case usagelog.FieldDurationMs: + m.ResetDurationMs() + return nil + case usagelog.FieldFirstTokenMs: + m.ResetFirstTokenMs() + return nil + case usagelog.FieldCreatedAt: + m.ResetCreatedAt() + return nil + } + return fmt.Errorf("unknown UsageLog field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *UsageLogMutation) AddedEdges() []string { + edges := make([]string, 0, 5) + if m.user != nil { + edges = append(edges, usagelog.EdgeUser) + } + if m.api_key != nil { + edges = append(edges, usagelog.EdgeAPIKey) + } + if m.account != nil { + edges = append(edges, usagelog.EdgeAccount) + } + if m.group != nil { + edges = append(edges, usagelog.EdgeGroup) + } + if m.subscription != nil { + edges = append(edges, usagelog.EdgeSubscription) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *UsageLogMutation) AddedIDs(name string) []ent.Value { + switch name { + case usagelog.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeAPIKey: + if id := m.api_key; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeAccount: + if id := m.account; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeGroup: + if id := m.group; id != nil { + return []ent.Value{*id} + } + case usagelog.EdgeSubscription: + if id := m.subscription; id != nil { + return []ent.Value{*id} + } + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *UsageLogMutation) RemovedEdges() []string { + edges := make([]string, 0, 5) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *UsageLogMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *UsageLogMutation) ClearedEdges() []string { + edges := make([]string, 0, 5) + if m.cleareduser { + edges = append(edges, usagelog.EdgeUser) + } + if m.clearedapi_key { + edges = append(edges, usagelog.EdgeAPIKey) + } + if m.clearedaccount { + edges = append(edges, usagelog.EdgeAccount) + } + if m.clearedgroup { + edges = append(edges, usagelog.EdgeGroup) + } + if m.clearedsubscription { + edges = append(edges, usagelog.EdgeSubscription) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *UsageLogMutation) EdgeCleared(name string) bool { + switch name { + case usagelog.EdgeUser: + return m.cleareduser + case usagelog.EdgeAPIKey: + return m.clearedapi_key + case usagelog.EdgeAccount: + return m.clearedaccount + case usagelog.EdgeGroup: + return m.clearedgroup + case usagelog.EdgeSubscription: + return m.clearedsubscription + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *UsageLogMutation) ClearEdge(name string) error { + switch name { + case usagelog.EdgeUser: + m.ClearUser() + return nil + case usagelog.EdgeAPIKey: + m.ClearAPIKey() + return nil + case usagelog.EdgeAccount: + m.ClearAccount() + return nil + case usagelog.EdgeGroup: + m.ClearGroup() + return nil + case usagelog.EdgeSubscription: + m.ClearSubscription() + return nil + } + return fmt.Errorf("unknown UsageLog unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *UsageLogMutation) ResetEdge(name string) error { + switch name { + case usagelog.EdgeUser: + m.ResetUser() + return nil + case usagelog.EdgeAPIKey: + m.ResetAPIKey() + return nil + case usagelog.EdgeAccount: + m.ResetAccount() + return nil + case usagelog.EdgeGroup: + m.ResetGroup() + return nil + case usagelog.EdgeSubscription: + m.ResetSubscription() + return nil + } + return fmt.Errorf("unknown UsageLog edge %s", name) +} + // UserMutation represents an operation that mutates the User nodes in the graph. type UserMutation struct { config @@ -7259,6 +10176,9 @@ type UserMutation struct { allowed_groups map[int64]struct{} removedallowed_groups map[int64]struct{} clearedallowed_groups bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -8117,6 +11037,60 @@ func (m *UserMutation) ResetAllowedGroups() { m.removedallowed_groups = nil } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *UserMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *UserMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *UserMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *UserMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *UserMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *UserMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *UserMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -8473,7 +11447,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -8489,6 +11463,9 @@ func (m *UserMutation) AddedEdges() []string { if m.allowed_groups != nil { edges = append(edges, user.EdgeAllowedGroups) } + if m.usage_logs != nil { + edges = append(edges, user.EdgeUsageLogs) + } return edges } @@ -8526,13 +11503,19 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -8548,6 +11531,9 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedallowed_groups != nil { edges = append(edges, user.EdgeAllowedGroups) } + if m.removedusage_logs != nil { + edges = append(edges, user.EdgeUsageLogs) + } return edges } @@ -8585,13 +11571,19 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 5) + edges := make([]string, 0, 6) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -8607,6 +11599,9 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedallowed_groups { edges = append(edges, user.EdgeAllowedGroups) } + if m.clearedusage_logs { + edges = append(edges, user.EdgeUsageLogs) + } return edges } @@ -8624,6 +11619,8 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedassigned_subscriptions case user.EdgeAllowedGroups: return m.clearedallowed_groups + case user.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -8655,6 +11652,9 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgeAllowedGroups: m.ResetAllowedGroups() return nil + case user.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown User edge %s", name) } @@ -9084,6 +12084,7 @@ type UserSubscriptionMutation struct { id *int64 created_at *time.Time updated_at *time.Time + deleted_at *time.Time starts_at *time.Time expires_at *time.Time status *string @@ -9105,6 +12106,9 @@ type UserSubscriptionMutation struct { clearedgroup bool assigned_by_user *int64 clearedassigned_by_user bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool done bool oldValue func(context.Context) (*UserSubscription, error) predicates []predicate.UserSubscription @@ -9280,6 +12284,55 @@ func (m *UserSubscriptionMutation) ResetUpdatedAt() { m.updated_at = nil } +// SetDeletedAt sets the "deleted_at" field. +func (m *UserSubscriptionMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *UserSubscriptionMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the UserSubscription entity. +// If the UserSubscription object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserSubscriptionMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *UserSubscriptionMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[usersubscription.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *UserSubscriptionMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[usersubscription.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *UserSubscriptionMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, usersubscription.FieldDeletedAt) +} + // SetUserID sets the "user_id" field. func (m *UserSubscriptionMutation) SetUserID(i int64) { m.user = &i @@ -10003,6 +13056,60 @@ func (m *UserSubscriptionMutation) ResetAssignedByUser() { m.clearedassigned_by_user = false } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *UserSubscriptionMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } +} + +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *UserSubscriptionMutation) ClearUsageLogs() { + m.clearedusage_logs = true +} + +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *UserSubscriptionMutation) UsageLogsCleared() bool { + return m.clearedusage_logs +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *UserSubscriptionMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) + } + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} + } +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *UserSubscriptionMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return +} + +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *UserSubscriptionMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return +} + +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *UserSubscriptionMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil +} + // Where appends a list predicates to the UserSubscriptionMutation builder. func (m *UserSubscriptionMutation) Where(ps ...predicate.UserSubscription) { m.predicates = append(m.predicates, ps...) @@ -10037,13 +13144,16 @@ func (m *UserSubscriptionMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserSubscriptionMutation) Fields() []string { - fields := make([]string, 0, 16) + fields := make([]string, 0, 17) if m.created_at != nil { fields = append(fields, usersubscription.FieldCreatedAt) } if m.updated_at != nil { fields = append(fields, usersubscription.FieldUpdatedAt) } + if m.deleted_at != nil { + fields = append(fields, usersubscription.FieldDeletedAt) + } if m.user != nil { fields = append(fields, usersubscription.FieldUserID) } @@ -10098,6 +13208,8 @@ func (m *UserSubscriptionMutation) Field(name string) (ent.Value, bool) { return m.CreatedAt() case usersubscription.FieldUpdatedAt: return m.UpdatedAt() + case usersubscription.FieldDeletedAt: + return m.DeletedAt() case usersubscription.FieldUserID: return m.UserID() case usersubscription.FieldGroupID: @@ -10139,6 +13251,8 @@ func (m *UserSubscriptionMutation) OldField(ctx context.Context, name string) (e return m.OldCreatedAt(ctx) case usersubscription.FieldUpdatedAt: return m.OldUpdatedAt(ctx) + case usersubscription.FieldDeletedAt: + return m.OldDeletedAt(ctx) case usersubscription.FieldUserID: return m.OldUserID(ctx) case usersubscription.FieldGroupID: @@ -10190,6 +13304,13 @@ func (m *UserSubscriptionMutation) SetField(name string, value ent.Value) error } m.SetUpdatedAt(v) return nil + case usersubscription.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil case usersubscription.FieldUserID: v, ok := value.(int64) if !ok { @@ -10357,6 +13478,9 @@ func (m *UserSubscriptionMutation) AddField(name string, value ent.Value) error // mutation. func (m *UserSubscriptionMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(usersubscription.FieldDeletedAt) { + fields = append(fields, usersubscription.FieldDeletedAt) + } if m.FieldCleared(usersubscription.FieldDailyWindowStart) { fields = append(fields, usersubscription.FieldDailyWindowStart) } @@ -10386,6 +13510,9 @@ func (m *UserSubscriptionMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UserSubscriptionMutation) ClearField(name string) error { switch name { + case usersubscription.FieldDeletedAt: + m.ClearDeletedAt() + return nil case usersubscription.FieldDailyWindowStart: m.ClearDailyWindowStart() return nil @@ -10415,6 +13542,9 @@ func (m *UserSubscriptionMutation) ResetField(name string) error { case usersubscription.FieldUpdatedAt: m.ResetUpdatedAt() return nil + case usersubscription.FieldDeletedAt: + m.ResetDeletedAt() + return nil case usersubscription.FieldUserID: m.ResetUserID() return nil @@ -10463,7 +13593,7 @@ func (m *UserSubscriptionMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserSubscriptionMutation) AddedEdges() []string { - edges := make([]string, 0, 3) + edges := make([]string, 0, 4) if m.user != nil { edges = append(edges, usersubscription.EdgeUser) } @@ -10473,6 +13603,9 @@ func (m *UserSubscriptionMutation) AddedEdges() []string { if m.assigned_by_user != nil { edges = append(edges, usersubscription.EdgeAssignedByUser) } + if m.usage_logs != nil { + edges = append(edges, usersubscription.EdgeUsageLogs) + } return edges } @@ -10492,25 +13625,42 @@ func (m *UserSubscriptionMutation) AddedIDs(name string) []ent.Value { if id := m.assigned_by_user; id != nil { return []ent.Value{*id} } + case usersubscription.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserSubscriptionMutation) RemovedEdges() []string { - edges := make([]string, 0, 3) + edges := make([]string, 0, 4) + if m.removedusage_logs != nil { + edges = append(edges, usersubscription.EdgeUsageLogs) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. func (m *UserSubscriptionMutation) RemovedIDs(name string) []ent.Value { + switch name { + case usersubscription.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserSubscriptionMutation) ClearedEdges() []string { - edges := make([]string, 0, 3) + edges := make([]string, 0, 4) if m.cleareduser { edges = append(edges, usersubscription.EdgeUser) } @@ -10520,6 +13670,9 @@ func (m *UserSubscriptionMutation) ClearedEdges() []string { if m.clearedassigned_by_user { edges = append(edges, usersubscription.EdgeAssignedByUser) } + if m.clearedusage_logs { + edges = append(edges, usersubscription.EdgeUsageLogs) + } return edges } @@ -10533,6 +13686,8 @@ func (m *UserSubscriptionMutation) EdgeCleared(name string) bool { return m.clearedgroup case usersubscription.EdgeAssignedByUser: return m.clearedassigned_by_user + case usersubscription.EdgeUsageLogs: + return m.clearedusage_logs } return false } @@ -10567,6 +13722,9 @@ func (m *UserSubscriptionMutation) ResetEdge(name string) error { case usersubscription.EdgeAssignedByUser: m.ResetAssignedByUser() return nil + case usersubscription.EdgeUsageLogs: + m.ResetUsageLogs() + return nil } return fmt.Errorf("unknown UserSubscription edge %s", name) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 467dad7b..f6bdf466 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -27,6 +27,9 @@ type RedeemCode func(*sql.Selector) // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) +// UsageLog is the predicate function for usagelog builders. +type UsageLog func(*sql.Selector) + // User is the predicate function for user builders. type User func(*sql.Selector) diff --git a/backend/ent/proxy.go b/backend/ent/proxy.go index eb271c7a..5228b73e 100644 --- a/backend/ent/proxy.go +++ b/backend/ent/proxy.go @@ -36,10 +36,31 @@ type Proxy struct { // Password holds the value of the "password" field. Password *string `json:"password,omitempty"` // Status holds the value of the "status" field. - Status string `json:"status,omitempty"` + Status string `json:"status,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the ProxyQuery when eager-loading is set. + Edges ProxyEdges `json:"edges"` selectValues sql.SelectValues } +// ProxyEdges holds the relations/edges for other nodes in the graph. +type ProxyEdges struct { + // Accounts holds the value of the accounts edge. + Accounts []*Account `json:"accounts,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// AccountsOrErr returns the Accounts value or an error if the edge +// was not loaded in eager-loading. +func (e ProxyEdges) AccountsOrErr() ([]*Account, error) { + if e.loadedTypes[0] { + return e.Accounts, nil + } + return nil, &NotLoadedError{edge: "accounts"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*Proxy) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -148,6 +169,11 @@ func (_m *Proxy) Value(name string) (ent.Value, error) { return _m.selectValues.Get(name) } +// QueryAccounts queries the "accounts" edge of the Proxy entity. +func (_m *Proxy) QueryAccounts() *AccountQuery { + return NewProxyClient(_m.config).QueryAccounts(_m) +} + // Update returns a builder for updating this Proxy. // Note that you need to call Proxy.Unwrap() before calling this method if this Proxy // was returned from a transaction, and the transaction was committed or rolled back. diff --git a/backend/ent/proxy/proxy.go b/backend/ent/proxy/proxy.go index e5e1067c..db7abcda 100644 --- a/backend/ent/proxy/proxy.go +++ b/backend/ent/proxy/proxy.go @@ -7,6 +7,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" ) const ( @@ -34,8 +35,17 @@ const ( FieldPassword = "password" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // EdgeAccounts holds the string denoting the accounts edge name in mutations. + EdgeAccounts = "accounts" // Table holds the table name of the proxy in the database. Table = "proxies" + // AccountsTable is the table that holds the accounts relation/edge. + AccountsTable = "accounts" + // AccountsInverseTable is the table name for the Account entity. + // It exists in this package in order to avoid circular dependency with the "account" package. + AccountsInverseTable = "accounts" + // AccountsColumn is the table column denoting the accounts relation/edge. + AccountsColumn = "proxy_id" ) // Columns holds all SQL columns for proxy fields. @@ -150,3 +160,24 @@ func ByPassword(opts ...sql.OrderTermOption) OrderOption { func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } + +// ByAccountsCount orders the results by accounts count. +func ByAccountsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAccountsStep(), opts...) + } +} + +// ByAccounts orders the results by accounts terms. +func ByAccounts(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAccountsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newAccountsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AccountsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, AccountsTable, AccountsColumn), + ) +} diff --git a/backend/ent/proxy/where.go b/backend/ent/proxy/where.go index ad92cee6..0a31ad7e 100644 --- a/backend/ent/proxy/where.go +++ b/backend/ent/proxy/where.go @@ -6,6 +6,7 @@ import ( "time" "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "github.com/Wei-Shaw/sub2api/ent/predicate" ) @@ -684,6 +685,29 @@ func StatusContainsFold(v string) predicate.Proxy { return predicate.Proxy(sql.FieldContainsFold(FieldStatus, v)) } +// HasAccounts applies the HasEdge predicate on the "accounts" edge. +func HasAccounts() predicate.Proxy { + return predicate.Proxy(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, AccountsTable, AccountsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAccountsWith applies the HasEdge predicate on the "accounts" edge with a given conditions (other predicates). +func HasAccountsWith(preds ...predicate.Account) predicate.Proxy { + return predicate.Proxy(func(s *sql.Selector) { + step := newAccountsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Proxy) predicate.Proxy { return predicate.Proxy(sql.AndPredicates(predicates...)) diff --git a/backend/ent/proxy_create.go b/backend/ent/proxy_create.go index 386abaec..9687aaa2 100644 --- a/backend/ent/proxy_create.go +++ b/backend/ent/proxy_create.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/proxy" ) @@ -130,6 +131,21 @@ func (_c *ProxyCreate) SetNillableStatus(v *string) *ProxyCreate { return _c } +// AddAccountIDs adds the "accounts" edge to the Account entity by IDs. +func (_c *ProxyCreate) AddAccountIDs(ids ...int64) *ProxyCreate { + _c.mutation.AddAccountIDs(ids...) + return _c +} + +// AddAccounts adds the "accounts" edges to the Account entity. +func (_c *ProxyCreate) AddAccounts(v ...*Account) *ProxyCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAccountIDs(ids...) +} + // Mutation returns the ProxyMutation object of the builder. func (_c *ProxyCreate) Mutation() *ProxyMutation { return _c.mutation @@ -308,6 +324,22 @@ func (_c *ProxyCreate) createSpec() (*Proxy, *sqlgraph.CreateSpec) { _spec.SetField(proxy.FieldStatus, field.TypeString, value) _node.Status = value } + if nodes := _c.mutation.AccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/proxy_query.go b/backend/ent/proxy_query.go index b0599553..1358eed2 100644 --- a/backend/ent/proxy_query.go +++ b/backend/ent/proxy_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" @@ -11,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/proxy" ) @@ -18,10 +20,11 @@ import ( // ProxyQuery is the builder for querying Proxy entities. type ProxyQuery struct { config - ctx *QueryContext - order []proxy.OrderOption - inters []Interceptor - predicates []predicate.Proxy + ctx *QueryContext + order []proxy.OrderOption + inters []Interceptor + predicates []predicate.Proxy + withAccounts *AccountQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -58,6 +61,28 @@ func (_q *ProxyQuery) Order(o ...proxy.OrderOption) *ProxyQuery { return _q } +// QueryAccounts chains the current query on the "accounts" edge. +func (_q *ProxyQuery) QueryAccounts() *AccountQuery { + query := (&AccountClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(proxy.Table, proxy.FieldID, selector), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, proxy.AccountsTable, proxy.AccountsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first Proxy entity from the query. // Returns a *NotFoundError when no Proxy was found. func (_q *ProxyQuery) First(ctx context.Context) (*Proxy, error) { @@ -245,17 +270,29 @@ func (_q *ProxyQuery) Clone() *ProxyQuery { return nil } return &ProxyQuery{ - config: _q.config, - ctx: _q.ctx.Clone(), - order: append([]proxy.OrderOption{}, _q.order...), - inters: append([]Interceptor{}, _q.inters...), - predicates: append([]predicate.Proxy{}, _q.predicates...), + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]proxy.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.Proxy{}, _q.predicates...), + withAccounts: _q.withAccounts.Clone(), // clone intermediate query. sql: _q.sql.Clone(), path: _q.path, } } +// WithAccounts tells the query-builder to eager-load the nodes that are connected to +// the "accounts" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *ProxyQuery) WithAccounts(opts ...func(*AccountQuery)) *ProxyQuery { + query := (&AccountClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAccounts = query + return _q +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -332,8 +369,11 @@ func (_q *ProxyQuery) prepareQuery(ctx context.Context) error { func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy, error) { var ( - nodes = []*Proxy{} - _spec = _q.querySpec() + nodes = []*Proxy{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withAccounts != nil, + } ) _spec.ScanValues = func(columns []string) ([]any, error) { return (*Proxy).scanValues(nil, columns) @@ -341,6 +381,7 @@ func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy, _spec.Assign = func(columns []string, values []any) error { node := &Proxy{config: _q.config} nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes return node.assignValues(columns, values) } for i := range hooks { @@ -352,9 +393,50 @@ func (_q *ProxyQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Proxy, if len(nodes) == 0 { return nodes, nil } + if query := _q.withAccounts; query != nil { + if err := _q.loadAccounts(ctx, query, nodes, + func(n *Proxy) { n.Edges.Accounts = []*Account{} }, + func(n *Proxy, e *Account) { n.Edges.Accounts = append(n.Edges.Accounts, e) }); err != nil { + return nil, err + } + } return nodes, nil } +func (_q *ProxyQuery) loadAccounts(ctx context.Context, query *AccountQuery, nodes []*Proxy, init func(*Proxy), assign func(*Proxy, *Account)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*Proxy) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(account.FieldProxyID) + } + query.Where(predicate.Account(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(proxy.AccountsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ProxyID + if fk == nil { + return fmt.Errorf(`foreign-key "proxy_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "proxy_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (_q *ProxyQuery) sqlCount(ctx context.Context) (int, error) { _spec := _q.querySpec() _spec.Node.Columns = _q.ctx.Fields diff --git a/backend/ent/proxy_update.go b/backend/ent/proxy_update.go index 3f5e1a7f..d487857f 100644 --- a/backend/ent/proxy_update.go +++ b/backend/ent/proxy_update.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/proxy" ) @@ -171,11 +172,47 @@ func (_u *ProxyUpdate) SetNillableStatus(v *string) *ProxyUpdate { return _u } +// AddAccountIDs adds the "accounts" edge to the Account entity by IDs. +func (_u *ProxyUpdate) AddAccountIDs(ids ...int64) *ProxyUpdate { + _u.mutation.AddAccountIDs(ids...) + return _u +} + +// AddAccounts adds the "accounts" edges to the Account entity. +func (_u *ProxyUpdate) AddAccounts(v ...*Account) *ProxyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAccountIDs(ids...) +} + // Mutation returns the ProxyMutation object of the builder. func (_u *ProxyUpdate) Mutation() *ProxyMutation { return _u.mutation } +// ClearAccounts clears all "accounts" edges to the Account entity. +func (_u *ProxyUpdate) ClearAccounts() *ProxyUpdate { + _u.mutation.ClearAccounts() + return _u +} + +// RemoveAccountIDs removes the "accounts" edge to Account entities by IDs. +func (_u *ProxyUpdate) RemoveAccountIDs(ids ...int64) *ProxyUpdate { + _u.mutation.RemoveAccountIDs(ids...) + return _u +} + +// RemoveAccounts removes "accounts" edges to Account entities. +func (_u *ProxyUpdate) RemoveAccounts(v ...*Account) *ProxyUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAccountIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *ProxyUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -304,6 +341,51 @@ func (_u *ProxyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Status(); ok { _spec.SetField(proxy.FieldStatus, field.TypeString, value) } + if _u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAccountsIDs(); len(nodes) > 0 && !_u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{proxy.Label} @@ -467,11 +549,47 @@ func (_u *ProxyUpdateOne) SetNillableStatus(v *string) *ProxyUpdateOne { return _u } +// AddAccountIDs adds the "accounts" edge to the Account entity by IDs. +func (_u *ProxyUpdateOne) AddAccountIDs(ids ...int64) *ProxyUpdateOne { + _u.mutation.AddAccountIDs(ids...) + return _u +} + +// AddAccounts adds the "accounts" edges to the Account entity. +func (_u *ProxyUpdateOne) AddAccounts(v ...*Account) *ProxyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAccountIDs(ids...) +} + // Mutation returns the ProxyMutation object of the builder. func (_u *ProxyUpdateOne) Mutation() *ProxyMutation { return _u.mutation } +// ClearAccounts clears all "accounts" edges to the Account entity. +func (_u *ProxyUpdateOne) ClearAccounts() *ProxyUpdateOne { + _u.mutation.ClearAccounts() + return _u +} + +// RemoveAccountIDs removes the "accounts" edge to Account entities by IDs. +func (_u *ProxyUpdateOne) RemoveAccountIDs(ids ...int64) *ProxyUpdateOne { + _u.mutation.RemoveAccountIDs(ids...) + return _u +} + +// RemoveAccounts removes "accounts" edges to Account entities. +func (_u *ProxyUpdateOne) RemoveAccounts(v ...*Account) *ProxyUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAccountIDs(ids...) +} + // Where appends a list predicates to the ProxyUpdate builder. func (_u *ProxyUpdateOne) Where(ps ...predicate.Proxy) *ProxyUpdateOne { _u.mutation.Where(ps...) @@ -630,6 +748,51 @@ func (_u *ProxyUpdateOne) sqlSave(ctx context.Context) (_node *Proxy, err error) if value, ok := _u.mutation.Status(); ok { _spec.SetField(proxy.FieldStatus, field.TypeString, value) } + if _u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAccountsIDs(); len(nodes) > 0 && !_u.mutation.AccountsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: true, + Table: proxy.AccountsTable, + Columns: []string{proxy.AccountsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &Proxy{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ef5e6bec..da0accd7 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -13,6 +13,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/schema" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -259,6 +260,10 @@ func init() { group.DefaultSubscriptionType = groupDescSubscriptionType.Default.(string) // group.SubscriptionTypeValidator is a validator for the "subscription_type" field. It is called by the builders before save. group.SubscriptionTypeValidator = groupDescSubscriptionType.Validators[0].(func(string) error) + // groupDescDefaultValidityDays is the schema descriptor for default_validity_days field. + groupDescDefaultValidityDays := groupFields[10].Descriptor() + // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. + group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) proxyMixin := schema.Proxy{}.Mixin() proxyMixinHooks1 := proxyMixin[1].Hooks() proxy.Hooks[0] = proxyMixinHooks1[0] @@ -420,6 +425,108 @@ func init() { setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time) // setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time) + usagelogFields := schema.UsageLog{}.Fields() + _ = usagelogFields + // usagelogDescRequestID is the schema descriptor for request_id field. + usagelogDescRequestID := usagelogFields[3].Descriptor() + // usagelog.RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save. + usagelog.RequestIDValidator = func() func(string) error { + validators := usagelogDescRequestID.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(request_id string) error { + for _, fn := range fns { + if err := fn(request_id); err != nil { + return err + } + } + return nil + } + }() + // usagelogDescModel is the schema descriptor for model field. + usagelogDescModel := usagelogFields[4].Descriptor() + // usagelog.ModelValidator is a validator for the "model" field. It is called by the builders before save. + usagelog.ModelValidator = func() func(string) error { + validators := usagelogDescModel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(model string) error { + for _, fn := range fns { + if err := fn(model); err != nil { + return err + } + } + return nil + } + }() + // usagelogDescInputTokens is the schema descriptor for input_tokens field. + usagelogDescInputTokens := usagelogFields[7].Descriptor() + // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. + usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) + // usagelogDescOutputTokens is the schema descriptor for output_tokens field. + usagelogDescOutputTokens := usagelogFields[8].Descriptor() + // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. + usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) + // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. + usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor() + // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. + usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) + // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. + usagelogDescCacheReadTokens := usagelogFields[10].Descriptor() + // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. + usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) + // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. + usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor() + // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. + usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) + // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. + usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor() + // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. + usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) + // usagelogDescInputCost is the schema descriptor for input_cost field. + usagelogDescInputCost := usagelogFields[13].Descriptor() + // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. + usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) + // usagelogDescOutputCost is the schema descriptor for output_cost field. + usagelogDescOutputCost := usagelogFields[14].Descriptor() + // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. + usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) + // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. + usagelogDescCacheCreationCost := usagelogFields[15].Descriptor() + // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. + usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) + // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. + usagelogDescCacheReadCost := usagelogFields[16].Descriptor() + // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. + usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) + // usagelogDescTotalCost is the schema descriptor for total_cost field. + usagelogDescTotalCost := usagelogFields[17].Descriptor() + // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. + usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) + // usagelogDescActualCost is the schema descriptor for actual_cost field. + usagelogDescActualCost := usagelogFields[18].Descriptor() + // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. + usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) + // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. + usagelogDescRateMultiplier := usagelogFields[19].Descriptor() + // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. + usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) + // usagelogDescBillingType is the schema descriptor for billing_type field. + usagelogDescBillingType := usagelogFields[20].Descriptor() + // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. + usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) + // usagelogDescStream is the schema descriptor for stream field. + usagelogDescStream := usagelogFields[21].Descriptor() + // usagelog.DefaultStream holds the default value on creation for the stream field. + usagelog.DefaultStream = usagelogDescStream.Default.(bool) + // usagelogDescCreatedAt is the schema descriptor for created_at field. + usagelogDescCreatedAt := usagelogFields[24].Descriptor() + // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. + usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() userMixinHooks1 := userMixin[1].Hooks() user.Hooks[0] = userMixinHooks1[0] @@ -518,6 +625,10 @@ func init() { // userallowedgroup.DefaultCreatedAt holds the default value on creation for the created_at field. userallowedgroup.DefaultCreatedAt = userallowedgroupDescCreatedAt.Default.(func() time.Time) usersubscriptionMixin := schema.UserSubscription{}.Mixin() + usersubscriptionMixinHooks1 := usersubscriptionMixin[1].Hooks() + usersubscription.Hooks[0] = usersubscriptionMixinHooks1[0] + usersubscriptionMixinInters1 := usersubscriptionMixin[1].Interceptors() + usersubscription.Interceptors[0] = usersubscriptionMixinInters1[0] usersubscriptionMixinFields0 := usersubscriptionMixin[0].Fields() _ = usersubscriptionMixinFields0 usersubscriptionFields := schema.UserSubscription{}.Fields() diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index c1dd64af..2561dc17 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -168,6 +168,13 @@ func (Account) Edges() []ent.Edge { // 一个账户可以属于多个分组,一个分组可以包含多个账户 edge.To("groups", Group.Type). Through("account_groups", AccountGroup.Type), + // proxy: 账户使用的代理配置(可选的一对一关系) + // 使用已有的 proxy_id 外键字段 + edge.To("proxy", Proxy.Type). + Field("proxy_id"). + Unique(), + // usage_logs: 该账户的使用日志 + edge.To("usage_logs", UsageLog.Type), } } diff --git a/backend/ent/schema/account_group.go b/backend/ent/schema/account_group.go index 66729752..aa270f08 100644 --- a/backend/ent/schema/account_group.go +++ b/backend/ent/schema/account_group.go @@ -4,6 +4,7 @@ import ( "time" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" @@ -33,7 +34,8 @@ func (AccountGroup) Fields() []ent.Field { Default(50), field.Time("created_at"). Immutable(). - Default(time.Now), + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), } } diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 0f0f830e..f9ece05e 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -60,12 +60,13 @@ func (ApiKey) Edges() []ent.Edge { Ref("api_keys"). Field("group_id"). Unique(), + edge.To("usage_logs", UsageLog.Type), } } func (ApiKey) Indexes() []ent.Index { return []ent.Index{ - index.Fields("key").Unique(), + // key 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("user_id"), index.Fields("group_id"), index.Fields("status"), diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 2c30c979..7f3ed167 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -69,6 +69,8 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Int("default_validity_days"). + Default(30), } } @@ -77,6 +79,7 @@ func (Group) Edges() []ent.Edge { edge.To("api_keys", ApiKey.Type), edge.To("redeem_codes", RedeemCode.Type), edge.To("subscriptions", UserSubscription.Type), + edge.To("usage_logs", UsageLog.Type), edge.From("accounts", Account.Type). Ref("groups"). Through("account_groups", AccountGroup.Type), @@ -88,7 +91,7 @@ func (Group) Edges() []ent.Edge { func (Group) Indexes() []ent.Index { return []ent.Index{ - index.Fields("name").Unique(), + // name 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("platform"), index.Fields("subscription_type"), diff --git a/backend/ent/schema/proxy.go b/backend/ent/schema/proxy.go index 45608c96..46d657d3 100644 --- a/backend/ent/schema/proxy.go +++ b/backend/ent/schema/proxy.go @@ -6,6 +6,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "entgo.io/ent/schema/index" ) @@ -54,6 +55,15 @@ func (Proxy) Fields() []ent.Field { } } +// Edges 定义代理实体的关联关系。 +func (Proxy) Edges() []ent.Edge { + return []ent.Edge{ + // accounts: 使用此代理的账户(反向边) + edge.From("accounts", Account.Type). + Ref("proxy"), + } +} + func (Proxy) Indexes() []ent.Index { return []ent.Index{ index.Fields("status"), diff --git a/backend/ent/schema/redeem_code.go b/backend/ent/schema/redeem_code.go index 0ecb48b7..b4664e06 100644 --- a/backend/ent/schema/redeem_code.go +++ b/backend/ent/schema/redeem_code.go @@ -15,6 +15,14 @@ import ( ) // RedeemCode holds the schema definition for the RedeemCode entity. +// +// 删除策略:硬删除 +// RedeemCode 使用硬删除而非软删除,原因如下: +// - 兑换码具有一次性使用特性,删除后无需保留历史记录 +// - 已使用的兑换码通过 status 和 used_at 字段追踪,无需依赖软删除 +// - 减少数据库存储压力和查询复杂度 +// +// 如需审计已删除的兑换码,建议在删除前将关键信息写入审计日志表。 type RedeemCode struct { ent.Schema } @@ -78,7 +86,7 @@ func (RedeemCode) Edges() []ent.Edge { func (RedeemCode) Indexes() []ent.Index { return []ent.Index{ - index.Fields("code").Unique(), + // code 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("used_by"), index.Fields("group_id"), diff --git a/backend/ent/schema/setting.go b/backend/ent/schema/setting.go index f31f2a41..3f896fab 100644 --- a/backend/ent/schema/setting.go +++ b/backend/ent/schema/setting.go @@ -8,10 +8,17 @@ import ( "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/field" - "entgo.io/ent/schema/index" ) // Setting holds the schema definition for the Setting entity. +// +// 删除策略:硬删除 +// Setting 使用硬删除而非软删除,原因如下: +// - 系统设置是简单的键值对,删除即意味着恢复默认值 +// - 设置变更通常通过应用日志追踪,无需在数据库层面保留历史 +// - 保持表结构简洁,避免无效数据积累 +// +// 如需设置变更审计,建议在更新/删除前将变更记录写入审计日志表。 type Setting struct { ent.Schema } @@ -43,7 +50,6 @@ func (Setting) Fields() []ent.Field { } func (Setting) Indexes() []ent.Index { - return []ent.Index{ - index.Fields("key").Unique(), - } + // key 字段已在 Fields() 中声明 Unique(),无需额外索引 + return nil } diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go new file mode 100644 index 00000000..6f78e8a9 --- /dev/null +++ b/backend/ent/schema/usage_log.go @@ -0,0 +1,152 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +package schema + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// UsageLog 定义使用日志实体的 schema。 +// +// 使用日志记录每次 API 调用的详细信息,包括 token 使用量、成本计算等。 +// 这是一个只追加的表,不支持更新和删除。 +type UsageLog struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (UsageLog) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "usage_logs"}, + } +} + +// Fields 定义使用日志实体的所有字段。 +func (UsageLog) Fields() []ent.Field { + return []ent.Field{ + // 关联字段 + field.Int64("user_id"), + field.Int64("api_key_id"), + field.Int64("account_id"), + field.String("request_id"). + MaxLen(64). + NotEmpty(), + field.String("model"). + MaxLen(100). + NotEmpty(), + field.Int64("group_id"). + Optional(). + Nillable(), + field.Int64("subscription_id"). + Optional(). + Nillable(), + + // Token 计数字段 + field.Int("input_tokens"). + Default(0), + field.Int("output_tokens"). + Default(0), + field.Int("cache_creation_tokens"). + Default(0), + field.Int("cache_read_tokens"). + Default(0), + field.Int("cache_creation_5m_tokens"). + Default(0), + field.Int("cache_creation_1h_tokens"). + Default(0), + + // 成本字段 + field.Float("input_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("output_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("cache_creation_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("cache_read_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("total_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("actual_cost"). + Default(0). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,10)"}), + field.Float("rate_multiplier"). + Default(1). + SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}), + + // 其他字段 + field.Int8("billing_type"). + Default(0), + field.Bool("stream"). + Default(false), + field.Int("duration_ms"). + Optional(). + Nillable(), + field.Int("first_token_ms"). + Optional(). + Nillable(), + + // 时间戳(只有 created_at,日志不可修改) + field.Time("created_at"). + Default(time.Now). + Immutable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +// Edges 定义使用日志实体的关联关系。 +func (UsageLog) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("usage_logs"). + Field("user_id"). + Required(). + Unique(), + edge.From("api_key", ApiKey.Type). + Ref("usage_logs"). + Field("api_key_id"). + Required(). + Unique(), + edge.From("account", Account.Type). + Ref("usage_logs"). + Field("account_id"). + Required(). + Unique(), + edge.From("group", Group.Type). + Ref("usage_logs"). + Field("group_id"). + Unique(), + edge.From("subscription", UserSubscription.Type). + Ref("usage_logs"). + Field("subscription_id"). + Unique(), + } +} + +// Indexes 定义数据库索引,优化查询性能。 +func (UsageLog) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("user_id"), + index.Fields("api_key_id"), + index.Fields("account_id"), + index.Fields("group_id"), + index.Fields("subscription_id"), + index.Fields("created_at"), + index.Fields("model"), + index.Fields("request_id"), + // 复合索引用于时间范围查询 + index.Fields("user_id", "created_at"), + index.Fields("api_key_id", "created_at"), + } +} diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index e76799ed..ba7f0ce7 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -73,12 +73,13 @@ func (User) Edges() []ent.Edge { edge.To("assigned_subscriptions", UserSubscription.Type), edge.To("allowed_groups", Group.Type). Through("user_allowed_groups", UserAllowedGroup.Type), + edge.To("usage_logs", UsageLog.Type), } } func (User) Indexes() []ent.Index { return []ent.Index{ - index.Fields("email").Unique(), + // email 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("deleted_at"), } diff --git a/backend/ent/schema/user_allowed_group.go b/backend/ent/schema/user_allowed_group.go index 8fce97c2..94156219 100644 --- a/backend/ent/schema/user_allowed_group.go +++ b/backend/ent/schema/user_allowed_group.go @@ -4,6 +4,7 @@ import ( "time" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" @@ -31,7 +32,8 @@ func (UserAllowedGroup) Fields() []ent.Field { field.Int64("group_id"), field.Time("created_at"). Immutable(). - Default(time.Now), + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), } } diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index bcb0da71..88c4ea8f 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -29,6 +29,7 @@ func (UserSubscription) Annotations() []schema.Annotation { func (UserSubscription) Mixin() []ent.Mixin { return []ent.Mixin{ mixins.TimeMixin{}, + mixins.SoftDeleteMixin{}, } } @@ -97,6 +98,7 @@ func (UserSubscription) Edges() []ent.Edge { Ref("assigned_subscriptions"). Field("assigned_by"). Unique(), + edge.To("usage_logs", UsageLog.Type), } } @@ -108,5 +110,6 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("expires_at"), index.Fields("assigned_by"), index.Fields("user_id", "group_id").Unique(), + index.Fields("deleted_at"), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index fbb68edf..ecb0409d 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -28,6 +28,8 @@ type Tx struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // UsageLog is the client for interacting with the UsageLog builders. + UsageLog *UsageLogClient // User is the client for interacting with the User builders. User *UserClient // UserAllowedGroup is the client for interacting with the UserAllowedGroup builders. @@ -172,6 +174,7 @@ func (tx *Tx) init() { tx.Proxy = NewProxyClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.Setting = NewSettingClient(tx.config) + tx.UsageLog = NewUsageLogClient(tx.config) tx.User = NewUserClient(tx.config) tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config) tx.UserSubscription = NewUserSubscriptionClient(tx.config) @@ -238,7 +241,6 @@ func (tx *txDriver) Query(ctx context.Context, query string, args, v any) error var _ dialect.Driver = (*txDriver)(nil) -// ExecContext 透传到底层事务,用于在 ent 事务中执行原生 SQL(与 ent 写入保持同一事务)。 // ExecContext allows calling the underlying ExecContext method of the transaction if it is supported by it. // See, database/sql#Tx.ExecContext for more information. func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { @@ -251,7 +253,6 @@ func (tx *txDriver) ExecContext(ctx context.Context, query string, args ...any) return ex.ExecContext(ctx, query, args...) } -// QueryContext 透传到底层事务,用于在 ent 事务中执行原生查询并共享锁语义。 // QueryContext allows calling the underlying QueryContext method of the transaction if it is supported by it. // See, database/sql#Tx.QueryContext for more information. func (tx *txDriver) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go new file mode 100644 index 00000000..e01780fe --- /dev/null +++ b/backend/ent/usagelog.go @@ -0,0 +1,491 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLog is the model entity for the UsageLog schema. +type UsageLog struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // APIKeyID holds the value of the "api_key_id" field. + APIKeyID int64 `json:"api_key_id,omitempty"` + // AccountID holds the value of the "account_id" field. + AccountID int64 `json:"account_id,omitempty"` + // RequestID holds the value of the "request_id" field. + RequestID string `json:"request_id,omitempty"` + // Model holds the value of the "model" field. + Model string `json:"model,omitempty"` + // GroupID holds the value of the "group_id" field. + GroupID *int64 `json:"group_id,omitempty"` + // SubscriptionID holds the value of the "subscription_id" field. + SubscriptionID *int64 `json:"subscription_id,omitempty"` + // InputTokens holds the value of the "input_tokens" field. + InputTokens int `json:"input_tokens,omitempty"` + // OutputTokens holds the value of the "output_tokens" field. + OutputTokens int `json:"output_tokens,omitempty"` + // CacheCreationTokens holds the value of the "cache_creation_tokens" field. + CacheCreationTokens int `json:"cache_creation_tokens,omitempty"` + // CacheReadTokens holds the value of the "cache_read_tokens" field. + CacheReadTokens int `json:"cache_read_tokens,omitempty"` + // CacheCreation5mTokens holds the value of the "cache_creation_5m_tokens" field. + CacheCreation5mTokens int `json:"cache_creation_5m_tokens,omitempty"` + // CacheCreation1hTokens holds the value of the "cache_creation_1h_tokens" field. + CacheCreation1hTokens int `json:"cache_creation_1h_tokens,omitempty"` + // InputCost holds the value of the "input_cost" field. + InputCost float64 `json:"input_cost,omitempty"` + // OutputCost holds the value of the "output_cost" field. + OutputCost float64 `json:"output_cost,omitempty"` + // CacheCreationCost holds the value of the "cache_creation_cost" field. + CacheCreationCost float64 `json:"cache_creation_cost,omitempty"` + // CacheReadCost holds the value of the "cache_read_cost" field. + CacheReadCost float64 `json:"cache_read_cost,omitempty"` + // TotalCost holds the value of the "total_cost" field. + TotalCost float64 `json:"total_cost,omitempty"` + // ActualCost holds the value of the "actual_cost" field. + ActualCost float64 `json:"actual_cost,omitempty"` + // RateMultiplier holds the value of the "rate_multiplier" field. + RateMultiplier float64 `json:"rate_multiplier,omitempty"` + // BillingType holds the value of the "billing_type" field. + BillingType int8 `json:"billing_type,omitempty"` + // Stream holds the value of the "stream" field. + Stream bool `json:"stream,omitempty"` + // DurationMs holds the value of the "duration_ms" field. + DurationMs *int `json:"duration_ms,omitempty"` + // FirstTokenMs holds the value of the "first_token_ms" field. + FirstTokenMs *int `json:"first_token_ms,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the UsageLogQuery when eager-loading is set. + Edges UsageLogEdges `json:"edges"` + selectValues sql.SelectValues +} + +// UsageLogEdges holds the relations/edges for other nodes in the graph. +type UsageLogEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // APIKey holds the value of the api_key edge. + APIKey *ApiKey `json:"api_key,omitempty"` + // Account holds the value of the account edge. + Account *Account `json:"account,omitempty"` + // Group holds the value of the group edge. + Group *Group `json:"group,omitempty"` + // Subscription holds the value of the subscription edge. + Subscription *UserSubscription `json:"subscription,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [5]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// APIKeyOrErr returns the APIKey value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) APIKeyOrErr() (*ApiKey, error) { + if e.APIKey != nil { + return e.APIKey, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: apikey.Label} + } + return nil, &NotLoadedError{edge: "api_key"} +} + +// AccountOrErr returns the Account value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) AccountOrErr() (*Account, error) { + if e.Account != nil { + return e.Account, nil + } else if e.loadedTypes[2] { + return nil, &NotFoundError{label: account.Label} + } + return nil, &NotLoadedError{edge: "account"} +} + +// GroupOrErr returns the Group value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) GroupOrErr() (*Group, error) { + if e.Group != nil { + return e.Group, nil + } else if e.loadedTypes[3] { + return nil, &NotFoundError{label: group.Label} + } + return nil, &NotLoadedError{edge: "group"} +} + +// SubscriptionOrErr returns the Subscription value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e UsageLogEdges) SubscriptionOrErr() (*UserSubscription, error) { + if e.Subscription != nil { + return e.Subscription, nil + } else if e.loadedTypes[4] { + return nil, &NotFoundError{label: usersubscription.Label} + } + return nil, &NotLoadedError{edge: "subscription"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*UsageLog) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case usagelog.FieldStream: + values[i] = new(sql.NullBool) + case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier: + values[i] = new(sql.NullFloat64) + case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs: + values[i] = new(sql.NullInt64) + case usagelog.FieldRequestID, usagelog.FieldModel: + values[i] = new(sql.NullString) + case usagelog.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the UsageLog fields. +func (_m *UsageLog) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case usagelog.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case usagelog.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case usagelog.FieldAPIKeyID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field api_key_id", values[i]) + } else if value.Valid { + _m.APIKeyID = value.Int64 + } + case usagelog.FieldAccountID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field account_id", values[i]) + } else if value.Valid { + _m.AccountID = value.Int64 + } + case usagelog.FieldRequestID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field request_id", values[i]) + } else if value.Valid { + _m.RequestID = value.String + } + case usagelog.FieldModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model", values[i]) + } else if value.Valid { + _m.Model = value.String + } + case usagelog.FieldGroupID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field group_id", values[i]) + } else if value.Valid { + _m.GroupID = new(int64) + *_m.GroupID = value.Int64 + } + case usagelog.FieldSubscriptionID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field subscription_id", values[i]) + } else if value.Valid { + _m.SubscriptionID = new(int64) + *_m.SubscriptionID = value.Int64 + } + case usagelog.FieldInputTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field input_tokens", values[i]) + } else if value.Valid { + _m.InputTokens = int(value.Int64) + } + case usagelog.FieldOutputTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field output_tokens", values[i]) + } else if value.Valid { + _m.OutputTokens = int(value.Int64) + } + case usagelog.FieldCacheCreationTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_tokens", values[i]) + } else if value.Valid { + _m.CacheCreationTokens = int(value.Int64) + } + case usagelog.FieldCacheReadTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_read_tokens", values[i]) + } else if value.Valid { + _m.CacheReadTokens = int(value.Int64) + } + case usagelog.FieldCacheCreation5mTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_5m_tokens", values[i]) + } else if value.Valid { + _m.CacheCreation5mTokens = int(value.Int64) + } + case usagelog.FieldCacheCreation1hTokens: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_1h_tokens", values[i]) + } else if value.Valid { + _m.CacheCreation1hTokens = int(value.Int64) + } + case usagelog.FieldInputCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field input_cost", values[i]) + } else if value.Valid { + _m.InputCost = value.Float64 + } + case usagelog.FieldOutputCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field output_cost", values[i]) + } else if value.Valid { + _m.OutputCost = value.Float64 + } + case usagelog.FieldCacheCreationCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field cache_creation_cost", values[i]) + } else if value.Valid { + _m.CacheCreationCost = value.Float64 + } + case usagelog.FieldCacheReadCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field cache_read_cost", values[i]) + } else if value.Valid { + _m.CacheReadCost = value.Float64 + } + case usagelog.FieldTotalCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field total_cost", values[i]) + } else if value.Valid { + _m.TotalCost = value.Float64 + } + case usagelog.FieldActualCost: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field actual_cost", values[i]) + } else if value.Valid { + _m.ActualCost = value.Float64 + } + case usagelog.FieldRateMultiplier: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i]) + } else if value.Valid { + _m.RateMultiplier = value.Float64 + } + case usagelog.FieldBillingType: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field billing_type", values[i]) + } else if value.Valid { + _m.BillingType = int8(value.Int64) + } + case usagelog.FieldStream: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field stream", values[i]) + } else if value.Valid { + _m.Stream = value.Bool + } + case usagelog.FieldDurationMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field duration_ms", values[i]) + } else if value.Valid { + _m.DurationMs = new(int) + *_m.DurationMs = int(value.Int64) + } + case usagelog.FieldFirstTokenMs: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field first_token_ms", values[i]) + } else if value.Valid { + _m.FirstTokenMs = new(int) + *_m.FirstTokenMs = int(value.Int64) + } + case usagelog.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the UsageLog. +// This includes values selected through modifiers, order, etc. +func (_m *UsageLog) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the UsageLog entity. +func (_m *UsageLog) QueryUser() *UserQuery { + return NewUsageLogClient(_m.config).QueryUser(_m) +} + +// QueryAPIKey queries the "api_key" edge of the UsageLog entity. +func (_m *UsageLog) QueryAPIKey() *ApiKeyQuery { + return NewUsageLogClient(_m.config).QueryAPIKey(_m) +} + +// QueryAccount queries the "account" edge of the UsageLog entity. +func (_m *UsageLog) QueryAccount() *AccountQuery { + return NewUsageLogClient(_m.config).QueryAccount(_m) +} + +// QueryGroup queries the "group" edge of the UsageLog entity. +func (_m *UsageLog) QueryGroup() *GroupQuery { + return NewUsageLogClient(_m.config).QueryGroup(_m) +} + +// QuerySubscription queries the "subscription" edge of the UsageLog entity. +func (_m *UsageLog) QuerySubscription() *UserSubscriptionQuery { + return NewUsageLogClient(_m.config).QuerySubscription(_m) +} + +// Update returns a builder for updating this UsageLog. +// Note that you need to call UsageLog.Unwrap() before calling this method if this UsageLog +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *UsageLog) Update() *UsageLogUpdateOne { + return NewUsageLogClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the UsageLog entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *UsageLog) Unwrap() *UsageLog { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: UsageLog is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *UsageLog) String() string { + var builder strings.Builder + builder.WriteString("UsageLog(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("api_key_id=") + builder.WriteString(fmt.Sprintf("%v", _m.APIKeyID)) + builder.WriteString(", ") + builder.WriteString("account_id=") + builder.WriteString(fmt.Sprintf("%v", _m.AccountID)) + builder.WriteString(", ") + builder.WriteString("request_id=") + builder.WriteString(_m.RequestID) + builder.WriteString(", ") + builder.WriteString("model=") + builder.WriteString(_m.Model) + builder.WriteString(", ") + if v := _m.GroupID; v != nil { + builder.WriteString("group_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SubscriptionID; v != nil { + builder.WriteString("subscription_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("input_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.InputTokens)) + builder.WriteString(", ") + builder.WriteString("output_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.OutputTokens)) + builder.WriteString(", ") + builder.WriteString("cache_creation_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreationTokens)) + builder.WriteString(", ") + builder.WriteString("cache_read_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheReadTokens)) + builder.WriteString(", ") + builder.WriteString("cache_creation_5m_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreation5mTokens)) + builder.WriteString(", ") + builder.WriteString("cache_creation_1h_tokens=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreation1hTokens)) + builder.WriteString(", ") + builder.WriteString("input_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.InputCost)) + builder.WriteString(", ") + builder.WriteString("output_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.OutputCost)) + builder.WriteString(", ") + builder.WriteString("cache_creation_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheCreationCost)) + builder.WriteString(", ") + builder.WriteString("cache_read_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheReadCost)) + builder.WriteString(", ") + builder.WriteString("total_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.TotalCost)) + builder.WriteString(", ") + builder.WriteString("actual_cost=") + builder.WriteString(fmt.Sprintf("%v", _m.ActualCost)) + builder.WriteString(", ") + builder.WriteString("rate_multiplier=") + builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier)) + builder.WriteString(", ") + builder.WriteString("billing_type=") + builder.WriteString(fmt.Sprintf("%v", _m.BillingType)) + builder.WriteString(", ") + builder.WriteString("stream=") + builder.WriteString(fmt.Sprintf("%v", _m.Stream)) + builder.WriteString(", ") + if v := _m.DurationMs; v != nil { + builder.WriteString("duration_ms=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.FirstTokenMs; v != nil { + builder.WriteString("first_token_ms=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// UsageLogs is a parsable slice of UsageLog. +type UsageLogs []*UsageLog diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go new file mode 100644 index 00000000..bdc6f7e6 --- /dev/null +++ b/backend/ent/usagelog/usagelog.go @@ -0,0 +1,396 @@ +// Code generated by ent, DO NOT EDIT. + +package usagelog + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the usagelog type in the database. + Label = "usage_log" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldAPIKeyID holds the string denoting the api_key_id field in the database. + FieldAPIKeyID = "api_key_id" + // FieldAccountID holds the string denoting the account_id field in the database. + FieldAccountID = "account_id" + // FieldRequestID holds the string denoting the request_id field in the database. + FieldRequestID = "request_id" + // FieldModel holds the string denoting the model field in the database. + FieldModel = "model" + // FieldGroupID holds the string denoting the group_id field in the database. + FieldGroupID = "group_id" + // FieldSubscriptionID holds the string denoting the subscription_id field in the database. + FieldSubscriptionID = "subscription_id" + // FieldInputTokens holds the string denoting the input_tokens field in the database. + FieldInputTokens = "input_tokens" + // FieldOutputTokens holds the string denoting the output_tokens field in the database. + FieldOutputTokens = "output_tokens" + // FieldCacheCreationTokens holds the string denoting the cache_creation_tokens field in the database. + FieldCacheCreationTokens = "cache_creation_tokens" + // FieldCacheReadTokens holds the string denoting the cache_read_tokens field in the database. + FieldCacheReadTokens = "cache_read_tokens" + // FieldCacheCreation5mTokens holds the string denoting the cache_creation_5m_tokens field in the database. + FieldCacheCreation5mTokens = "cache_creation_5m_tokens" + // FieldCacheCreation1hTokens holds the string denoting the cache_creation_1h_tokens field in the database. + FieldCacheCreation1hTokens = "cache_creation_1h_tokens" + // FieldInputCost holds the string denoting the input_cost field in the database. + FieldInputCost = "input_cost" + // FieldOutputCost holds the string denoting the output_cost field in the database. + FieldOutputCost = "output_cost" + // FieldCacheCreationCost holds the string denoting the cache_creation_cost field in the database. + FieldCacheCreationCost = "cache_creation_cost" + // FieldCacheReadCost holds the string denoting the cache_read_cost field in the database. + FieldCacheReadCost = "cache_read_cost" + // FieldTotalCost holds the string denoting the total_cost field in the database. + FieldTotalCost = "total_cost" + // FieldActualCost holds the string denoting the actual_cost field in the database. + FieldActualCost = "actual_cost" + // FieldRateMultiplier holds the string denoting the rate_multiplier field in the database. + FieldRateMultiplier = "rate_multiplier" + // FieldBillingType holds the string denoting the billing_type field in the database. + FieldBillingType = "billing_type" + // FieldStream holds the string denoting the stream field in the database. + FieldStream = "stream" + // FieldDurationMs holds the string denoting the duration_ms field in the database. + FieldDurationMs = "duration_ms" + // FieldFirstTokenMs holds the string denoting the first_token_ms field in the database. + FieldFirstTokenMs = "first_token_ms" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeAPIKey holds the string denoting the api_key edge name in mutations. + EdgeAPIKey = "api_key" + // EdgeAccount holds the string denoting the account edge name in mutations. + EdgeAccount = "account" + // EdgeGroup holds the string denoting the group edge name in mutations. + EdgeGroup = "group" + // EdgeSubscription holds the string denoting the subscription edge name in mutations. + EdgeSubscription = "subscription" + // Table holds the table name of the usagelog in the database. + Table = "usage_logs" + // UserTable is the table that holds the user relation/edge. + UserTable = "usage_logs" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // APIKeyTable is the table that holds the api_key relation/edge. + APIKeyTable = "usage_logs" + // APIKeyInverseTable is the table name for the ApiKey entity. + // It exists in this package in order to avoid circular dependency with the "apikey" package. + APIKeyInverseTable = "api_keys" + // APIKeyColumn is the table column denoting the api_key relation/edge. + APIKeyColumn = "api_key_id" + // AccountTable is the table that holds the account relation/edge. + AccountTable = "usage_logs" + // AccountInverseTable is the table name for the Account entity. + // It exists in this package in order to avoid circular dependency with the "account" package. + AccountInverseTable = "accounts" + // AccountColumn is the table column denoting the account relation/edge. + AccountColumn = "account_id" + // GroupTable is the table that holds the group relation/edge. + GroupTable = "usage_logs" + // GroupInverseTable is the table name for the Group entity. + // It exists in this package in order to avoid circular dependency with the "group" package. + GroupInverseTable = "groups" + // GroupColumn is the table column denoting the group relation/edge. + GroupColumn = "group_id" + // SubscriptionTable is the table that holds the subscription relation/edge. + SubscriptionTable = "usage_logs" + // SubscriptionInverseTable is the table name for the UserSubscription entity. + // It exists in this package in order to avoid circular dependency with the "usersubscription" package. + SubscriptionInverseTable = "user_subscriptions" + // SubscriptionColumn is the table column denoting the subscription relation/edge. + SubscriptionColumn = "subscription_id" +) + +// Columns holds all SQL columns for usagelog fields. +var Columns = []string{ + FieldID, + FieldUserID, + FieldAPIKeyID, + FieldAccountID, + FieldRequestID, + FieldModel, + FieldGroupID, + FieldSubscriptionID, + FieldInputTokens, + FieldOutputTokens, + FieldCacheCreationTokens, + FieldCacheReadTokens, + FieldCacheCreation5mTokens, + FieldCacheCreation1hTokens, + FieldInputCost, + FieldOutputCost, + FieldCacheCreationCost, + FieldCacheReadCost, + FieldTotalCost, + FieldActualCost, + FieldRateMultiplier, + FieldBillingType, + FieldStream, + FieldDurationMs, + FieldFirstTokenMs, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // RequestIDValidator is a validator for the "request_id" field. It is called by the builders before save. + RequestIDValidator func(string) error + // ModelValidator is a validator for the "model" field. It is called by the builders before save. + ModelValidator func(string) error + // DefaultInputTokens holds the default value on creation for the "input_tokens" field. + DefaultInputTokens int + // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. + DefaultOutputTokens int + // DefaultCacheCreationTokens holds the default value on creation for the "cache_creation_tokens" field. + DefaultCacheCreationTokens int + // DefaultCacheReadTokens holds the default value on creation for the "cache_read_tokens" field. + DefaultCacheReadTokens int + // DefaultCacheCreation5mTokens holds the default value on creation for the "cache_creation_5m_tokens" field. + DefaultCacheCreation5mTokens int + // DefaultCacheCreation1hTokens holds the default value on creation for the "cache_creation_1h_tokens" field. + DefaultCacheCreation1hTokens int + // DefaultInputCost holds the default value on creation for the "input_cost" field. + DefaultInputCost float64 + // DefaultOutputCost holds the default value on creation for the "output_cost" field. + DefaultOutputCost float64 + // DefaultCacheCreationCost holds the default value on creation for the "cache_creation_cost" field. + DefaultCacheCreationCost float64 + // DefaultCacheReadCost holds the default value on creation for the "cache_read_cost" field. + DefaultCacheReadCost float64 + // DefaultTotalCost holds the default value on creation for the "total_cost" field. + DefaultTotalCost float64 + // DefaultActualCost holds the default value on creation for the "actual_cost" field. + DefaultActualCost float64 + // DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field. + DefaultRateMultiplier float64 + // DefaultBillingType holds the default value on creation for the "billing_type" field. + DefaultBillingType int8 + // DefaultStream holds the default value on creation for the "stream" field. + DefaultStream bool + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time +) + +// OrderOption defines the ordering options for the UsageLog queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByAPIKeyID orders the results by the api_key_id field. +func ByAPIKeyID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAPIKeyID, opts...).ToFunc() +} + +// ByAccountID orders the results by the account_id field. +func ByAccountID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccountID, opts...).ToFunc() +} + +// ByRequestID orders the results by the request_id field. +func ByRequestID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestID, opts...).ToFunc() +} + +// ByModel orders the results by the model field. +func ByModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModel, opts...).ToFunc() +} + +// ByGroupID orders the results by the group_id field. +func ByGroupID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGroupID, opts...).ToFunc() +} + +// BySubscriptionID orders the results by the subscription_id field. +func BySubscriptionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSubscriptionID, opts...).ToFunc() +} + +// ByInputTokens orders the results by the input_tokens field. +func ByInputTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldInputTokens, opts...).ToFunc() +} + +// ByOutputTokens orders the results by the output_tokens field. +func ByOutputTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOutputTokens, opts...).ToFunc() +} + +// ByCacheCreationTokens orders the results by the cache_creation_tokens field. +func ByCacheCreationTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreationTokens, opts...).ToFunc() +} + +// ByCacheReadTokens orders the results by the cache_read_tokens field. +func ByCacheReadTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheReadTokens, opts...).ToFunc() +} + +// ByCacheCreation5mTokens orders the results by the cache_creation_5m_tokens field. +func ByCacheCreation5mTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreation5mTokens, opts...).ToFunc() +} + +// ByCacheCreation1hTokens orders the results by the cache_creation_1h_tokens field. +func ByCacheCreation1hTokens(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreation1hTokens, opts...).ToFunc() +} + +// ByInputCost orders the results by the input_cost field. +func ByInputCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldInputCost, opts...).ToFunc() +} + +// ByOutputCost orders the results by the output_cost field. +func ByOutputCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOutputCost, opts...).ToFunc() +} + +// ByCacheCreationCost orders the results by the cache_creation_cost field. +func ByCacheCreationCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheCreationCost, opts...).ToFunc() +} + +// ByCacheReadCost orders the results by the cache_read_cost field. +func ByCacheReadCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheReadCost, opts...).ToFunc() +} + +// ByTotalCost orders the results by the total_cost field. +func ByTotalCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalCost, opts...).ToFunc() +} + +// ByActualCost orders the results by the actual_cost field. +func ByActualCost(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldActualCost, opts...).ToFunc() +} + +// ByRateMultiplier orders the results by the rate_multiplier field. +func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc() +} + +// ByBillingType orders the results by the billing_type field. +func ByBillingType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBillingType, opts...).ToFunc() +} + +// ByStream orders the results by the stream field. +func ByStream(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStream, opts...).ToFunc() +} + +// ByDurationMs orders the results by the duration_ms field. +func ByDurationMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDurationMs, opts...).ToFunc() +} + +// ByFirstTokenMs orders the results by the first_token_ms field. +func ByFirstTokenMs(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFirstTokenMs, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAPIKeyField orders the results by api_key field. +func ByAPIKeyField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAPIKeyStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAccountField orders the results by account field. +func ByAccountField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAccountStep(), sql.OrderByField(field, opts...)) + } +} + +// ByGroupField orders the results by group field. +func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...)) + } +} + +// BySubscriptionField orders the results by subscription field. +func BySubscriptionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSubscriptionStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newAPIKeyStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(APIKeyInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, APIKeyTable, APIKeyColumn), + ) +} +func newAccountStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AccountInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, AccountTable, AccountColumn), + ) +} +func newGroupStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(GroupInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), + ) +} +func newSubscriptionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SubscriptionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, SubscriptionTable, SubscriptionColumn), + ) +} diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go new file mode 100644 index 00000000..9c260433 --- /dev/null +++ b/backend/ent/usagelog/where.go @@ -0,0 +1,1271 @@ +// Code generated by ent, DO NOT EDIT. + +package usagelog + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldID, id)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUserID, v)) +} + +// APIKeyID applies equality check predicate on the "api_key_id" field. It's identical to APIKeyIDEQ. +func APIKeyID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAPIKeyID, v)) +} + +// AccountID applies equality check predicate on the "account_id" field. It's identical to AccountIDEQ. +func AccountID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAccountID, v)) +} + +// RequestID applies equality check predicate on the "request_id" field. It's identical to RequestIDEQ. +func RequestID(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestID, v)) +} + +// Model applies equality check predicate on the "model" field. It's identical to ModelEQ. +func Model(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) +} + +// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. +func GroupID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) +} + +// SubscriptionID applies equality check predicate on the "subscription_id" field. It's identical to SubscriptionIDEQ. +func SubscriptionID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldSubscriptionID, v)) +} + +// InputTokens applies equality check predicate on the "input_tokens" field. It's identical to InputTokensEQ. +func InputTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputTokens, v)) +} + +// OutputTokens applies equality check predicate on the "output_tokens" field. It's identical to OutputTokensEQ. +func OutputTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputTokens, v)) +} + +// CacheCreationTokens applies equality check predicate on the "cache_creation_tokens" field. It's identical to CacheCreationTokensEQ. +func CacheCreationTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationTokens, v)) +} + +// CacheReadTokens applies equality check predicate on the "cache_read_tokens" field. It's identical to CacheReadTokensEQ. +func CacheReadTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadTokens, v)) +} + +// CacheCreation5mTokens applies equality check predicate on the "cache_creation_5m_tokens" field. It's identical to CacheCreation5mTokensEQ. +func CacheCreation5mTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation1hTokens applies equality check predicate on the "cache_creation_1h_tokens" field. It's identical to CacheCreation1hTokensEQ. +func CacheCreation1hTokens(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation1hTokens, v)) +} + +// InputCost applies equality check predicate on the "input_cost" field. It's identical to InputCostEQ. +func InputCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputCost, v)) +} + +// OutputCost applies equality check predicate on the "output_cost" field. It's identical to OutputCostEQ. +func OutputCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputCost, v)) +} + +// CacheCreationCost applies equality check predicate on the "cache_creation_cost" field. It's identical to CacheCreationCostEQ. +func CacheCreationCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationCost, v)) +} + +// CacheReadCost applies equality check predicate on the "cache_read_cost" field. It's identical to CacheReadCostEQ. +func CacheReadCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadCost, v)) +} + +// TotalCost applies equality check predicate on the "total_cost" field. It's identical to TotalCostEQ. +func TotalCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldTotalCost, v)) +} + +// ActualCost applies equality check predicate on the "actual_cost" field. It's identical to ActualCostEQ. +func ActualCost(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldActualCost, v)) +} + +// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ. +func RateMultiplier(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v)) +} + +// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ. +func BillingType(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v)) +} + +// Stream applies equality check predicate on the "stream" field. It's identical to StreamEQ. +func Stream(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldStream, v)) +} + +// DurationMs applies equality check predicate on the "duration_ms" field. It's identical to DurationMsEQ. +func DurationMs(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldDurationMs, v)) +} + +// FirstTokenMs applies equality check predicate on the "first_token_ms" field. It's identical to FirstTokenMsEQ. +func FirstTokenMs(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldFirstTokenMs, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldUserID, vs...)) +} + +// APIKeyIDEQ applies the EQ predicate on the "api_key_id" field. +func APIKeyIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAPIKeyID, v)) +} + +// APIKeyIDNEQ applies the NEQ predicate on the "api_key_id" field. +func APIKeyIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldAPIKeyID, v)) +} + +// APIKeyIDIn applies the In predicate on the "api_key_id" field. +func APIKeyIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldAPIKeyID, vs...)) +} + +// APIKeyIDNotIn applies the NotIn predicate on the "api_key_id" field. +func APIKeyIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldAPIKeyID, vs...)) +} + +// AccountIDEQ applies the EQ predicate on the "account_id" field. +func AccountIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldAccountID, v)) +} + +// AccountIDNEQ applies the NEQ predicate on the "account_id" field. +func AccountIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldAccountID, v)) +} + +// AccountIDIn applies the In predicate on the "account_id" field. +func AccountIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldAccountID, vs...)) +} + +// AccountIDNotIn applies the NotIn predicate on the "account_id" field. +func AccountIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldAccountID, vs...)) +} + +// RequestIDEQ applies the EQ predicate on the "request_id" field. +func RequestIDEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRequestID, v)) +} + +// RequestIDNEQ applies the NEQ predicate on the "request_id" field. +func RequestIDNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldRequestID, v)) +} + +// RequestIDIn applies the In predicate on the "request_id" field. +func RequestIDIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldRequestID, vs...)) +} + +// RequestIDNotIn applies the NotIn predicate on the "request_id" field. +func RequestIDNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldRequestID, vs...)) +} + +// RequestIDGT applies the GT predicate on the "request_id" field. +func RequestIDGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldRequestID, v)) +} + +// RequestIDGTE applies the GTE predicate on the "request_id" field. +func RequestIDGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldRequestID, v)) +} + +// RequestIDLT applies the LT predicate on the "request_id" field. +func RequestIDLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldRequestID, v)) +} + +// RequestIDLTE applies the LTE predicate on the "request_id" field. +func RequestIDLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldRequestID, v)) +} + +// RequestIDContains applies the Contains predicate on the "request_id" field. +func RequestIDContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldRequestID, v)) +} + +// RequestIDHasPrefix applies the HasPrefix predicate on the "request_id" field. +func RequestIDHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldRequestID, v)) +} + +// RequestIDHasSuffix applies the HasSuffix predicate on the "request_id" field. +func RequestIDHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldRequestID, v)) +} + +// RequestIDEqualFold applies the EqualFold predicate on the "request_id" field. +func RequestIDEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldRequestID, v)) +} + +// RequestIDContainsFold applies the ContainsFold predicate on the "request_id" field. +func RequestIDContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldRequestID, v)) +} + +// ModelEQ applies the EQ predicate on the "model" field. +func ModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) +} + +// ModelNEQ applies the NEQ predicate on the "model" field. +func ModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldModel, v)) +} + +// ModelIn applies the In predicate on the "model" field. +func ModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldModel, vs...)) +} + +// ModelNotIn applies the NotIn predicate on the "model" field. +func ModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldModel, vs...)) +} + +// ModelGT applies the GT predicate on the "model" field. +func ModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldModel, v)) +} + +// ModelGTE applies the GTE predicate on the "model" field. +func ModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldModel, v)) +} + +// ModelLT applies the LT predicate on the "model" field. +func ModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldModel, v)) +} + +// ModelLTE applies the LTE predicate on the "model" field. +func ModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldModel, v)) +} + +// ModelContains applies the Contains predicate on the "model" field. +func ModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldModel, v)) +} + +// ModelHasPrefix applies the HasPrefix predicate on the "model" field. +func ModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldModel, v)) +} + +// ModelHasSuffix applies the HasSuffix predicate on the "model" field. +func ModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldModel, v)) +} + +// ModelEqualFold applies the EqualFold predicate on the "model" field. +func ModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldModel, v)) +} + +// ModelContainsFold applies the ContainsFold predicate on the "model" field. +func ModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) +} + +// GroupIDEQ applies the EQ predicate on the "group_id" field. +func GroupIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) +} + +// GroupIDNEQ applies the NEQ predicate on the "group_id" field. +func GroupIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldGroupID, v)) +} + +// GroupIDIn applies the In predicate on the "group_id" field. +func GroupIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldGroupID, vs...)) +} + +// GroupIDNotIn applies the NotIn predicate on the "group_id" field. +func GroupIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldGroupID, vs...)) +} + +// GroupIDIsNil applies the IsNil predicate on the "group_id" field. +func GroupIDIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldGroupID)) +} + +// GroupIDNotNil applies the NotNil predicate on the "group_id" field. +func GroupIDNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldGroupID)) +} + +// SubscriptionIDEQ applies the EQ predicate on the "subscription_id" field. +func SubscriptionIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldSubscriptionID, v)) +} + +// SubscriptionIDNEQ applies the NEQ predicate on the "subscription_id" field. +func SubscriptionIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldSubscriptionID, v)) +} + +// SubscriptionIDIn applies the In predicate on the "subscription_id" field. +func SubscriptionIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldSubscriptionID, vs...)) +} + +// SubscriptionIDNotIn applies the NotIn predicate on the "subscription_id" field. +func SubscriptionIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldSubscriptionID, vs...)) +} + +// SubscriptionIDIsNil applies the IsNil predicate on the "subscription_id" field. +func SubscriptionIDIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldSubscriptionID)) +} + +// SubscriptionIDNotNil applies the NotNil predicate on the "subscription_id" field. +func SubscriptionIDNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldSubscriptionID)) +} + +// InputTokensEQ applies the EQ predicate on the "input_tokens" field. +func InputTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputTokens, v)) +} + +// InputTokensNEQ applies the NEQ predicate on the "input_tokens" field. +func InputTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldInputTokens, v)) +} + +// InputTokensIn applies the In predicate on the "input_tokens" field. +func InputTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldInputTokens, vs...)) +} + +// InputTokensNotIn applies the NotIn predicate on the "input_tokens" field. +func InputTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldInputTokens, vs...)) +} + +// InputTokensGT applies the GT predicate on the "input_tokens" field. +func InputTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldInputTokens, v)) +} + +// InputTokensGTE applies the GTE predicate on the "input_tokens" field. +func InputTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldInputTokens, v)) +} + +// InputTokensLT applies the LT predicate on the "input_tokens" field. +func InputTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldInputTokens, v)) +} + +// InputTokensLTE applies the LTE predicate on the "input_tokens" field. +func InputTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldInputTokens, v)) +} + +// OutputTokensEQ applies the EQ predicate on the "output_tokens" field. +func OutputTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputTokens, v)) +} + +// OutputTokensNEQ applies the NEQ predicate on the "output_tokens" field. +func OutputTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldOutputTokens, v)) +} + +// OutputTokensIn applies the In predicate on the "output_tokens" field. +func OutputTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldOutputTokens, vs...)) +} + +// OutputTokensNotIn applies the NotIn predicate on the "output_tokens" field. +func OutputTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldOutputTokens, vs...)) +} + +// OutputTokensGT applies the GT predicate on the "output_tokens" field. +func OutputTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldOutputTokens, v)) +} + +// OutputTokensGTE applies the GTE predicate on the "output_tokens" field. +func OutputTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldOutputTokens, v)) +} + +// OutputTokensLT applies the LT predicate on the "output_tokens" field. +func OutputTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldOutputTokens, v)) +} + +// OutputTokensLTE applies the LTE predicate on the "output_tokens" field. +func OutputTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldOutputTokens, v)) +} + +// CacheCreationTokensEQ applies the EQ predicate on the "cache_creation_tokens" field. +func CacheCreationTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensNEQ applies the NEQ predicate on the "cache_creation_tokens" field. +func CacheCreationTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensIn applies the In predicate on the "cache_creation_tokens" field. +func CacheCreationTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreationTokens, vs...)) +} + +// CacheCreationTokensNotIn applies the NotIn predicate on the "cache_creation_tokens" field. +func CacheCreationTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreationTokens, vs...)) +} + +// CacheCreationTokensGT applies the GT predicate on the "cache_creation_tokens" field. +func CacheCreationTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensGTE applies the GTE predicate on the "cache_creation_tokens" field. +func CacheCreationTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensLT applies the LT predicate on the "cache_creation_tokens" field. +func CacheCreationTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreationTokens, v)) +} + +// CacheCreationTokensLTE applies the LTE predicate on the "cache_creation_tokens" field. +func CacheCreationTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreationTokens, v)) +} + +// CacheReadTokensEQ applies the EQ predicate on the "cache_read_tokens" field. +func CacheReadTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadTokens, v)) +} + +// CacheReadTokensNEQ applies the NEQ predicate on the "cache_read_tokens" field. +func CacheReadTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheReadTokens, v)) +} + +// CacheReadTokensIn applies the In predicate on the "cache_read_tokens" field. +func CacheReadTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheReadTokens, vs...)) +} + +// CacheReadTokensNotIn applies the NotIn predicate on the "cache_read_tokens" field. +func CacheReadTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheReadTokens, vs...)) +} + +// CacheReadTokensGT applies the GT predicate on the "cache_read_tokens" field. +func CacheReadTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheReadTokens, v)) +} + +// CacheReadTokensGTE applies the GTE predicate on the "cache_read_tokens" field. +func CacheReadTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheReadTokens, v)) +} + +// CacheReadTokensLT applies the LT predicate on the "cache_read_tokens" field. +func CacheReadTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheReadTokens, v)) +} + +// CacheReadTokensLTE applies the LTE predicate on the "cache_read_tokens" field. +func CacheReadTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheReadTokens, v)) +} + +// CacheCreation5mTokensEQ applies the EQ predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensNEQ applies the NEQ predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensIn applies the In predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreation5mTokens, vs...)) +} + +// CacheCreation5mTokensNotIn applies the NotIn predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreation5mTokens, vs...)) +} + +// CacheCreation5mTokensGT applies the GT predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensGTE applies the GTE predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensLT applies the LT predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation5mTokensLTE applies the LTE predicate on the "cache_creation_5m_tokens" field. +func CacheCreation5mTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreation5mTokens, v)) +} + +// CacheCreation1hTokensEQ applies the EQ predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensNEQ applies the NEQ predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensIn applies the In predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreation1hTokens, vs...)) +} + +// CacheCreation1hTokensNotIn applies the NotIn predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreation1hTokens, vs...)) +} + +// CacheCreation1hTokensGT applies the GT predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensGTE applies the GTE predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensLT applies the LT predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreation1hTokens, v)) +} + +// CacheCreation1hTokensLTE applies the LTE predicate on the "cache_creation_1h_tokens" field. +func CacheCreation1hTokensLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreation1hTokens, v)) +} + +// InputCostEQ applies the EQ predicate on the "input_cost" field. +func InputCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldInputCost, v)) +} + +// InputCostNEQ applies the NEQ predicate on the "input_cost" field. +func InputCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldInputCost, v)) +} + +// InputCostIn applies the In predicate on the "input_cost" field. +func InputCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldInputCost, vs...)) +} + +// InputCostNotIn applies the NotIn predicate on the "input_cost" field. +func InputCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldInputCost, vs...)) +} + +// InputCostGT applies the GT predicate on the "input_cost" field. +func InputCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldInputCost, v)) +} + +// InputCostGTE applies the GTE predicate on the "input_cost" field. +func InputCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldInputCost, v)) +} + +// InputCostLT applies the LT predicate on the "input_cost" field. +func InputCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldInputCost, v)) +} + +// InputCostLTE applies the LTE predicate on the "input_cost" field. +func InputCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldInputCost, v)) +} + +// OutputCostEQ applies the EQ predicate on the "output_cost" field. +func OutputCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldOutputCost, v)) +} + +// OutputCostNEQ applies the NEQ predicate on the "output_cost" field. +func OutputCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldOutputCost, v)) +} + +// OutputCostIn applies the In predicate on the "output_cost" field. +func OutputCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldOutputCost, vs...)) +} + +// OutputCostNotIn applies the NotIn predicate on the "output_cost" field. +func OutputCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldOutputCost, vs...)) +} + +// OutputCostGT applies the GT predicate on the "output_cost" field. +func OutputCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldOutputCost, v)) +} + +// OutputCostGTE applies the GTE predicate on the "output_cost" field. +func OutputCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldOutputCost, v)) +} + +// OutputCostLT applies the LT predicate on the "output_cost" field. +func OutputCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldOutputCost, v)) +} + +// OutputCostLTE applies the LTE predicate on the "output_cost" field. +func OutputCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldOutputCost, v)) +} + +// CacheCreationCostEQ applies the EQ predicate on the "cache_creation_cost" field. +func CacheCreationCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheCreationCost, v)) +} + +// CacheCreationCostNEQ applies the NEQ predicate on the "cache_creation_cost" field. +func CacheCreationCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheCreationCost, v)) +} + +// CacheCreationCostIn applies the In predicate on the "cache_creation_cost" field. +func CacheCreationCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheCreationCost, vs...)) +} + +// CacheCreationCostNotIn applies the NotIn predicate on the "cache_creation_cost" field. +func CacheCreationCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheCreationCost, vs...)) +} + +// CacheCreationCostGT applies the GT predicate on the "cache_creation_cost" field. +func CacheCreationCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheCreationCost, v)) +} + +// CacheCreationCostGTE applies the GTE predicate on the "cache_creation_cost" field. +func CacheCreationCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheCreationCost, v)) +} + +// CacheCreationCostLT applies the LT predicate on the "cache_creation_cost" field. +func CacheCreationCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheCreationCost, v)) +} + +// CacheCreationCostLTE applies the LTE predicate on the "cache_creation_cost" field. +func CacheCreationCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheCreationCost, v)) +} + +// CacheReadCostEQ applies the EQ predicate on the "cache_read_cost" field. +func CacheReadCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheReadCost, v)) +} + +// CacheReadCostNEQ applies the NEQ predicate on the "cache_read_cost" field. +func CacheReadCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheReadCost, v)) +} + +// CacheReadCostIn applies the In predicate on the "cache_read_cost" field. +func CacheReadCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCacheReadCost, vs...)) +} + +// CacheReadCostNotIn applies the NotIn predicate on the "cache_read_cost" field. +func CacheReadCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCacheReadCost, vs...)) +} + +// CacheReadCostGT applies the GT predicate on the "cache_read_cost" field. +func CacheReadCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCacheReadCost, v)) +} + +// CacheReadCostGTE applies the GTE predicate on the "cache_read_cost" field. +func CacheReadCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCacheReadCost, v)) +} + +// CacheReadCostLT applies the LT predicate on the "cache_read_cost" field. +func CacheReadCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCacheReadCost, v)) +} + +// CacheReadCostLTE applies the LTE predicate on the "cache_read_cost" field. +func CacheReadCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCacheReadCost, v)) +} + +// TotalCostEQ applies the EQ predicate on the "total_cost" field. +func TotalCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldTotalCost, v)) +} + +// TotalCostNEQ applies the NEQ predicate on the "total_cost" field. +func TotalCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldTotalCost, v)) +} + +// TotalCostIn applies the In predicate on the "total_cost" field. +func TotalCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldTotalCost, vs...)) +} + +// TotalCostNotIn applies the NotIn predicate on the "total_cost" field. +func TotalCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldTotalCost, vs...)) +} + +// TotalCostGT applies the GT predicate on the "total_cost" field. +func TotalCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldTotalCost, v)) +} + +// TotalCostGTE applies the GTE predicate on the "total_cost" field. +func TotalCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldTotalCost, v)) +} + +// TotalCostLT applies the LT predicate on the "total_cost" field. +func TotalCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldTotalCost, v)) +} + +// TotalCostLTE applies the LTE predicate on the "total_cost" field. +func TotalCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldTotalCost, v)) +} + +// ActualCostEQ applies the EQ predicate on the "actual_cost" field. +func ActualCostEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldActualCost, v)) +} + +// ActualCostNEQ applies the NEQ predicate on the "actual_cost" field. +func ActualCostNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldActualCost, v)) +} + +// ActualCostIn applies the In predicate on the "actual_cost" field. +func ActualCostIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldActualCost, vs...)) +} + +// ActualCostNotIn applies the NotIn predicate on the "actual_cost" field. +func ActualCostNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldActualCost, vs...)) +} + +// ActualCostGT applies the GT predicate on the "actual_cost" field. +func ActualCostGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldActualCost, v)) +} + +// ActualCostGTE applies the GTE predicate on the "actual_cost" field. +func ActualCostGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldActualCost, v)) +} + +// ActualCostLT applies the LT predicate on the "actual_cost" field. +func ActualCostLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldActualCost, v)) +} + +// ActualCostLTE applies the LTE predicate on the "actual_cost" field. +func ActualCostLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldActualCost, v)) +} + +// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field. +func RateMultiplierEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v)) +} + +// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field. +func RateMultiplierNEQ(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldRateMultiplier, v)) +} + +// RateMultiplierIn applies the In predicate on the "rate_multiplier" field. +func RateMultiplierIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldRateMultiplier, vs...)) +} + +// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field. +func RateMultiplierNotIn(vs ...float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldRateMultiplier, vs...)) +} + +// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field. +func RateMultiplierGT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldRateMultiplier, v)) +} + +// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field. +func RateMultiplierGTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldRateMultiplier, v)) +} + +// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field. +func RateMultiplierLT(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldRateMultiplier, v)) +} + +// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field. +func RateMultiplierLTE(v float64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v)) +} + +// BillingTypeEQ applies the EQ predicate on the "billing_type" field. +func BillingTypeEQ(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v)) +} + +// BillingTypeNEQ applies the NEQ predicate on the "billing_type" field. +func BillingTypeNEQ(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldBillingType, v)) +} + +// BillingTypeIn applies the In predicate on the "billing_type" field. +func BillingTypeIn(vs ...int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldBillingType, vs...)) +} + +// BillingTypeNotIn applies the NotIn predicate on the "billing_type" field. +func BillingTypeNotIn(vs ...int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldBillingType, vs...)) +} + +// BillingTypeGT applies the GT predicate on the "billing_type" field. +func BillingTypeGT(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldBillingType, v)) +} + +// BillingTypeGTE applies the GTE predicate on the "billing_type" field. +func BillingTypeGTE(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldBillingType, v)) +} + +// BillingTypeLT applies the LT predicate on the "billing_type" field. +func BillingTypeLT(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldBillingType, v)) +} + +// BillingTypeLTE applies the LTE predicate on the "billing_type" field. +func BillingTypeLTE(v int8) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldBillingType, v)) +} + +// StreamEQ applies the EQ predicate on the "stream" field. +func StreamEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldStream, v)) +} + +// StreamNEQ applies the NEQ predicate on the "stream" field. +func StreamNEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldStream, v)) +} + +// DurationMsEQ applies the EQ predicate on the "duration_ms" field. +func DurationMsEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldDurationMs, v)) +} + +// DurationMsNEQ applies the NEQ predicate on the "duration_ms" field. +func DurationMsNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldDurationMs, v)) +} + +// DurationMsIn applies the In predicate on the "duration_ms" field. +func DurationMsIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldDurationMs, vs...)) +} + +// DurationMsNotIn applies the NotIn predicate on the "duration_ms" field. +func DurationMsNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldDurationMs, vs...)) +} + +// DurationMsGT applies the GT predicate on the "duration_ms" field. +func DurationMsGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldDurationMs, v)) +} + +// DurationMsGTE applies the GTE predicate on the "duration_ms" field. +func DurationMsGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldDurationMs, v)) +} + +// DurationMsLT applies the LT predicate on the "duration_ms" field. +func DurationMsLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldDurationMs, v)) +} + +// DurationMsLTE applies the LTE predicate on the "duration_ms" field. +func DurationMsLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldDurationMs, v)) +} + +// DurationMsIsNil applies the IsNil predicate on the "duration_ms" field. +func DurationMsIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldDurationMs)) +} + +// DurationMsNotNil applies the NotNil predicate on the "duration_ms" field. +func DurationMsNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldDurationMs)) +} + +// FirstTokenMsEQ applies the EQ predicate on the "first_token_ms" field. +func FirstTokenMsEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldFirstTokenMs, v)) +} + +// FirstTokenMsNEQ applies the NEQ predicate on the "first_token_ms" field. +func FirstTokenMsNEQ(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldFirstTokenMs, v)) +} + +// FirstTokenMsIn applies the In predicate on the "first_token_ms" field. +func FirstTokenMsIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldFirstTokenMs, vs...)) +} + +// FirstTokenMsNotIn applies the NotIn predicate on the "first_token_ms" field. +func FirstTokenMsNotIn(vs ...int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldFirstTokenMs, vs...)) +} + +// FirstTokenMsGT applies the GT predicate on the "first_token_ms" field. +func FirstTokenMsGT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldFirstTokenMs, v)) +} + +// FirstTokenMsGTE applies the GTE predicate on the "first_token_ms" field. +func FirstTokenMsGTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldFirstTokenMs, v)) +} + +// FirstTokenMsLT applies the LT predicate on the "first_token_ms" field. +func FirstTokenMsLT(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldFirstTokenMs, v)) +} + +// FirstTokenMsLTE applies the LTE predicate on the "first_token_ms" field. +func FirstTokenMsLTE(v int) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldFirstTokenMs, v)) +} + +// FirstTokenMsIsNil applies the IsNil predicate on the "first_token_ms" field. +func FirstTokenMsIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldFirstTokenMs)) +} + +// FirstTokenMsNotNil applies the NotNil predicate on the "first_token_ms" field. +func FirstTokenMsNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldFirstTokenMs)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldCreatedAt, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAPIKey applies the HasEdge predicate on the "api_key" edge. +func HasAPIKey() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, APIKeyTable, APIKeyColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAPIKeyWith applies the HasEdge predicate on the "api_key" edge with a given conditions (other predicates). +func HasAPIKeyWith(preds ...predicate.ApiKey) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newAPIKeyStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAccount applies the HasEdge predicate on the "account" edge. +func HasAccount() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, AccountTable, AccountColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAccountWith applies the HasEdge predicate on the "account" edge with a given conditions (other predicates). +func HasAccountWith(preds ...predicate.Account) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newAccountStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasGroup applies the HasEdge predicate on the "group" edge. +func HasGroup() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasGroupWith applies the HasEdge predicate on the "group" edge with a given conditions (other predicates). +func HasGroupWith(preds ...predicate.Group) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newGroupStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasSubscription applies the HasEdge predicate on the "subscription" edge. +func HasSubscription() predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, SubscriptionTable, SubscriptionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasSubscriptionWith applies the HasEdge predicate on the "subscription" edge with a given conditions (other predicates). +func HasSubscriptionWith(preds ...predicate.UserSubscription) predicate.UsageLog { + return predicate.UsageLog(func(s *sql.Selector) { + step := newSubscriptionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.UsageLog) predicate.UsageLog { + return predicate.UsageLog(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.UsageLog) predicate.UsageLog { + return predicate.UsageLog(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.UsageLog) predicate.UsageLog { + return predicate.UsageLog(sql.NotPredicates(p)) +} diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go new file mode 100644 index 00000000..bcba64b1 --- /dev/null +++ b/backend/ent/usagelog_create.go @@ -0,0 +1,2431 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLogCreate is the builder for creating a UsageLog entity. +type UsageLogCreate struct { + config + mutation *UsageLogMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetUserID sets the "user_id" field. +func (_c *UsageLogCreate) SetUserID(v int64) *UsageLogCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetAPIKeyID sets the "api_key_id" field. +func (_c *UsageLogCreate) SetAPIKeyID(v int64) *UsageLogCreate { + _c.mutation.SetAPIKeyID(v) + return _c +} + +// SetAccountID sets the "account_id" field. +func (_c *UsageLogCreate) SetAccountID(v int64) *UsageLogCreate { + _c.mutation.SetAccountID(v) + return _c +} + +// SetRequestID sets the "request_id" field. +func (_c *UsageLogCreate) SetRequestID(v string) *UsageLogCreate { + _c.mutation.SetRequestID(v) + return _c +} + +// SetModel sets the "model" field. +func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { + _c.mutation.SetModel(v) + return _c +} + +// SetGroupID sets the "group_id" field. +func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { + _c.mutation.SetGroupID(v) + return _c +} + +// SetNillableGroupID sets the "group_id" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableGroupID(v *int64) *UsageLogCreate { + if v != nil { + _c.SetGroupID(*v) + } + return _c +} + +// SetSubscriptionID sets the "subscription_id" field. +func (_c *UsageLogCreate) SetSubscriptionID(v int64) *UsageLogCreate { + _c.mutation.SetSubscriptionID(v) + return _c +} + +// SetNillableSubscriptionID sets the "subscription_id" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableSubscriptionID(v *int64) *UsageLogCreate { + if v != nil { + _c.SetSubscriptionID(*v) + } + return _c +} + +// SetInputTokens sets the "input_tokens" field. +func (_c *UsageLogCreate) SetInputTokens(v int) *UsageLogCreate { + _c.mutation.SetInputTokens(v) + return _c +} + +// SetNillableInputTokens sets the "input_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableInputTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetInputTokens(*v) + } + return _c +} + +// SetOutputTokens sets the "output_tokens" field. +func (_c *UsageLogCreate) SetOutputTokens(v int) *UsageLogCreate { + _c.mutation.SetOutputTokens(v) + return _c +} + +// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableOutputTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetOutputTokens(*v) + } + return _c +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (_c *UsageLogCreate) SetCacheCreationTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheCreationTokens(v) + return _c +} + +// SetNillableCacheCreationTokens sets the "cache_creation_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreationTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheCreationTokens(*v) + } + return _c +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (_c *UsageLogCreate) SetCacheReadTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheReadTokens(v) + return _c +} + +// SetNillableCacheReadTokens sets the "cache_read_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheReadTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheReadTokens(*v) + } + return _c +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (_c *UsageLogCreate) SetCacheCreation5mTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheCreation5mTokens(v) + return _c +} + +// SetNillableCacheCreation5mTokens sets the "cache_creation_5m_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreation5mTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheCreation5mTokens(*v) + } + return _c +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (_c *UsageLogCreate) SetCacheCreation1hTokens(v int) *UsageLogCreate { + _c.mutation.SetCacheCreation1hTokens(v) + return _c +} + +// SetNillableCacheCreation1hTokens sets the "cache_creation_1h_tokens" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreation1hTokens(v *int) *UsageLogCreate { + if v != nil { + _c.SetCacheCreation1hTokens(*v) + } + return _c +} + +// SetInputCost sets the "input_cost" field. +func (_c *UsageLogCreate) SetInputCost(v float64) *UsageLogCreate { + _c.mutation.SetInputCost(v) + return _c +} + +// SetNillableInputCost sets the "input_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableInputCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetInputCost(*v) + } + return _c +} + +// SetOutputCost sets the "output_cost" field. +func (_c *UsageLogCreate) SetOutputCost(v float64) *UsageLogCreate { + _c.mutation.SetOutputCost(v) + return _c +} + +// SetNillableOutputCost sets the "output_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableOutputCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetOutputCost(*v) + } + return _c +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (_c *UsageLogCreate) SetCacheCreationCost(v float64) *UsageLogCreate { + _c.mutation.SetCacheCreationCost(v) + return _c +} + +// SetNillableCacheCreationCost sets the "cache_creation_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheCreationCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetCacheCreationCost(*v) + } + return _c +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (_c *UsageLogCreate) SetCacheReadCost(v float64) *UsageLogCreate { + _c.mutation.SetCacheReadCost(v) + return _c +} + +// SetNillableCacheReadCost sets the "cache_read_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheReadCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetCacheReadCost(*v) + } + return _c +} + +// SetTotalCost sets the "total_cost" field. +func (_c *UsageLogCreate) SetTotalCost(v float64) *UsageLogCreate { + _c.mutation.SetTotalCost(v) + return _c +} + +// SetNillableTotalCost sets the "total_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableTotalCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetTotalCost(*v) + } + return _c +} + +// SetActualCost sets the "actual_cost" field. +func (_c *UsageLogCreate) SetActualCost(v float64) *UsageLogCreate { + _c.mutation.SetActualCost(v) + return _c +} + +// SetNillableActualCost sets the "actual_cost" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableActualCost(v *float64) *UsageLogCreate { + if v != nil { + _c.SetActualCost(*v) + } + return _c +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (_c *UsageLogCreate) SetRateMultiplier(v float64) *UsageLogCreate { + _c.mutation.SetRateMultiplier(v) + return _c +} + +// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate { + if v != nil { + _c.SetRateMultiplier(*v) + } + return _c +} + +// SetBillingType sets the "billing_type" field. +func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate { + _c.mutation.SetBillingType(v) + return _c +} + +// SetNillableBillingType sets the "billing_type" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableBillingType(v *int8) *UsageLogCreate { + if v != nil { + _c.SetBillingType(*v) + } + return _c +} + +// SetStream sets the "stream" field. +func (_c *UsageLogCreate) SetStream(v bool) *UsageLogCreate { + _c.mutation.SetStream(v) + return _c +} + +// SetNillableStream sets the "stream" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableStream(v *bool) *UsageLogCreate { + if v != nil { + _c.SetStream(*v) + } + return _c +} + +// SetDurationMs sets the "duration_ms" field. +func (_c *UsageLogCreate) SetDurationMs(v int) *UsageLogCreate { + _c.mutation.SetDurationMs(v) + return _c +} + +// SetNillableDurationMs sets the "duration_ms" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableDurationMs(v *int) *UsageLogCreate { + if v != nil { + _c.SetDurationMs(*v) + } + return _c +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (_c *UsageLogCreate) SetFirstTokenMs(v int) *UsageLogCreate { + _c.mutation.SetFirstTokenMs(v) + return _c +} + +// SetNillableFirstTokenMs sets the "first_token_ms" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableFirstTokenMs(v *int) *UsageLogCreate { + if v != nil { + _c.SetFirstTokenMs(*v) + } + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCreatedAt(v *time.Time) *UsageLogCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *UsageLogCreate) SetUser(v *User) *UsageLogCreate { + return _c.SetUserID(v.ID) +} + +// SetAPIKey sets the "api_key" edge to the ApiKey entity. +func (_c *UsageLogCreate) SetAPIKey(v *ApiKey) *UsageLogCreate { + return _c.SetAPIKeyID(v.ID) +} + +// SetAccount sets the "account" edge to the Account entity. +func (_c *UsageLogCreate) SetAccount(v *Account) *UsageLogCreate { + return _c.SetAccountID(v.ID) +} + +// SetGroup sets the "group" edge to the Group entity. +func (_c *UsageLogCreate) SetGroup(v *Group) *UsageLogCreate { + return _c.SetGroupID(v.ID) +} + +// SetSubscription sets the "subscription" edge to the UserSubscription entity. +func (_c *UsageLogCreate) SetSubscription(v *UserSubscription) *UsageLogCreate { + return _c.SetSubscriptionID(v.ID) +} + +// Mutation returns the UsageLogMutation object of the builder. +func (_c *UsageLogCreate) Mutation() *UsageLogMutation { + return _c.mutation +} + +// Save creates the UsageLog in the database. +func (_c *UsageLogCreate) Save(ctx context.Context) (*UsageLog, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *UsageLogCreate) SaveX(ctx context.Context) *UsageLog { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UsageLogCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UsageLogCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *UsageLogCreate) defaults() { + if _, ok := _c.mutation.InputTokens(); !ok { + v := usagelog.DefaultInputTokens + _c.mutation.SetInputTokens(v) + } + if _, ok := _c.mutation.OutputTokens(); !ok { + v := usagelog.DefaultOutputTokens + _c.mutation.SetOutputTokens(v) + } + if _, ok := _c.mutation.CacheCreationTokens(); !ok { + v := usagelog.DefaultCacheCreationTokens + _c.mutation.SetCacheCreationTokens(v) + } + if _, ok := _c.mutation.CacheReadTokens(); !ok { + v := usagelog.DefaultCacheReadTokens + _c.mutation.SetCacheReadTokens(v) + } + if _, ok := _c.mutation.CacheCreation5mTokens(); !ok { + v := usagelog.DefaultCacheCreation5mTokens + _c.mutation.SetCacheCreation5mTokens(v) + } + if _, ok := _c.mutation.CacheCreation1hTokens(); !ok { + v := usagelog.DefaultCacheCreation1hTokens + _c.mutation.SetCacheCreation1hTokens(v) + } + if _, ok := _c.mutation.InputCost(); !ok { + v := usagelog.DefaultInputCost + _c.mutation.SetInputCost(v) + } + if _, ok := _c.mutation.OutputCost(); !ok { + v := usagelog.DefaultOutputCost + _c.mutation.SetOutputCost(v) + } + if _, ok := _c.mutation.CacheCreationCost(); !ok { + v := usagelog.DefaultCacheCreationCost + _c.mutation.SetCacheCreationCost(v) + } + if _, ok := _c.mutation.CacheReadCost(); !ok { + v := usagelog.DefaultCacheReadCost + _c.mutation.SetCacheReadCost(v) + } + if _, ok := _c.mutation.TotalCost(); !ok { + v := usagelog.DefaultTotalCost + _c.mutation.SetTotalCost(v) + } + if _, ok := _c.mutation.ActualCost(); !ok { + v := usagelog.DefaultActualCost + _c.mutation.SetActualCost(v) + } + if _, ok := _c.mutation.RateMultiplier(); !ok { + v := usagelog.DefaultRateMultiplier + _c.mutation.SetRateMultiplier(v) + } + if _, ok := _c.mutation.BillingType(); !ok { + v := usagelog.DefaultBillingType + _c.mutation.SetBillingType(v) + } + if _, ok := _c.mutation.Stream(); !ok { + v := usagelog.DefaultStream + _c.mutation.SetStream(v) + } + if _, ok := _c.mutation.CreatedAt(); !ok { + v := usagelog.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *UsageLogCreate) check() error { + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "UsageLog.user_id"`)} + } + if _, ok := _c.mutation.APIKeyID(); !ok { + return &ValidationError{Name: "api_key_id", err: errors.New(`ent: missing required field "UsageLog.api_key_id"`)} + } + if _, ok := _c.mutation.AccountID(); !ok { + return &ValidationError{Name: "account_id", err: errors.New(`ent: missing required field "UsageLog.account_id"`)} + } + if _, ok := _c.mutation.RequestID(); !ok { + return &ValidationError{Name: "request_id", err: errors.New(`ent: missing required field "UsageLog.request_id"`)} + } + if v, ok := _c.mutation.RequestID(); ok { + if err := usagelog.RequestIDValidator(v); err != nil { + return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "UsageLog.request_id": %w`, err)} + } + } + if _, ok := _c.mutation.Model(); !ok { + return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "UsageLog.model"`)} + } + if v, ok := _c.mutation.Model(); ok { + if err := usagelog.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} + } + } + if _, ok := _c.mutation.InputTokens(); !ok { + return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} + } + if _, ok := _c.mutation.OutputTokens(); !ok { + return &ValidationError{Name: "output_tokens", err: errors.New(`ent: missing required field "UsageLog.output_tokens"`)} + } + if _, ok := _c.mutation.CacheCreationTokens(); !ok { + return &ValidationError{Name: "cache_creation_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_creation_tokens"`)} + } + if _, ok := _c.mutation.CacheReadTokens(); !ok { + return &ValidationError{Name: "cache_read_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_read_tokens"`)} + } + if _, ok := _c.mutation.CacheCreation5mTokens(); !ok { + return &ValidationError{Name: "cache_creation_5m_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_creation_5m_tokens"`)} + } + if _, ok := _c.mutation.CacheCreation1hTokens(); !ok { + return &ValidationError{Name: "cache_creation_1h_tokens", err: errors.New(`ent: missing required field "UsageLog.cache_creation_1h_tokens"`)} + } + if _, ok := _c.mutation.InputCost(); !ok { + return &ValidationError{Name: "input_cost", err: errors.New(`ent: missing required field "UsageLog.input_cost"`)} + } + if _, ok := _c.mutation.OutputCost(); !ok { + return &ValidationError{Name: "output_cost", err: errors.New(`ent: missing required field "UsageLog.output_cost"`)} + } + if _, ok := _c.mutation.CacheCreationCost(); !ok { + return &ValidationError{Name: "cache_creation_cost", err: errors.New(`ent: missing required field "UsageLog.cache_creation_cost"`)} + } + if _, ok := _c.mutation.CacheReadCost(); !ok { + return &ValidationError{Name: "cache_read_cost", err: errors.New(`ent: missing required field "UsageLog.cache_read_cost"`)} + } + if _, ok := _c.mutation.TotalCost(); !ok { + return &ValidationError{Name: "total_cost", err: errors.New(`ent: missing required field "UsageLog.total_cost"`)} + } + if _, ok := _c.mutation.ActualCost(); !ok { + return &ValidationError{Name: "actual_cost", err: errors.New(`ent: missing required field "UsageLog.actual_cost"`)} + } + if _, ok := _c.mutation.RateMultiplier(); !ok { + return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "UsageLog.rate_multiplier"`)} + } + if _, ok := _c.mutation.BillingType(); !ok { + return &ValidationError{Name: "billing_type", err: errors.New(`ent: missing required field "UsageLog.billing_type"`)} + } + if _, ok := _c.mutation.Stream(); !ok { + return &ValidationError{Name: "stream", err: errors.New(`ent: missing required field "UsageLog.stream"`)} + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "UsageLog.user"`)} + } + if len(_c.mutation.APIKeyIDs()) == 0 { + return &ValidationError{Name: "api_key", err: errors.New(`ent: missing required edge "UsageLog.api_key"`)} + } + if len(_c.mutation.AccountIDs()) == 0 { + return &ValidationError{Name: "account", err: errors.New(`ent: missing required edge "UsageLog.account"`)} + } + return nil +} + +func (_c *UsageLogCreate) sqlSave(ctx context.Context) (*UsageLog, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { + var ( + _node = &UsageLog{config: _c.config} + _spec = sqlgraph.NewCreateSpec(usagelog.Table, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.RequestID(); ok { + _spec.SetField(usagelog.FieldRequestID, field.TypeString, value) + _node.RequestID = value + } + if value, ok := _c.mutation.Model(); ok { + _spec.SetField(usagelog.FieldModel, field.TypeString, value) + _node.Model = value + } + if value, ok := _c.mutation.InputTokens(); ok { + _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) + _node.InputTokens = value + } + if value, ok := _c.mutation.OutputTokens(); ok { + _spec.SetField(usagelog.FieldOutputTokens, field.TypeInt, value) + _node.OutputTokens = value + } + if value, ok := _c.mutation.CacheCreationTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + _node.CacheCreationTokens = value + } + if value, ok := _c.mutation.CacheReadTokens(); ok { + _spec.SetField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + _node.CacheReadTokens = value + } + if value, ok := _c.mutation.CacheCreation5mTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + _node.CacheCreation5mTokens = value + } + if value, ok := _c.mutation.CacheCreation1hTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + _node.CacheCreation1hTokens = value + } + if value, ok := _c.mutation.InputCost(); ok { + _spec.SetField(usagelog.FieldInputCost, field.TypeFloat64, value) + _node.InputCost = value + } + if value, ok := _c.mutation.OutputCost(); ok { + _spec.SetField(usagelog.FieldOutputCost, field.TypeFloat64, value) + _node.OutputCost = value + } + if value, ok := _c.mutation.CacheCreationCost(); ok { + _spec.SetField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + _node.CacheCreationCost = value + } + if value, ok := _c.mutation.CacheReadCost(); ok { + _spec.SetField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + _node.CacheReadCost = value + } + if value, ok := _c.mutation.TotalCost(); ok { + _spec.SetField(usagelog.FieldTotalCost, field.TypeFloat64, value) + _node.TotalCost = value + } + if value, ok := _c.mutation.ActualCost(); ok { + _spec.SetField(usagelog.FieldActualCost, field.TypeFloat64, value) + _node.ActualCost = value + } + if value, ok := _c.mutation.RateMultiplier(); ok { + _spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + _node.RateMultiplier = value + } + if value, ok := _c.mutation.BillingType(); ok { + _spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value) + _node.BillingType = value + } + if value, ok := _c.mutation.Stream(); ok { + _spec.SetField(usagelog.FieldStream, field.TypeBool, value) + _node.Stream = value + } + if value, ok := _c.mutation.DurationMs(); ok { + _spec.SetField(usagelog.FieldDurationMs, field.TypeInt, value) + _node.DurationMs = &value + } + if value, ok := _c.mutation.FirstTokenMs(); ok { + _spec.SetField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + _node.FirstTokenMs = &value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.APIKeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.APIKeyID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AccountIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.AccountID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.GroupID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.SubscriptionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.SubscriptionID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UsageLog.Create(). +// SetUserID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UsageLogUpsert) { +// SetUserID(v+v). +// }). +// Exec(ctx) +func (_c *UsageLogCreate) OnConflict(opts ...sql.ConflictOption) *UsageLogUpsertOne { + _c.conflict = opts + return &UsageLogUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UsageLogCreate) OnConflictColumns(columns ...string) *UsageLogUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UsageLogUpsertOne{ + create: _c, + } +} + +type ( + // UsageLogUpsertOne is the builder for "upsert"-ing + // one UsageLog node. + UsageLogUpsertOne struct { + create *UsageLogCreate + } + + // UsageLogUpsert is the "OnConflict" setter. + UsageLogUpsert struct { + *sql.UpdateSet + } +) + +// SetUserID sets the "user_id" field. +func (u *UsageLogUpsert) SetUserID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateUserID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldUserID) + return u +} + +// SetAPIKeyID sets the "api_key_id" field. +func (u *UsageLogUpsert) SetAPIKeyID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldAPIKeyID, v) + return u +} + +// UpdateAPIKeyID sets the "api_key_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateAPIKeyID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldAPIKeyID) + return u +} + +// SetAccountID sets the "account_id" field. +func (u *UsageLogUpsert) SetAccountID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldAccountID, v) + return u +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateAccountID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldAccountID) + return u +} + +// SetRequestID sets the "request_id" field. +func (u *UsageLogUpsert) SetRequestID(v string) *UsageLogUpsert { + u.Set(usagelog.FieldRequestID, v) + return u +} + +// UpdateRequestID sets the "request_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateRequestID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldRequestID) + return u +} + +// SetModel sets the "model" field. +func (u *UsageLogUpsert) SetModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldModel, v) + return u +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldModel) + return u +} + +// SetGroupID sets the "group_id" field. +func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldGroupID, v) + return u +} + +// UpdateGroupID sets the "group_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateGroupID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldGroupID) + return u +} + +// ClearGroupID clears the value of the "group_id" field. +func (u *UsageLogUpsert) ClearGroupID() *UsageLogUpsert { + u.SetNull(usagelog.FieldGroupID) + return u +} + +// SetSubscriptionID sets the "subscription_id" field. +func (u *UsageLogUpsert) SetSubscriptionID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldSubscriptionID, v) + return u +} + +// UpdateSubscriptionID sets the "subscription_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateSubscriptionID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldSubscriptionID) + return u +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (u *UsageLogUpsert) ClearSubscriptionID() *UsageLogUpsert { + u.SetNull(usagelog.FieldSubscriptionID) + return u +} + +// SetInputTokens sets the "input_tokens" field. +func (u *UsageLogUpsert) SetInputTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldInputTokens, v) + return u +} + +// UpdateInputTokens sets the "input_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateInputTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldInputTokens) + return u +} + +// AddInputTokens adds v to the "input_tokens" field. +func (u *UsageLogUpsert) AddInputTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldInputTokens, v) + return u +} + +// SetOutputTokens sets the "output_tokens" field. +func (u *UsageLogUpsert) SetOutputTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldOutputTokens, v) + return u +} + +// UpdateOutputTokens sets the "output_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateOutputTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldOutputTokens) + return u +} + +// AddOutputTokens adds v to the "output_tokens" field. +func (u *UsageLogUpsert) AddOutputTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldOutputTokens, v) + return u +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (u *UsageLogUpsert) SetCacheCreationTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreationTokens, v) + return u +} + +// UpdateCacheCreationTokens sets the "cache_creation_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreationTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreationTokens) + return u +} + +// AddCacheCreationTokens adds v to the "cache_creation_tokens" field. +func (u *UsageLogUpsert) AddCacheCreationTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreationTokens, v) + return u +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (u *UsageLogUpsert) SetCacheReadTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheReadTokens, v) + return u +} + +// UpdateCacheReadTokens sets the "cache_read_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheReadTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheReadTokens) + return u +} + +// AddCacheReadTokens adds v to the "cache_read_tokens" field. +func (u *UsageLogUpsert) AddCacheReadTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheReadTokens, v) + return u +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsert) SetCacheCreation5mTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreation5mTokens, v) + return u +} + +// UpdateCacheCreation5mTokens sets the "cache_creation_5m_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreation5mTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreation5mTokens) + return u +} + +// AddCacheCreation5mTokens adds v to the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsert) AddCacheCreation5mTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreation5mTokens, v) + return u +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsert) SetCacheCreation1hTokens(v int) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreation1hTokens, v) + return u +} + +// UpdateCacheCreation1hTokens sets the "cache_creation_1h_tokens" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreation1hTokens() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreation1hTokens) + return u +} + +// AddCacheCreation1hTokens adds v to the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsert) AddCacheCreation1hTokens(v int) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreation1hTokens, v) + return u +} + +// SetInputCost sets the "input_cost" field. +func (u *UsageLogUpsert) SetInputCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldInputCost, v) + return u +} + +// UpdateInputCost sets the "input_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateInputCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldInputCost) + return u +} + +// AddInputCost adds v to the "input_cost" field. +func (u *UsageLogUpsert) AddInputCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldInputCost, v) + return u +} + +// SetOutputCost sets the "output_cost" field. +func (u *UsageLogUpsert) SetOutputCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldOutputCost, v) + return u +} + +// UpdateOutputCost sets the "output_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateOutputCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldOutputCost) + return u +} + +// AddOutputCost adds v to the "output_cost" field. +func (u *UsageLogUpsert) AddOutputCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldOutputCost, v) + return u +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (u *UsageLogUpsert) SetCacheCreationCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldCacheCreationCost, v) + return u +} + +// UpdateCacheCreationCost sets the "cache_creation_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheCreationCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheCreationCost) + return u +} + +// AddCacheCreationCost adds v to the "cache_creation_cost" field. +func (u *UsageLogUpsert) AddCacheCreationCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldCacheCreationCost, v) + return u +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (u *UsageLogUpsert) SetCacheReadCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldCacheReadCost, v) + return u +} + +// UpdateCacheReadCost sets the "cache_read_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheReadCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheReadCost) + return u +} + +// AddCacheReadCost adds v to the "cache_read_cost" field. +func (u *UsageLogUpsert) AddCacheReadCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldCacheReadCost, v) + return u +} + +// SetTotalCost sets the "total_cost" field. +func (u *UsageLogUpsert) SetTotalCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldTotalCost, v) + return u +} + +// UpdateTotalCost sets the "total_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateTotalCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldTotalCost) + return u +} + +// AddTotalCost adds v to the "total_cost" field. +func (u *UsageLogUpsert) AddTotalCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldTotalCost, v) + return u +} + +// SetActualCost sets the "actual_cost" field. +func (u *UsageLogUpsert) SetActualCost(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldActualCost, v) + return u +} + +// UpdateActualCost sets the "actual_cost" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateActualCost() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldActualCost) + return u +} + +// AddActualCost adds v to the "actual_cost" field. +func (u *UsageLogUpsert) AddActualCost(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldActualCost, v) + return u +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (u *UsageLogUpsert) SetRateMultiplier(v float64) *UsageLogUpsert { + u.Set(usagelog.FieldRateMultiplier, v) + return u +} + +// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateRateMultiplier() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldRateMultiplier) + return u +} + +// AddRateMultiplier adds v to the "rate_multiplier" field. +func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert { + u.Add(usagelog.FieldRateMultiplier, v) + return u +} + +// SetBillingType sets the "billing_type" field. +func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert { + u.Set(usagelog.FieldBillingType, v) + return u +} + +// UpdateBillingType sets the "billing_type" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateBillingType() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldBillingType) + return u +} + +// AddBillingType adds v to the "billing_type" field. +func (u *UsageLogUpsert) AddBillingType(v int8) *UsageLogUpsert { + u.Add(usagelog.FieldBillingType, v) + return u +} + +// SetStream sets the "stream" field. +func (u *UsageLogUpsert) SetStream(v bool) *UsageLogUpsert { + u.Set(usagelog.FieldStream, v) + return u +} + +// UpdateStream sets the "stream" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateStream() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldStream) + return u +} + +// SetDurationMs sets the "duration_ms" field. +func (u *UsageLogUpsert) SetDurationMs(v int) *UsageLogUpsert { + u.Set(usagelog.FieldDurationMs, v) + return u +} + +// UpdateDurationMs sets the "duration_ms" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateDurationMs() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldDurationMs) + return u +} + +// AddDurationMs adds v to the "duration_ms" field. +func (u *UsageLogUpsert) AddDurationMs(v int) *UsageLogUpsert { + u.Add(usagelog.FieldDurationMs, v) + return u +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (u *UsageLogUpsert) ClearDurationMs() *UsageLogUpsert { + u.SetNull(usagelog.FieldDurationMs) + return u +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (u *UsageLogUpsert) SetFirstTokenMs(v int) *UsageLogUpsert { + u.Set(usagelog.FieldFirstTokenMs, v) + return u +} + +// UpdateFirstTokenMs sets the "first_token_ms" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateFirstTokenMs() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldFirstTokenMs) + return u +} + +// AddFirstTokenMs adds v to the "first_token_ms" field. +func (u *UsageLogUpsert) AddFirstTokenMs(v int) *UsageLogUpsert { + u.Add(usagelog.FieldFirstTokenMs, v) + return u +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (u *UsageLogUpsert) ClearFirstTokenMs() *UsageLogUpsert { + u.SetNull(usagelog.FieldFirstTokenMs) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UsageLogUpsertOne) UpdateNewValues() *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(usagelog.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UsageLogUpsertOne) Ignore() *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UsageLogUpsertOne) DoNothing() *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UsageLogCreate.OnConflict +// documentation for more info. +func (u *UsageLogUpsertOne) Update(set func(*UsageLogUpsert)) *UsageLogUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UsageLogUpsert{UpdateSet: update}) + })) + return u +} + +// SetUserID sets the "user_id" field. +func (u *UsageLogUpsertOne) SetUserID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateUserID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUserID() + }) +} + +// SetAPIKeyID sets the "api_key_id" field. +func (u *UsageLogUpsertOne) SetAPIKeyID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetAPIKeyID(v) + }) +} + +// UpdateAPIKeyID sets the "api_key_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateAPIKeyID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAPIKeyID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *UsageLogUpsertOne) SetAccountID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateAccountID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAccountID() + }) +} + +// SetRequestID sets the "request_id" field. +func (u *UsageLogUpsertOne) SetRequestID(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestID(v) + }) +} + +// UpdateRequestID sets the "request_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateRequestID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestID() + }) +} + +// SetModel sets the "model" field. +func (u *UsageLogUpsertOne) SetModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModel() + }) +} + +// SetGroupID sets the "group_id" field. +func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetGroupID(v) + }) +} + +// UpdateGroupID sets the "group_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateGroupID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateGroupID() + }) +} + +// ClearGroupID clears the value of the "group_id" field. +func (u *UsageLogUpsertOne) ClearGroupID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearGroupID() + }) +} + +// SetSubscriptionID sets the "subscription_id" field. +func (u *UsageLogUpsertOne) SetSubscriptionID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetSubscriptionID(v) + }) +} + +// UpdateSubscriptionID sets the "subscription_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateSubscriptionID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateSubscriptionID() + }) +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (u *UsageLogUpsertOne) ClearSubscriptionID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearSubscriptionID() + }) +} + +// SetInputTokens sets the "input_tokens" field. +func (u *UsageLogUpsertOne) SetInputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputTokens(v) + }) +} + +// AddInputTokens adds v to the "input_tokens" field. +func (u *UsageLogUpsertOne) AddInputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputTokens(v) + }) +} + +// UpdateInputTokens sets the "input_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateInputTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputTokens() + }) +} + +// SetOutputTokens sets the "output_tokens" field. +func (u *UsageLogUpsertOne) SetOutputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputTokens(v) + }) +} + +// AddOutputTokens adds v to the "output_tokens" field. +func (u *UsageLogUpsertOne) AddOutputTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputTokens(v) + }) +} + +// UpdateOutputTokens sets the "output_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateOutputTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputTokens() + }) +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (u *UsageLogUpsertOne) SetCacheCreationTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationTokens(v) + }) +} + +// AddCacheCreationTokens adds v to the "cache_creation_tokens" field. +func (u *UsageLogUpsertOne) AddCacheCreationTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationTokens(v) + }) +} + +// UpdateCacheCreationTokens sets the "cache_creation_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreationTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationTokens() + }) +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (u *UsageLogUpsertOne) SetCacheReadTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadTokens(v) + }) +} + +// AddCacheReadTokens adds v to the "cache_read_tokens" field. +func (u *UsageLogUpsertOne) AddCacheReadTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadTokens(v) + }) +} + +// UpdateCacheReadTokens sets the "cache_read_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheReadTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadTokens() + }) +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertOne) SetCacheCreation5mTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation5mTokens(v) + }) +} + +// AddCacheCreation5mTokens adds v to the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertOne) AddCacheCreation5mTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation5mTokens(v) + }) +} + +// UpdateCacheCreation5mTokens sets the "cache_creation_5m_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreation5mTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation5mTokens() + }) +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertOne) SetCacheCreation1hTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation1hTokens(v) + }) +} + +// AddCacheCreation1hTokens adds v to the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertOne) AddCacheCreation1hTokens(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation1hTokens(v) + }) +} + +// UpdateCacheCreation1hTokens sets the "cache_creation_1h_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreation1hTokens() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation1hTokens() + }) +} + +// SetInputCost sets the "input_cost" field. +func (u *UsageLogUpsertOne) SetInputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputCost(v) + }) +} + +// AddInputCost adds v to the "input_cost" field. +func (u *UsageLogUpsertOne) AddInputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputCost(v) + }) +} + +// UpdateInputCost sets the "input_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateInputCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputCost() + }) +} + +// SetOutputCost sets the "output_cost" field. +func (u *UsageLogUpsertOne) SetOutputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputCost(v) + }) +} + +// AddOutputCost adds v to the "output_cost" field. +func (u *UsageLogUpsertOne) AddOutputCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputCost(v) + }) +} + +// UpdateOutputCost sets the "output_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateOutputCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputCost() + }) +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (u *UsageLogUpsertOne) SetCacheCreationCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationCost(v) + }) +} + +// AddCacheCreationCost adds v to the "cache_creation_cost" field. +func (u *UsageLogUpsertOne) AddCacheCreationCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationCost(v) + }) +} + +// UpdateCacheCreationCost sets the "cache_creation_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheCreationCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationCost() + }) +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (u *UsageLogUpsertOne) SetCacheReadCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadCost(v) + }) +} + +// AddCacheReadCost adds v to the "cache_read_cost" field. +func (u *UsageLogUpsertOne) AddCacheReadCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadCost(v) + }) +} + +// UpdateCacheReadCost sets the "cache_read_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheReadCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadCost() + }) +} + +// SetTotalCost sets the "total_cost" field. +func (u *UsageLogUpsertOne) SetTotalCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetTotalCost(v) + }) +} + +// AddTotalCost adds v to the "total_cost" field. +func (u *UsageLogUpsertOne) AddTotalCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddTotalCost(v) + }) +} + +// UpdateTotalCost sets the "total_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateTotalCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateTotalCost() + }) +} + +// SetActualCost sets the "actual_cost" field. +func (u *UsageLogUpsertOne) SetActualCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetActualCost(v) + }) +} + +// AddActualCost adds v to the "actual_cost" field. +func (u *UsageLogUpsertOne) AddActualCost(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddActualCost(v) + }) +} + +// UpdateActualCost sets the "actual_cost" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateActualCost() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateActualCost() + }) +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (u *UsageLogUpsertOne) SetRateMultiplier(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetRateMultiplier(v) + }) +} + +// AddRateMultiplier adds v to the "rate_multiplier" field. +func (u *UsageLogUpsertOne) AddRateMultiplier(v float64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddRateMultiplier(v) + }) +} + +// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRateMultiplier() + }) +} + +// SetBillingType sets the "billing_type" field. +func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingType(v) + }) +} + +// AddBillingType adds v to the "billing_type" field. +func (u *UsageLogUpsertOne) AddBillingType(v int8) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddBillingType(v) + }) +} + +// UpdateBillingType sets the "billing_type" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateBillingType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingType() + }) +} + +// SetStream sets the "stream" field. +func (u *UsageLogUpsertOne) SetStream(v bool) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetStream(v) + }) +} + +// UpdateStream sets the "stream" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateStream() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateStream() + }) +} + +// SetDurationMs sets the "duration_ms" field. +func (u *UsageLogUpsertOne) SetDurationMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetDurationMs(v) + }) +} + +// AddDurationMs adds v to the "duration_ms" field. +func (u *UsageLogUpsertOne) AddDurationMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddDurationMs(v) + }) +} + +// UpdateDurationMs sets the "duration_ms" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateDurationMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateDurationMs() + }) +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (u *UsageLogUpsertOne) ClearDurationMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearDurationMs() + }) +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (u *UsageLogUpsertOne) SetFirstTokenMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetFirstTokenMs(v) + }) +} + +// AddFirstTokenMs adds v to the "first_token_ms" field. +func (u *UsageLogUpsertOne) AddFirstTokenMs(v int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddFirstTokenMs(v) + }) +} + +// UpdateFirstTokenMs sets the "first_token_ms" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateFirstTokenMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateFirstTokenMs() + }) +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (u *UsageLogUpsertOne) ClearFirstTokenMs() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearFirstTokenMs() + }) +} + +// Exec executes the query. +func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UsageLogCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UsageLogUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *UsageLogUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *UsageLogUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// UsageLogCreateBulk is the builder for creating many UsageLog entities in bulk. +type UsageLogCreateBulk struct { + config + err error + builders []*UsageLogCreate + conflict []sql.ConflictOption +} + +// Save creates the UsageLog entities in the database. +func (_c *UsageLogCreateBulk) Save(ctx context.Context) ([]*UsageLog, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*UsageLog, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*UsageLogMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *UsageLogCreateBulk) SaveX(ctx context.Context) []*UsageLog { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UsageLogCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UsageLogCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UsageLog.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UsageLogUpsert) { +// SetUserID(v+v). +// }). +// Exec(ctx) +func (_c *UsageLogCreateBulk) OnConflict(opts ...sql.ConflictOption) *UsageLogUpsertBulk { + _c.conflict = opts + return &UsageLogUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UsageLogCreateBulk) OnConflictColumns(columns ...string) *UsageLogUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UsageLogUpsertBulk{ + create: _c, + } +} + +// UsageLogUpsertBulk is the builder for "upsert"-ing +// a bulk of UsageLog nodes. +type UsageLogUpsertBulk struct { + create *UsageLogCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UsageLogUpsertBulk) UpdateNewValues() *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(usagelog.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UsageLog.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UsageLogUpsertBulk) Ignore() *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UsageLogUpsertBulk) DoNothing() *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UsageLogCreateBulk.OnConflict +// documentation for more info. +func (u *UsageLogUpsertBulk) Update(set func(*UsageLogUpsert)) *UsageLogUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UsageLogUpsert{UpdateSet: update}) + })) + return u +} + +// SetUserID sets the "user_id" field. +func (u *UsageLogUpsertBulk) SetUserID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateUserID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUserID() + }) +} + +// SetAPIKeyID sets the "api_key_id" field. +func (u *UsageLogUpsertBulk) SetAPIKeyID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetAPIKeyID(v) + }) +} + +// UpdateAPIKeyID sets the "api_key_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateAPIKeyID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAPIKeyID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *UsageLogUpsertBulk) SetAccountID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateAccountID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateAccountID() + }) +} + +// SetRequestID sets the "request_id" field. +func (u *UsageLogUpsertBulk) SetRequestID(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetRequestID(v) + }) +} + +// UpdateRequestID sets the "request_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateRequestID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRequestID() + }) +} + +// SetModel sets the "model" field. +func (u *UsageLogUpsertBulk) SetModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModel() + }) +} + +// SetGroupID sets the "group_id" field. +func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetGroupID(v) + }) +} + +// UpdateGroupID sets the "group_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateGroupID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateGroupID() + }) +} + +// ClearGroupID clears the value of the "group_id" field. +func (u *UsageLogUpsertBulk) ClearGroupID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearGroupID() + }) +} + +// SetSubscriptionID sets the "subscription_id" field. +func (u *UsageLogUpsertBulk) SetSubscriptionID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetSubscriptionID(v) + }) +} + +// UpdateSubscriptionID sets the "subscription_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateSubscriptionID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateSubscriptionID() + }) +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (u *UsageLogUpsertBulk) ClearSubscriptionID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearSubscriptionID() + }) +} + +// SetInputTokens sets the "input_tokens" field. +func (u *UsageLogUpsertBulk) SetInputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputTokens(v) + }) +} + +// AddInputTokens adds v to the "input_tokens" field. +func (u *UsageLogUpsertBulk) AddInputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputTokens(v) + }) +} + +// UpdateInputTokens sets the "input_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateInputTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputTokens() + }) +} + +// SetOutputTokens sets the "output_tokens" field. +func (u *UsageLogUpsertBulk) SetOutputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputTokens(v) + }) +} + +// AddOutputTokens adds v to the "output_tokens" field. +func (u *UsageLogUpsertBulk) AddOutputTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputTokens(v) + }) +} + +// UpdateOutputTokens sets the "output_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateOutputTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputTokens() + }) +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheCreationTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationTokens(v) + }) +} + +// AddCacheCreationTokens adds v to the "cache_creation_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheCreationTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationTokens(v) + }) +} + +// UpdateCacheCreationTokens sets the "cache_creation_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreationTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationTokens() + }) +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheReadTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadTokens(v) + }) +} + +// AddCacheReadTokens adds v to the "cache_read_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheReadTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadTokens(v) + }) +} + +// UpdateCacheReadTokens sets the "cache_read_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheReadTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadTokens() + }) +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheCreation5mTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation5mTokens(v) + }) +} + +// AddCacheCreation5mTokens adds v to the "cache_creation_5m_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheCreation5mTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation5mTokens(v) + }) +} + +// UpdateCacheCreation5mTokens sets the "cache_creation_5m_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreation5mTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation5mTokens() + }) +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertBulk) SetCacheCreation1hTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreation1hTokens(v) + }) +} + +// AddCacheCreation1hTokens adds v to the "cache_creation_1h_tokens" field. +func (u *UsageLogUpsertBulk) AddCacheCreation1hTokens(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreation1hTokens(v) + }) +} + +// UpdateCacheCreation1hTokens sets the "cache_creation_1h_tokens" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreation1hTokens() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreation1hTokens() + }) +} + +// SetInputCost sets the "input_cost" field. +func (u *UsageLogUpsertBulk) SetInputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetInputCost(v) + }) +} + +// AddInputCost adds v to the "input_cost" field. +func (u *UsageLogUpsertBulk) AddInputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddInputCost(v) + }) +} + +// UpdateInputCost sets the "input_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateInputCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateInputCost() + }) +} + +// SetOutputCost sets the "output_cost" field. +func (u *UsageLogUpsertBulk) SetOutputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetOutputCost(v) + }) +} + +// AddOutputCost adds v to the "output_cost" field. +func (u *UsageLogUpsertBulk) AddOutputCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddOutputCost(v) + }) +} + +// UpdateOutputCost sets the "output_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateOutputCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateOutputCost() + }) +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (u *UsageLogUpsertBulk) SetCacheCreationCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheCreationCost(v) + }) +} + +// AddCacheCreationCost adds v to the "cache_creation_cost" field. +func (u *UsageLogUpsertBulk) AddCacheCreationCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheCreationCost(v) + }) +} + +// UpdateCacheCreationCost sets the "cache_creation_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheCreationCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheCreationCost() + }) +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (u *UsageLogUpsertBulk) SetCacheReadCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheReadCost(v) + }) +} + +// AddCacheReadCost adds v to the "cache_read_cost" field. +func (u *UsageLogUpsertBulk) AddCacheReadCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddCacheReadCost(v) + }) +} + +// UpdateCacheReadCost sets the "cache_read_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheReadCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheReadCost() + }) +} + +// SetTotalCost sets the "total_cost" field. +func (u *UsageLogUpsertBulk) SetTotalCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetTotalCost(v) + }) +} + +// AddTotalCost adds v to the "total_cost" field. +func (u *UsageLogUpsertBulk) AddTotalCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddTotalCost(v) + }) +} + +// UpdateTotalCost sets the "total_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateTotalCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateTotalCost() + }) +} + +// SetActualCost sets the "actual_cost" field. +func (u *UsageLogUpsertBulk) SetActualCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetActualCost(v) + }) +} + +// AddActualCost adds v to the "actual_cost" field. +func (u *UsageLogUpsertBulk) AddActualCost(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddActualCost(v) + }) +} + +// UpdateActualCost sets the "actual_cost" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateActualCost() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateActualCost() + }) +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (u *UsageLogUpsertBulk) SetRateMultiplier(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetRateMultiplier(v) + }) +} + +// AddRateMultiplier adds v to the "rate_multiplier" field. +func (u *UsageLogUpsertBulk) AddRateMultiplier(v float64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddRateMultiplier(v) + }) +} + +// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateRateMultiplier() + }) +} + +// SetBillingType sets the "billing_type" field. +func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingType(v) + }) +} + +// AddBillingType adds v to the "billing_type" field. +func (u *UsageLogUpsertBulk) AddBillingType(v int8) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddBillingType(v) + }) +} + +// UpdateBillingType sets the "billing_type" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateBillingType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingType() + }) +} + +// SetStream sets the "stream" field. +func (u *UsageLogUpsertBulk) SetStream(v bool) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetStream(v) + }) +} + +// UpdateStream sets the "stream" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateStream() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateStream() + }) +} + +// SetDurationMs sets the "duration_ms" field. +func (u *UsageLogUpsertBulk) SetDurationMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetDurationMs(v) + }) +} + +// AddDurationMs adds v to the "duration_ms" field. +func (u *UsageLogUpsertBulk) AddDurationMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddDurationMs(v) + }) +} + +// UpdateDurationMs sets the "duration_ms" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateDurationMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateDurationMs() + }) +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (u *UsageLogUpsertBulk) ClearDurationMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearDurationMs() + }) +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (u *UsageLogUpsertBulk) SetFirstTokenMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetFirstTokenMs(v) + }) +} + +// AddFirstTokenMs adds v to the "first_token_ms" field. +func (u *UsageLogUpsertBulk) AddFirstTokenMs(v int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddFirstTokenMs(v) + }) +} + +// UpdateFirstTokenMs sets the "first_token_ms" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateFirstTokenMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateFirstTokenMs() + }) +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (u *UsageLogUpsertBulk) ClearFirstTokenMs() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearFirstTokenMs() + }) +} + +// Exec executes the query. +func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UsageLogCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UsageLogCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UsageLogUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/usagelog_delete.go b/backend/ent/usagelog_delete.go new file mode 100644 index 00000000..73450fda --- /dev/null +++ b/backend/ent/usagelog_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" +) + +// UsageLogDelete is the builder for deleting a UsageLog entity. +type UsageLogDelete struct { + config + hooks []Hook + mutation *UsageLogMutation +} + +// Where appends a list predicates to the UsageLogDelete builder. +func (_d *UsageLogDelete) Where(ps ...predicate.UsageLog) *UsageLogDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *UsageLogDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UsageLogDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *UsageLogDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(usagelog.Table, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// UsageLogDeleteOne is the builder for deleting a single UsageLog entity. +type UsageLogDeleteOne struct { + _d *UsageLogDelete +} + +// Where appends a list predicates to the UsageLogDelete builder. +func (_d *UsageLogDeleteOne) Where(ps ...predicate.UsageLog) *UsageLogDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *UsageLogDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{usagelog.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UsageLogDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/usagelog_query.go b/backend/ent/usagelog_query.go new file mode 100644 index 00000000..8e5013cc --- /dev/null +++ b/backend/ent/usagelog_query.go @@ -0,0 +1,912 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLogQuery is the builder for querying UsageLog entities. +type UsageLogQuery struct { + config + ctx *QueryContext + order []usagelog.OrderOption + inters []Interceptor + predicates []predicate.UsageLog + withUser *UserQuery + withAPIKey *ApiKeyQuery + withAccount *AccountQuery + withGroup *GroupQuery + withSubscription *UserSubscriptionQuery + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the UsageLogQuery builder. +func (_q *UsageLogQuery) Where(ps ...predicate.UsageLog) *UsageLogQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *UsageLogQuery) Limit(limit int) *UsageLogQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *UsageLogQuery) Offset(offset int) *UsageLogQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *UsageLogQuery) Unique(unique bool) *UsageLogQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *UsageLogQuery) Order(o ...usagelog.OrderOption) *UsageLogQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *UsageLogQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.UserTable, usagelog.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAPIKey chains the current query on the "api_key" edge. +func (_q *UsageLogQuery) QueryAPIKey() *ApiKeyQuery { + query := (&ApiKeyClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(apikey.Table, apikey.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.APIKeyTable, usagelog.APIKeyColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAccount chains the current query on the "account" edge. +func (_q *UsageLogQuery) QueryAccount() *AccountQuery { + query := (&AccountClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(account.Table, account.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.AccountTable, usagelog.AccountColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryGroup chains the current query on the "group" edge. +func (_q *UsageLogQuery) QueryGroup() *GroupQuery { + query := (&GroupClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(group.Table, group.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.GroupTable, usagelog.GroupColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QuerySubscription chains the current query on the "subscription" edge. +func (_q *UsageLogQuery) QuerySubscription() *UserSubscriptionQuery { + query := (&UserSubscriptionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usagelog.Table, usagelog.FieldID, selector), + sqlgraph.To(usersubscription.Table, usersubscription.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, usagelog.SubscriptionTable, usagelog.SubscriptionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first UsageLog entity from the query. +// Returns a *NotFoundError when no UsageLog was found. +func (_q *UsageLogQuery) First(ctx context.Context) (*UsageLog, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{usagelog.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *UsageLogQuery) FirstX(ctx context.Context) *UsageLog { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first UsageLog ID from the query. +// Returns a *NotFoundError when no UsageLog ID was found. +func (_q *UsageLogQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{usagelog.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *UsageLogQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single UsageLog entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one UsageLog entity is found. +// Returns a *NotFoundError when no UsageLog entities are found. +func (_q *UsageLogQuery) Only(ctx context.Context) (*UsageLog, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{usagelog.Label} + default: + return nil, &NotSingularError{usagelog.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *UsageLogQuery) OnlyX(ctx context.Context) *UsageLog { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only UsageLog ID in the query. +// Returns a *NotSingularError when more than one UsageLog ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *UsageLogQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{usagelog.Label} + default: + err = &NotSingularError{usagelog.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *UsageLogQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of UsageLogs. +func (_q *UsageLogQuery) All(ctx context.Context) ([]*UsageLog, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*UsageLog, *UsageLogQuery]() + return withInterceptors[[]*UsageLog](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *UsageLogQuery) AllX(ctx context.Context) []*UsageLog { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of UsageLog IDs. +func (_q *UsageLogQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(usagelog.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *UsageLogQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *UsageLogQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*UsageLogQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *UsageLogQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *UsageLogQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *UsageLogQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the UsageLogQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *UsageLogQuery) Clone() *UsageLogQuery { + if _q == nil { + return nil + } + return &UsageLogQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]usagelog.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.UsageLog{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withAPIKey: _q.withAPIKey.Clone(), + withAccount: _q.withAccount.Clone(), + withGroup: _q.withGroup.Clone(), + withSubscription: _q.withSubscription.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithUser(opts ...func(*UserQuery)) *UsageLogQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// WithAPIKey tells the query-builder to eager-load the nodes that are connected to +// the "api_key" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithAPIKey(opts ...func(*ApiKeyQuery)) *UsageLogQuery { + query := (&ApiKeyClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAPIKey = query + return _q +} + +// WithAccount tells the query-builder to eager-load the nodes that are connected to +// the "account" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithAccount(opts ...func(*AccountQuery)) *UsageLogQuery { + query := (&AccountClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAccount = query + return _q +} + +// WithGroup tells the query-builder to eager-load the nodes that are connected to +// the "group" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithGroup(opts ...func(*GroupQuery)) *UsageLogQuery { + query := (&GroupClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withGroup = query + return _q +} + +// WithSubscription tells the query-builder to eager-load the nodes that are connected to +// the "subscription" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UsageLogQuery) WithSubscription(opts ...func(*UserSubscriptionQuery)) *UsageLogQuery { + query := (&UserSubscriptionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withSubscription = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// UserID int64 `json:"user_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.UsageLog.Query(). +// GroupBy(usagelog.FieldUserID). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *UsageLogQuery) GroupBy(field string, fields ...string) *UsageLogGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &UsageLogGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = usagelog.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// UserID int64 `json:"user_id,omitempty"` +// } +// +// client.UsageLog.Query(). +// Select(usagelog.FieldUserID). +// Scan(ctx, &v) +func (_q *UsageLogQuery) Select(fields ...string) *UsageLogSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &UsageLogSelect{UsageLogQuery: _q} + sbuild.label = usagelog.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UsageLogSelect configured with the given aggregations. +func (_q *UsageLogQuery) Aggregate(fns ...AggregateFunc) *UsageLogSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *UsageLogQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !usagelog.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *UsageLogQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageLog, error) { + var ( + nodes = []*UsageLog{} + _spec = _q.querySpec() + loadedTypes = [5]bool{ + _q.withUser != nil, + _q.withAPIKey != nil, + _q.withAccount != nil, + _q.withGroup != nil, + _q.withSubscription != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*UsageLog).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &UsageLog{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *UsageLog, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := _q.withAPIKey; query != nil { + if err := _q.loadAPIKey(ctx, query, nodes, nil, + func(n *UsageLog, e *ApiKey) { n.Edges.APIKey = e }); err != nil { + return nil, err + } + } + if query := _q.withAccount; query != nil { + if err := _q.loadAccount(ctx, query, nodes, nil, + func(n *UsageLog, e *Account) { n.Edges.Account = e }); err != nil { + return nil, err + } + } + if query := _q.withGroup; query != nil { + if err := _q.loadGroup(ctx, query, nodes, nil, + func(n *UsageLog, e *Group) { n.Edges.Group = e }); err != nil { + return nil, err + } + } + if query := _q.withSubscription; query != nil { + if err := _q.loadSubscription(ctx, query, nodes, nil, + func(n *UsageLog, e *UserSubscription) { n.Edges.Subscription = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *UsageLogQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadAPIKey(ctx context.Context, query *ApiKeyQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *ApiKey)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + fk := nodes[i].APIKeyID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(apikey.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "api_key_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadAccount(ctx context.Context, query *AccountQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *Account)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + fk := nodes[i].AccountID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(account.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "account_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *Group)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + if nodes[i].GroupID == nil { + continue + } + fk := *nodes[i].GroupID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(group.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *UsageLogQuery) loadSubscription(ctx context.Context, query *UserSubscriptionQuery, nodes []*UsageLog, init func(*UsageLog), assign func(*UsageLog, *UserSubscription)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*UsageLog) + for i := range nodes { + if nodes[i].SubscriptionID == nil { + continue + } + fk := *nodes[i].SubscriptionID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(usersubscription.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "subscription_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *UsageLogQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *UsageLogQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(usagelog.Table, usagelog.Columns, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, usagelog.FieldID) + for i := range fields { + if fields[i] != usagelog.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(usagelog.FieldUserID) + } + if _q.withAPIKey != nil { + _spec.Node.AddColumnOnce(usagelog.FieldAPIKeyID) + } + if _q.withAccount != nil { + _spec.Node.AddColumnOnce(usagelog.FieldAccountID) + } + if _q.withGroup != nil { + _spec.Node.AddColumnOnce(usagelog.FieldGroupID) + } + if _q.withSubscription != nil { + _spec.Node.AddColumnOnce(usagelog.FieldSubscriptionID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *UsageLogQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(usagelog.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = usagelog.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// UsageLogGroupBy is the group-by builder for UsageLog entities. +type UsageLogGroupBy struct { + selector + build *UsageLogQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *UsageLogGroupBy) Aggregate(fns ...AggregateFunc) *UsageLogGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *UsageLogGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UsageLogQuery, *UsageLogGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *UsageLogGroupBy) sqlScan(ctx context.Context, root *UsageLogQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// UsageLogSelect is the builder for selecting fields of UsageLog entities. +type UsageLogSelect struct { + *UsageLogQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *UsageLogSelect) Aggregate(fns ...AggregateFunc) *UsageLogSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *UsageLogSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UsageLogQuery, *UsageLogSelect](ctx, _s.UsageLogQuery, _s, _s.inters, v) +} + +func (_s *UsageLogSelect) sqlScan(ctx context.Context, root *UsageLogQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go new file mode 100644 index 00000000..55b8e234 --- /dev/null +++ b/backend/ent/usagelog_update.go @@ -0,0 +1,1800 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/account" + "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" + "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" +) + +// UsageLogUpdate is the builder for updating UsageLog entities. +type UsageLogUpdate struct { + config + hooks []Hook + mutation *UsageLogMutation +} + +// Where appends a list predicates to the UsageLogUpdate builder. +func (_u *UsageLogUpdate) Where(ps ...predicate.UsageLog) *UsageLogUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *UsageLogUpdate) SetUserID(v int64) *UsageLogUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableUserID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetAPIKeyID sets the "api_key_id" field. +func (_u *UsageLogUpdate) SetAPIKeyID(v int64) *UsageLogUpdate { + _u.mutation.SetAPIKeyID(v) + return _u +} + +// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableAPIKeyID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetAPIKeyID(*v) + } + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *UsageLogUpdate) SetAccountID(v int64) *UsageLogUpdate { + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableAccountID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// SetRequestID sets the "request_id" field. +func (_u *UsageLogUpdate) SetRequestID(v string) *UsageLogUpdate { + _u.mutation.SetRequestID(v) + return _u +} + +// SetNillableRequestID sets the "request_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableRequestID(v *string) *UsageLogUpdate { + if v != nil { + _u.SetRequestID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *UsageLogUpdate) SetModel(v string) *UsageLogUpdate { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetGroupID sets the "group_id" field. +func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { + _u.mutation.SetGroupID(v) + return _u +} + +// SetNillableGroupID sets the "group_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableGroupID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetGroupID(*v) + } + return _u +} + +// ClearGroupID clears the value of the "group_id" field. +func (_u *UsageLogUpdate) ClearGroupID() *UsageLogUpdate { + _u.mutation.ClearGroupID() + return _u +} + +// SetSubscriptionID sets the "subscription_id" field. +func (_u *UsageLogUpdate) SetSubscriptionID(v int64) *UsageLogUpdate { + _u.mutation.SetSubscriptionID(v) + return _u +} + +// SetNillableSubscriptionID sets the "subscription_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableSubscriptionID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetSubscriptionID(*v) + } + return _u +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (_u *UsageLogUpdate) ClearSubscriptionID() *UsageLogUpdate { + _u.mutation.ClearSubscriptionID() + return _u +} + +// SetInputTokens sets the "input_tokens" field. +func (_u *UsageLogUpdate) SetInputTokens(v int) *UsageLogUpdate { + _u.mutation.ResetInputTokens() + _u.mutation.SetInputTokens(v) + return _u +} + +// SetNillableInputTokens sets the "input_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableInputTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetInputTokens(*v) + } + return _u +} + +// AddInputTokens adds value to the "input_tokens" field. +func (_u *UsageLogUpdate) AddInputTokens(v int) *UsageLogUpdate { + _u.mutation.AddInputTokens(v) + return _u +} + +// SetOutputTokens sets the "output_tokens" field. +func (_u *UsageLogUpdate) SetOutputTokens(v int) *UsageLogUpdate { + _u.mutation.ResetOutputTokens() + _u.mutation.SetOutputTokens(v) + return _u +} + +// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableOutputTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetOutputTokens(*v) + } + return _u +} + +// AddOutputTokens adds value to the "output_tokens" field. +func (_u *UsageLogUpdate) AddOutputTokens(v int) *UsageLogUpdate { + _u.mutation.AddOutputTokens(v) + return _u +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (_u *UsageLogUpdate) SetCacheCreationTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheCreationTokens() + _u.mutation.SetCacheCreationTokens(v) + return _u +} + +// SetNillableCacheCreationTokens sets the "cache_creation_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreationTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreationTokens(*v) + } + return _u +} + +// AddCacheCreationTokens adds value to the "cache_creation_tokens" field. +func (_u *UsageLogUpdate) AddCacheCreationTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheCreationTokens(v) + return _u +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (_u *UsageLogUpdate) SetCacheReadTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheReadTokens() + _u.mutation.SetCacheReadTokens(v) + return _u +} + +// SetNillableCacheReadTokens sets the "cache_read_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheReadTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheReadTokens(*v) + } + return _u +} + +// AddCacheReadTokens adds value to the "cache_read_tokens" field. +func (_u *UsageLogUpdate) AddCacheReadTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheReadTokens(v) + return _u +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdate) SetCacheCreation5mTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheCreation5mTokens() + _u.mutation.SetCacheCreation5mTokens(v) + return _u +} + +// SetNillableCacheCreation5mTokens sets the "cache_creation_5m_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreation5mTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreation5mTokens(*v) + } + return _u +} + +// AddCacheCreation5mTokens adds value to the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdate) AddCacheCreation5mTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheCreation5mTokens(v) + return _u +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdate) SetCacheCreation1hTokens(v int) *UsageLogUpdate { + _u.mutation.ResetCacheCreation1hTokens() + _u.mutation.SetCacheCreation1hTokens(v) + return _u +} + +// SetNillableCacheCreation1hTokens sets the "cache_creation_1h_tokens" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreation1hTokens(v *int) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreation1hTokens(*v) + } + return _u +} + +// AddCacheCreation1hTokens adds value to the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdate) AddCacheCreation1hTokens(v int) *UsageLogUpdate { + _u.mutation.AddCacheCreation1hTokens(v) + return _u +} + +// SetInputCost sets the "input_cost" field. +func (_u *UsageLogUpdate) SetInputCost(v float64) *UsageLogUpdate { + _u.mutation.ResetInputCost() + _u.mutation.SetInputCost(v) + return _u +} + +// SetNillableInputCost sets the "input_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableInputCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetInputCost(*v) + } + return _u +} + +// AddInputCost adds value to the "input_cost" field. +func (_u *UsageLogUpdate) AddInputCost(v float64) *UsageLogUpdate { + _u.mutation.AddInputCost(v) + return _u +} + +// SetOutputCost sets the "output_cost" field. +func (_u *UsageLogUpdate) SetOutputCost(v float64) *UsageLogUpdate { + _u.mutation.ResetOutputCost() + _u.mutation.SetOutputCost(v) + return _u +} + +// SetNillableOutputCost sets the "output_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableOutputCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetOutputCost(*v) + } + return _u +} + +// AddOutputCost adds value to the "output_cost" field. +func (_u *UsageLogUpdate) AddOutputCost(v float64) *UsageLogUpdate { + _u.mutation.AddOutputCost(v) + return _u +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (_u *UsageLogUpdate) SetCacheCreationCost(v float64) *UsageLogUpdate { + _u.mutation.ResetCacheCreationCost() + _u.mutation.SetCacheCreationCost(v) + return _u +} + +// SetNillableCacheCreationCost sets the "cache_creation_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheCreationCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetCacheCreationCost(*v) + } + return _u +} + +// AddCacheCreationCost adds value to the "cache_creation_cost" field. +func (_u *UsageLogUpdate) AddCacheCreationCost(v float64) *UsageLogUpdate { + _u.mutation.AddCacheCreationCost(v) + return _u +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (_u *UsageLogUpdate) SetCacheReadCost(v float64) *UsageLogUpdate { + _u.mutation.ResetCacheReadCost() + _u.mutation.SetCacheReadCost(v) + return _u +} + +// SetNillableCacheReadCost sets the "cache_read_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheReadCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetCacheReadCost(*v) + } + return _u +} + +// AddCacheReadCost adds value to the "cache_read_cost" field. +func (_u *UsageLogUpdate) AddCacheReadCost(v float64) *UsageLogUpdate { + _u.mutation.AddCacheReadCost(v) + return _u +} + +// SetTotalCost sets the "total_cost" field. +func (_u *UsageLogUpdate) SetTotalCost(v float64) *UsageLogUpdate { + _u.mutation.ResetTotalCost() + _u.mutation.SetTotalCost(v) + return _u +} + +// SetNillableTotalCost sets the "total_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableTotalCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetTotalCost(*v) + } + return _u +} + +// AddTotalCost adds value to the "total_cost" field. +func (_u *UsageLogUpdate) AddTotalCost(v float64) *UsageLogUpdate { + _u.mutation.AddTotalCost(v) + return _u +} + +// SetActualCost sets the "actual_cost" field. +func (_u *UsageLogUpdate) SetActualCost(v float64) *UsageLogUpdate { + _u.mutation.ResetActualCost() + _u.mutation.SetActualCost(v) + return _u +} + +// SetNillableActualCost sets the "actual_cost" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableActualCost(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetActualCost(*v) + } + return _u +} + +// AddActualCost adds value to the "actual_cost" field. +func (_u *UsageLogUpdate) AddActualCost(v float64) *UsageLogUpdate { + _u.mutation.AddActualCost(v) + return _u +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (_u *UsageLogUpdate) SetRateMultiplier(v float64) *UsageLogUpdate { + _u.mutation.ResetRateMultiplier() + _u.mutation.SetRateMultiplier(v) + return _u +} + +// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableRateMultiplier(v *float64) *UsageLogUpdate { + if v != nil { + _u.SetRateMultiplier(*v) + } + return _u +} + +// AddRateMultiplier adds value to the "rate_multiplier" field. +func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate { + _u.mutation.AddRateMultiplier(v) + return _u +} + +// SetBillingType sets the "billing_type" field. +func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate { + _u.mutation.ResetBillingType() + _u.mutation.SetBillingType(v) + return _u +} + +// SetNillableBillingType sets the "billing_type" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableBillingType(v *int8) *UsageLogUpdate { + if v != nil { + _u.SetBillingType(*v) + } + return _u +} + +// AddBillingType adds value to the "billing_type" field. +func (_u *UsageLogUpdate) AddBillingType(v int8) *UsageLogUpdate { + _u.mutation.AddBillingType(v) + return _u +} + +// SetStream sets the "stream" field. +func (_u *UsageLogUpdate) SetStream(v bool) *UsageLogUpdate { + _u.mutation.SetStream(v) + return _u +} + +// SetNillableStream sets the "stream" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableStream(v *bool) *UsageLogUpdate { + if v != nil { + _u.SetStream(*v) + } + return _u +} + +// SetDurationMs sets the "duration_ms" field. +func (_u *UsageLogUpdate) SetDurationMs(v int) *UsageLogUpdate { + _u.mutation.ResetDurationMs() + _u.mutation.SetDurationMs(v) + return _u +} + +// SetNillableDurationMs sets the "duration_ms" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableDurationMs(v *int) *UsageLogUpdate { + if v != nil { + _u.SetDurationMs(*v) + } + return _u +} + +// AddDurationMs adds value to the "duration_ms" field. +func (_u *UsageLogUpdate) AddDurationMs(v int) *UsageLogUpdate { + _u.mutation.AddDurationMs(v) + return _u +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (_u *UsageLogUpdate) ClearDurationMs() *UsageLogUpdate { + _u.mutation.ClearDurationMs() + return _u +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (_u *UsageLogUpdate) SetFirstTokenMs(v int) *UsageLogUpdate { + _u.mutation.ResetFirstTokenMs() + _u.mutation.SetFirstTokenMs(v) + return _u +} + +// SetNillableFirstTokenMs sets the "first_token_ms" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableFirstTokenMs(v *int) *UsageLogUpdate { + if v != nil { + _u.SetFirstTokenMs(*v) + } + return _u +} + +// AddFirstTokenMs adds value to the "first_token_ms" field. +func (_u *UsageLogUpdate) AddFirstTokenMs(v int) *UsageLogUpdate { + _u.mutation.AddFirstTokenMs(v) + return _u +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (_u *UsageLogUpdate) ClearFirstTokenMs() *UsageLogUpdate { + _u.mutation.ClearFirstTokenMs() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { + return _u.SetUserID(v.ID) +} + +// SetAPIKey sets the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdate) SetAPIKey(v *ApiKey) *UsageLogUpdate { + return _u.SetAPIKeyID(v.ID) +} + +// SetAccount sets the "account" edge to the Account entity. +func (_u *UsageLogUpdate) SetAccount(v *Account) *UsageLogUpdate { + return _u.SetAccountID(v.ID) +} + +// SetGroup sets the "group" edge to the Group entity. +func (_u *UsageLogUpdate) SetGroup(v *Group) *UsageLogUpdate { + return _u.SetGroupID(v.ID) +} + +// SetSubscription sets the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdate) SetSubscription(v *UserSubscription) *UsageLogUpdate { + return _u.SetSubscriptionID(v.ID) +} + +// Mutation returns the UsageLogMutation object of the builder. +func (_u *UsageLogUpdate) Mutation() *UsageLogMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UsageLogUpdate) ClearUser() *UsageLogUpdate { + _u.mutation.ClearUser() + return _u +} + +// ClearAPIKey clears the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdate) ClearAPIKey() *UsageLogUpdate { + _u.mutation.ClearAPIKey() + return _u +} + +// ClearAccount clears the "account" edge to the Account entity. +func (_u *UsageLogUpdate) ClearAccount() *UsageLogUpdate { + _u.mutation.ClearAccount() + return _u +} + +// ClearGroup clears the "group" edge to the Group entity. +func (_u *UsageLogUpdate) ClearGroup() *UsageLogUpdate { + _u.mutation.ClearGroup() + return _u +} + +// ClearSubscription clears the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdate) ClearSubscription() *UsageLogUpdate { + _u.mutation.ClearSubscription() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *UsageLogUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UsageLogUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *UsageLogUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UsageLogUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UsageLogUpdate) check() error { + if v, ok := _u.mutation.RequestID(); ok { + if err := usagelog.RequestIDValidator(v); err != nil { + return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "UsageLog.request_id": %w`, err)} + } + } + if v, ok := _u.mutation.Model(); ok { + if err := usagelog.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) + } + if _u.mutation.APIKeyCleared() && len(_u.mutation.APIKeyIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.api_key"`) + } + if _u.mutation.AccountCleared() && len(_u.mutation.AccountIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.account"`) + } + return nil +} + +func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(usagelog.Table, usagelog.Columns, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.RequestID(); ok { + _spec.SetField(usagelog.FieldRequestID, field.TypeString, value) + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(usagelog.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.InputTokens(); ok { + _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedInputTokens(); ok { + _spec.AddField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.OutputTokens(); ok { + _spec.SetField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOutputTokens(); ok { + _spec.AddField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreationTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreationTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheReadTokens(); ok { + _spec.SetField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheReadTokens(); ok { + _spec.AddField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation5mTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation5mTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation1hTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation1hTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.InputCost(); ok { + _spec.SetField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedInputCost(); ok { + _spec.AddField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.OutputCost(); ok { + _spec.SetField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedOutputCost(); ok { + _spec.AddField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheCreationCost(); ok { + _spec.SetField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheCreationCost(); ok { + _spec.AddField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheReadCost(); ok { + _spec.SetField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheReadCost(); ok { + _spec.AddField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.TotalCost(); ok { + _spec.SetField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalCost(); ok { + _spec.AddField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ActualCost(); ok { + _spec.SetField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedActualCost(); ok { + _spec.AddField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateMultiplier(); ok { + _spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateMultiplier(); ok { + _spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.BillingType(); ok { + _spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.AddedBillingType(); ok { + _spec.AddField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.Stream(); ok { + _spec.SetField(usagelog.FieldStream, field.TypeBool, value) + } + if value, ok := _u.mutation.DurationMs(); ok { + _spec.SetField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDurationMs(); ok { + _spec.AddField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if _u.mutation.DurationMsCleared() { + _spec.ClearField(usagelog.FieldDurationMs, field.TypeInt) + } + if value, ok := _u.mutation.FirstTokenMs(); ok { + _spec.SetField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedFirstTokenMs(); ok { + _spec.AddField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if _u.mutation.FirstTokenMsCleared() { + _spec.ClearField(usagelog.FieldFirstTokenMs, field.TypeInt) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.APIKeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.APIKeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AccountCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.GroupCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.SubscriptionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.SubscriptionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{usagelog.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// UsageLogUpdateOne is the builder for updating a single UsageLog entity. +type UsageLogUpdateOne struct { + config + fields []string + hooks []Hook + mutation *UsageLogMutation +} + +// SetUserID sets the "user_id" field. +func (_u *UsageLogUpdateOne) SetUserID(v int64) *UsageLogUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableUserID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetAPIKeyID sets the "api_key_id" field. +func (_u *UsageLogUpdateOne) SetAPIKeyID(v int64) *UsageLogUpdateOne { + _u.mutation.SetAPIKeyID(v) + return _u +} + +// SetNillableAPIKeyID sets the "api_key_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableAPIKeyID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetAPIKeyID(*v) + } + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *UsageLogUpdateOne) SetAccountID(v int64) *UsageLogUpdateOne { + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableAccountID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// SetRequestID sets the "request_id" field. +func (_u *UsageLogUpdateOne) SetRequestID(v string) *UsageLogUpdateOne { + _u.mutation.SetRequestID(v) + return _u +} + +// SetNillableRequestID sets the "request_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableRequestID(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetRequestID(*v) + } + return _u +} + +// SetModel sets the "model" field. +func (_u *UsageLogUpdateOne) SetModel(v string) *UsageLogUpdateOne { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetGroupID sets the "group_id" field. +func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { + _u.mutation.SetGroupID(v) + return _u +} + +// SetNillableGroupID sets the "group_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableGroupID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetGroupID(*v) + } + return _u +} + +// ClearGroupID clears the value of the "group_id" field. +func (_u *UsageLogUpdateOne) ClearGroupID() *UsageLogUpdateOne { + _u.mutation.ClearGroupID() + return _u +} + +// SetSubscriptionID sets the "subscription_id" field. +func (_u *UsageLogUpdateOne) SetSubscriptionID(v int64) *UsageLogUpdateOne { + _u.mutation.SetSubscriptionID(v) + return _u +} + +// SetNillableSubscriptionID sets the "subscription_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableSubscriptionID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetSubscriptionID(*v) + } + return _u +} + +// ClearSubscriptionID clears the value of the "subscription_id" field. +func (_u *UsageLogUpdateOne) ClearSubscriptionID() *UsageLogUpdateOne { + _u.mutation.ClearSubscriptionID() + return _u +} + +// SetInputTokens sets the "input_tokens" field. +func (_u *UsageLogUpdateOne) SetInputTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetInputTokens() + _u.mutation.SetInputTokens(v) + return _u +} + +// SetNillableInputTokens sets the "input_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableInputTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetInputTokens(*v) + } + return _u +} + +// AddInputTokens adds value to the "input_tokens" field. +func (_u *UsageLogUpdateOne) AddInputTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddInputTokens(v) + return _u +} + +// SetOutputTokens sets the "output_tokens" field. +func (_u *UsageLogUpdateOne) SetOutputTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetOutputTokens() + _u.mutation.SetOutputTokens(v) + return _u +} + +// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableOutputTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetOutputTokens(*v) + } + return _u +} + +// AddOutputTokens adds value to the "output_tokens" field. +func (_u *UsageLogUpdateOne) AddOutputTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddOutputTokens(v) + return _u +} + +// SetCacheCreationTokens sets the "cache_creation_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheCreationTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreationTokens() + _u.mutation.SetCacheCreationTokens(v) + return _u +} + +// SetNillableCacheCreationTokens sets the "cache_creation_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreationTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreationTokens(*v) + } + return _u +} + +// AddCacheCreationTokens adds value to the "cache_creation_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheCreationTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheCreationTokens(v) + return _u +} + +// SetCacheReadTokens sets the "cache_read_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheReadTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheReadTokens() + _u.mutation.SetCacheReadTokens(v) + return _u +} + +// SetNillableCacheReadTokens sets the "cache_read_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheReadTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheReadTokens(*v) + } + return _u +} + +// AddCacheReadTokens adds value to the "cache_read_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheReadTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheReadTokens(v) + return _u +} + +// SetCacheCreation5mTokens sets the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheCreation5mTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreation5mTokens() + _u.mutation.SetCacheCreation5mTokens(v) + return _u +} + +// SetNillableCacheCreation5mTokens sets the "cache_creation_5m_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreation5mTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreation5mTokens(*v) + } + return _u +} + +// AddCacheCreation5mTokens adds value to the "cache_creation_5m_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheCreation5mTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheCreation5mTokens(v) + return _u +} + +// SetCacheCreation1hTokens sets the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdateOne) SetCacheCreation1hTokens(v int) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreation1hTokens() + _u.mutation.SetCacheCreation1hTokens(v) + return _u +} + +// SetNillableCacheCreation1hTokens sets the "cache_creation_1h_tokens" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreation1hTokens(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreation1hTokens(*v) + } + return _u +} + +// AddCacheCreation1hTokens adds value to the "cache_creation_1h_tokens" field. +func (_u *UsageLogUpdateOne) AddCacheCreation1hTokens(v int) *UsageLogUpdateOne { + _u.mutation.AddCacheCreation1hTokens(v) + return _u +} + +// SetInputCost sets the "input_cost" field. +func (_u *UsageLogUpdateOne) SetInputCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetInputCost() + _u.mutation.SetInputCost(v) + return _u +} + +// SetNillableInputCost sets the "input_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableInputCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetInputCost(*v) + } + return _u +} + +// AddInputCost adds value to the "input_cost" field. +func (_u *UsageLogUpdateOne) AddInputCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddInputCost(v) + return _u +} + +// SetOutputCost sets the "output_cost" field. +func (_u *UsageLogUpdateOne) SetOutputCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetOutputCost() + _u.mutation.SetOutputCost(v) + return _u +} + +// SetNillableOutputCost sets the "output_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableOutputCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetOutputCost(*v) + } + return _u +} + +// AddOutputCost adds value to the "output_cost" field. +func (_u *UsageLogUpdateOne) AddOutputCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddOutputCost(v) + return _u +} + +// SetCacheCreationCost sets the "cache_creation_cost" field. +func (_u *UsageLogUpdateOne) SetCacheCreationCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetCacheCreationCost() + _u.mutation.SetCacheCreationCost(v) + return _u +} + +// SetNillableCacheCreationCost sets the "cache_creation_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheCreationCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheCreationCost(*v) + } + return _u +} + +// AddCacheCreationCost adds value to the "cache_creation_cost" field. +func (_u *UsageLogUpdateOne) AddCacheCreationCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddCacheCreationCost(v) + return _u +} + +// SetCacheReadCost sets the "cache_read_cost" field. +func (_u *UsageLogUpdateOne) SetCacheReadCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetCacheReadCost() + _u.mutation.SetCacheReadCost(v) + return _u +} + +// SetNillableCacheReadCost sets the "cache_read_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheReadCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheReadCost(*v) + } + return _u +} + +// AddCacheReadCost adds value to the "cache_read_cost" field. +func (_u *UsageLogUpdateOne) AddCacheReadCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddCacheReadCost(v) + return _u +} + +// SetTotalCost sets the "total_cost" field. +func (_u *UsageLogUpdateOne) SetTotalCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetTotalCost() + _u.mutation.SetTotalCost(v) + return _u +} + +// SetNillableTotalCost sets the "total_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableTotalCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetTotalCost(*v) + } + return _u +} + +// AddTotalCost adds value to the "total_cost" field. +func (_u *UsageLogUpdateOne) AddTotalCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddTotalCost(v) + return _u +} + +// SetActualCost sets the "actual_cost" field. +func (_u *UsageLogUpdateOne) SetActualCost(v float64) *UsageLogUpdateOne { + _u.mutation.ResetActualCost() + _u.mutation.SetActualCost(v) + return _u +} + +// SetNillableActualCost sets the "actual_cost" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableActualCost(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetActualCost(*v) + } + return _u +} + +// AddActualCost adds value to the "actual_cost" field. +func (_u *UsageLogUpdateOne) AddActualCost(v float64) *UsageLogUpdateOne { + _u.mutation.AddActualCost(v) + return _u +} + +// SetRateMultiplier sets the "rate_multiplier" field. +func (_u *UsageLogUpdateOne) SetRateMultiplier(v float64) *UsageLogUpdateOne { + _u.mutation.ResetRateMultiplier() + _u.mutation.SetRateMultiplier(v) + return _u +} + +// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableRateMultiplier(v *float64) *UsageLogUpdateOne { + if v != nil { + _u.SetRateMultiplier(*v) + } + return _u +} + +// AddRateMultiplier adds value to the "rate_multiplier" field. +func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne { + _u.mutation.AddRateMultiplier(v) + return _u +} + +// SetBillingType sets the "billing_type" field. +func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne { + _u.mutation.ResetBillingType() + _u.mutation.SetBillingType(v) + return _u +} + +// SetNillableBillingType sets the "billing_type" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableBillingType(v *int8) *UsageLogUpdateOne { + if v != nil { + _u.SetBillingType(*v) + } + return _u +} + +// AddBillingType adds value to the "billing_type" field. +func (_u *UsageLogUpdateOne) AddBillingType(v int8) *UsageLogUpdateOne { + _u.mutation.AddBillingType(v) + return _u +} + +// SetStream sets the "stream" field. +func (_u *UsageLogUpdateOne) SetStream(v bool) *UsageLogUpdateOne { + _u.mutation.SetStream(v) + return _u +} + +// SetNillableStream sets the "stream" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableStream(v *bool) *UsageLogUpdateOne { + if v != nil { + _u.SetStream(*v) + } + return _u +} + +// SetDurationMs sets the "duration_ms" field. +func (_u *UsageLogUpdateOne) SetDurationMs(v int) *UsageLogUpdateOne { + _u.mutation.ResetDurationMs() + _u.mutation.SetDurationMs(v) + return _u +} + +// SetNillableDurationMs sets the "duration_ms" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableDurationMs(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetDurationMs(*v) + } + return _u +} + +// AddDurationMs adds value to the "duration_ms" field. +func (_u *UsageLogUpdateOne) AddDurationMs(v int) *UsageLogUpdateOne { + _u.mutation.AddDurationMs(v) + return _u +} + +// ClearDurationMs clears the value of the "duration_ms" field. +func (_u *UsageLogUpdateOne) ClearDurationMs() *UsageLogUpdateOne { + _u.mutation.ClearDurationMs() + return _u +} + +// SetFirstTokenMs sets the "first_token_ms" field. +func (_u *UsageLogUpdateOne) SetFirstTokenMs(v int) *UsageLogUpdateOne { + _u.mutation.ResetFirstTokenMs() + _u.mutation.SetFirstTokenMs(v) + return _u +} + +// SetNillableFirstTokenMs sets the "first_token_ms" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableFirstTokenMs(v *int) *UsageLogUpdateOne { + if v != nil { + _u.SetFirstTokenMs(*v) + } + return _u +} + +// AddFirstTokenMs adds value to the "first_token_ms" field. +func (_u *UsageLogUpdateOne) AddFirstTokenMs(v int) *UsageLogUpdateOne { + _u.mutation.AddFirstTokenMs(v) + return _u +} + +// ClearFirstTokenMs clears the value of the "first_token_ms" field. +func (_u *UsageLogUpdateOne) ClearFirstTokenMs() *UsageLogUpdateOne { + _u.mutation.ClearFirstTokenMs() + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { + return _u.SetUserID(v.ID) +} + +// SetAPIKey sets the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdateOne) SetAPIKey(v *ApiKey) *UsageLogUpdateOne { + return _u.SetAPIKeyID(v.ID) +} + +// SetAccount sets the "account" edge to the Account entity. +func (_u *UsageLogUpdateOne) SetAccount(v *Account) *UsageLogUpdateOne { + return _u.SetAccountID(v.ID) +} + +// SetGroup sets the "group" edge to the Group entity. +func (_u *UsageLogUpdateOne) SetGroup(v *Group) *UsageLogUpdateOne { + return _u.SetGroupID(v.ID) +} + +// SetSubscription sets the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdateOne) SetSubscription(v *UserSubscription) *UsageLogUpdateOne { + return _u.SetSubscriptionID(v.ID) +} + +// Mutation returns the UsageLogMutation object of the builder. +func (_u *UsageLogUpdateOne) Mutation() *UsageLogMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *UsageLogUpdateOne) ClearUser() *UsageLogUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// ClearAPIKey clears the "api_key" edge to the ApiKey entity. +func (_u *UsageLogUpdateOne) ClearAPIKey() *UsageLogUpdateOne { + _u.mutation.ClearAPIKey() + return _u +} + +// ClearAccount clears the "account" edge to the Account entity. +func (_u *UsageLogUpdateOne) ClearAccount() *UsageLogUpdateOne { + _u.mutation.ClearAccount() + return _u +} + +// ClearGroup clears the "group" edge to the Group entity. +func (_u *UsageLogUpdateOne) ClearGroup() *UsageLogUpdateOne { + _u.mutation.ClearGroup() + return _u +} + +// ClearSubscription clears the "subscription" edge to the UserSubscription entity. +func (_u *UsageLogUpdateOne) ClearSubscription() *UsageLogUpdateOne { + _u.mutation.ClearSubscription() + return _u +} + +// Where appends a list predicates to the UsageLogUpdate builder. +func (_u *UsageLogUpdateOne) Where(ps ...predicate.UsageLog) *UsageLogUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *UsageLogUpdateOne) Select(field string, fields ...string) *UsageLogUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated UsageLog entity. +func (_u *UsageLogUpdateOne) Save(ctx context.Context) (*UsageLog, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UsageLogUpdateOne) SaveX(ctx context.Context) *UsageLog { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *UsageLogUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UsageLogUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UsageLogUpdateOne) check() error { + if v, ok := _u.mutation.RequestID(); ok { + if err := usagelog.RequestIDValidator(v); err != nil { + return &ValidationError{Name: "request_id", err: fmt.Errorf(`ent: validator failed for field "UsageLog.request_id": %w`, err)} + } + } + if v, ok := _u.mutation.Model(); ok { + if err := usagelog.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) + } + if _u.mutation.APIKeyCleared() && len(_u.mutation.APIKeyIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.api_key"`) + } + if _u.mutation.AccountCleared() && len(_u.mutation.AccountIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "UsageLog.account"`) + } + return nil +} + +func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(usagelog.Table, usagelog.Columns, sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UsageLog.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, usagelog.FieldID) + for _, f := range fields { + if !usagelog.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != usagelog.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.RequestID(); ok { + _spec.SetField(usagelog.FieldRequestID, field.TypeString, value) + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(usagelog.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.InputTokens(); ok { + _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedInputTokens(); ok { + _spec.AddField(usagelog.FieldInputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.OutputTokens(); ok { + _spec.SetField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedOutputTokens(); ok { + _spec.AddField(usagelog.FieldOutputTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreationTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreationTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreationTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheReadTokens(); ok { + _spec.SetField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheReadTokens(); ok { + _spec.AddField(usagelog.FieldCacheReadTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation5mTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation5mTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation5mTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.CacheCreation1hTokens(); ok { + _spec.SetField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedCacheCreation1hTokens(); ok { + _spec.AddField(usagelog.FieldCacheCreation1hTokens, field.TypeInt, value) + } + if value, ok := _u.mutation.InputCost(); ok { + _spec.SetField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedInputCost(); ok { + _spec.AddField(usagelog.FieldInputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.OutputCost(); ok { + _spec.SetField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedOutputCost(); ok { + _spec.AddField(usagelog.FieldOutputCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheCreationCost(); ok { + _spec.SetField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheCreationCost(); ok { + _spec.AddField(usagelog.FieldCacheCreationCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.CacheReadCost(); ok { + _spec.SetField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedCacheReadCost(); ok { + _spec.AddField(usagelog.FieldCacheReadCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.TotalCost(); ok { + _spec.SetField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalCost(); ok { + _spec.AddField(usagelog.FieldTotalCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ActualCost(); ok { + _spec.SetField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedActualCost(); ok { + _spec.AddField(usagelog.FieldActualCost, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateMultiplier(); ok { + _spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateMultiplier(); ok { + _spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.BillingType(); ok { + _spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.AddedBillingType(); ok { + _spec.AddField(usagelog.FieldBillingType, field.TypeInt8, value) + } + if value, ok := _u.mutation.Stream(); ok { + _spec.SetField(usagelog.FieldStream, field.TypeBool, value) + } + if value, ok := _u.mutation.DurationMs(); ok { + _spec.SetField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedDurationMs(); ok { + _spec.AddField(usagelog.FieldDurationMs, field.TypeInt, value) + } + if _u.mutation.DurationMsCleared() { + _spec.ClearField(usagelog.FieldDurationMs, field.TypeInt) + } + if value, ok := _u.mutation.FirstTokenMs(); ok { + _spec.SetField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedFirstTokenMs(); ok { + _spec.AddField(usagelog.FieldFirstTokenMs, field.TypeInt, value) + } + if _u.mutation.FirstTokenMsCleared() { + _spec.ClearField(usagelog.FieldFirstTokenMs, field.TypeInt) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.UserTable, + Columns: []string{usagelog.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.APIKeyCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.APIKeyIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.APIKeyTable, + Columns: []string{usagelog.APIKeyColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(apikey.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AccountCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AccountIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.AccountTable, + Columns: []string{usagelog.AccountColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(account.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.GroupCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.GroupTable, + Columns: []string{usagelog.GroupColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.SubscriptionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.SubscriptionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: usagelog.SubscriptionTable, + Columns: []string{usagelog.SubscriptionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usersubscription.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &UsageLog{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{usagelog.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/user.go b/backend/ent/user.go index 1f06eb4e..eda67c84 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -59,11 +59,13 @@ type UserEdges struct { AssignedSubscriptions []*UserSubscription `json:"assigned_subscriptions,omitempty"` // AllowedGroups holds the value of the allowed_groups edge. AllowedGroups []*Group `json:"allowed_groups,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [6]bool + loadedTypes [7]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -111,10 +113,19 @@ func (e UserEdges) AllowedGroupsOrErr() ([]*Group, error) { return nil, &NotLoadedError{edge: "allowed_groups"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[5] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[5] { + if e.loadedTypes[6] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -265,6 +276,11 @@ func (_m *User) QueryAllowedGroups() *GroupQuery { return NewUserClient(_m.config).QueryAllowedGroups(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the User entity. +func (_m *User) QueryUsageLogs() *UsageLogQuery { + return NewUserClient(_m.config).QueryUsageLogs(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index e1e6988b..9ad87890 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -49,6 +49,8 @@ const ( EdgeAssignedSubscriptions = "assigned_subscriptions" // EdgeAllowedGroups holds the string denoting the allowed_groups edge name in mutations. EdgeAllowedGroups = "allowed_groups" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -86,6 +88,13 @@ const ( // AllowedGroupsInverseTable is the table name for the Group entity. // It exists in this package in order to avoid circular dependency with the "group" package. AllowedGroupsInverseTable = "groups" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -308,6 +317,20 @@ func ByAllowedGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -356,6 +379,13 @@ func newAllowedGroupsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2M, false, AllowedGroupsTable, AllowedGroupsPrimaryKey...), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index ad434c59..81959cf4 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -895,6 +895,29 @@ func HasAllowedGroupsWith(preds ...predicate.Group) predicate.User { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index 8c9caaa2..51bdc493 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -253,6 +254,21 @@ func (_c *UserCreate) AddAllowedGroups(v ...*Group) *UserCreate { return _c.AddAllowedGroupIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *UserCreate) AddUsageLogIDs(ids ...int64) *UserCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *UserCreate) AddUsageLogs(v ...*UsageLog) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -559,6 +575,22 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { edge.Target.Fields = specE.Fields _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 21159a62..c172dda3 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -33,6 +34,7 @@ type UserQuery struct { withSubscriptions *UserSubscriptionQuery withAssignedSubscriptions *UserSubscriptionQuery withAllowedGroups *GroupQuery + withUsageLogs *UsageLogQuery withUserAllowedGroups *UserAllowedGroupQuery // intermediate query (i.e. traversal path). sql *sql.Selector @@ -180,6 +182,28 @@ func (_q *UserQuery) QueryAllowedGroups() *GroupQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *UserQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.UsageLogsTable, user.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -399,6 +423,7 @@ func (_q *UserQuery) Clone() *UserQuery { withSubscriptions: _q.withSubscriptions.Clone(), withAssignedSubscriptions: _q.withAssignedSubscriptions.Clone(), withAllowedGroups: _q.withAllowedGroups.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -461,6 +486,17 @@ func (_q *UserQuery) WithAllowedGroups(opts ...func(*GroupQuery)) *UserQuery { return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -550,12 +586,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [6]bool{ + loadedTypes = [7]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, _q.withAssignedSubscriptions != nil, _q.withAllowedGroups != nil, + _q.withUsageLogs != nil, _q.withUserAllowedGroups != nil, } ) @@ -614,6 +651,13 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *User) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *User, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -811,6 +855,36 @@ func (_q *UserQuery) loadAllowedGroups(ctx context.Context, query *GroupQuery, n } return nil } +func (_q *UserQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*User, init func(*User), assign func(*User, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldUserID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index a00f9b8a..31e57a43 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -273,6 +274,21 @@ func (_u *UserUpdate) AddAllowedGroups(v ...*Group) *UserUpdate { return _u.AddAllowedGroupIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserUpdate) AddUsageLogIDs(ids ...int64) *UserUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdate) AddUsageLogs(v ...*UsageLog) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -383,6 +399,27 @@ func (_u *UserUpdate) RemoveAllowedGroups(v ...*Group) *UserUpdate { return _u.RemoveAllowedGroupIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdate) ClearUsageLogs() *UserUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserUpdate) RemoveUsageLogIDs(ids ...int64) *UserUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserUpdate) RemoveUsageLogs(v ...*UsageLog) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -751,6 +788,51 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -1012,6 +1094,21 @@ func (_u *UserUpdateOne) AddAllowedGroups(v ...*Group) *UserUpdateOne { return _u.AddAllowedGroupIDs(ids...) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserUpdateOne) AddUsageLogIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdateOne) AddUsageLogs(v ...*UsageLog) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -1122,6 +1219,27 @@ func (_u *UserUpdateOne) RemoveAllowedGroups(v ...*Group) *UserUpdateOne { return _u.RemoveAllowedGroupIDs(ids...) } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserUpdateOne) ClearUsageLogs() *UserUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -1520,6 +1638,51 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { edge.Target.Fields = specE.Fields _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.UsageLogsTable, + Columns: []string{user.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/ent/usersubscription.go b/backend/ent/usersubscription.go index 3cfe9475..01beb2fc 100644 --- a/backend/ent/usersubscription.go +++ b/backend/ent/usersubscription.go @@ -23,6 +23,8 @@ type UserSubscription struct { CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. UpdatedAt time.Time `json:"updated_at,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt *time.Time `json:"deleted_at,omitempty"` // UserID holds the value of the "user_id" field. UserID int64 `json:"user_id,omitempty"` // GroupID holds the value of the "group_id" field. @@ -65,9 +67,11 @@ type UserSubscriptionEdges struct { Group *Group `json:"group,omitempty"` // AssignedByUser holds the value of the assigned_by_user edge. AssignedByUser *User `json:"assigned_by_user,omitempty"` + // UsageLogs holds the value of the usage_logs edge. + UsageLogs []*UsageLog `json:"usage_logs,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [3]bool + loadedTypes [4]bool } // UserOrErr returns the User value or an error if the edge @@ -103,6 +107,15 @@ func (e UserSubscriptionEdges) AssignedByUserOrErr() (*User, error) { return nil, &NotLoadedError{edge: "assigned_by_user"} } +// UsageLogsOrErr returns the UsageLogs value or an error if the edge +// was not loaded in eager-loading. +func (e UserSubscriptionEdges) UsageLogsOrErr() ([]*UsageLog, error) { + if e.loadedTypes[3] { + return e.UsageLogs, nil + } + return nil, &NotLoadedError{edge: "usage_logs"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*UserSubscription) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -114,7 +127,7 @@ func (*UserSubscription) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullInt64) case usersubscription.FieldStatus, usersubscription.FieldNotes: values[i] = new(sql.NullString) - case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt: + case usersubscription.FieldCreatedAt, usersubscription.FieldUpdatedAt, usersubscription.FieldDeletedAt, usersubscription.FieldStartsAt, usersubscription.FieldExpiresAt, usersubscription.FieldDailyWindowStart, usersubscription.FieldWeeklyWindowStart, usersubscription.FieldMonthlyWindowStart, usersubscription.FieldAssignedAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -149,6 +162,13 @@ func (_m *UserSubscription) assignValues(columns []string, values []any) error { } else if value.Valid { _m.UpdatedAt = value.Time } + case usersubscription.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + _m.DeletedAt = new(time.Time) + *_m.DeletedAt = value.Time + } case usersubscription.FieldUserID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field user_id", values[i]) @@ -266,6 +286,11 @@ func (_m *UserSubscription) QueryAssignedByUser() *UserQuery { return NewUserSubscriptionClient(_m.config).QueryAssignedByUser(_m) } +// QueryUsageLogs queries the "usage_logs" edge of the UserSubscription entity. +func (_m *UserSubscription) QueryUsageLogs() *UsageLogQuery { + return NewUserSubscriptionClient(_m.config).QueryUsageLogs(_m) +} + // Update returns a builder for updating this UserSubscription. // Note that you need to call UserSubscription.Unwrap() before calling this method if this UserSubscription // was returned from a transaction, and the transaction was committed or rolled back. @@ -295,6 +320,11 @@ func (_m *UserSubscription) String() string { builder.WriteString("updated_at=") builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) builder.WriteString(", ") + if v := _m.DeletedAt; v != nil { + builder.WriteString("deleted_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("user_id=") builder.WriteString(fmt.Sprintf("%v", _m.UserID)) builder.WriteString(", ") diff --git a/backend/ent/usersubscription/usersubscription.go b/backend/ent/usersubscription/usersubscription.go index f4f7fa82..06441646 100644 --- a/backend/ent/usersubscription/usersubscription.go +++ b/backend/ent/usersubscription/usersubscription.go @@ -5,6 +5,7 @@ package usersubscription import ( "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" ) @@ -18,6 +19,8 @@ const ( FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. FieldUpdatedAt = "updated_at" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" // FieldUserID holds the string denoting the user_id field in the database. FieldUserID = "user_id" // FieldGroupID holds the string denoting the group_id field in the database. @@ -52,6 +55,8 @@ const ( EdgeGroup = "group" // EdgeAssignedByUser holds the string denoting the assigned_by_user edge name in mutations. EdgeAssignedByUser = "assigned_by_user" + // EdgeUsageLogs holds the string denoting the usage_logs edge name in mutations. + EdgeUsageLogs = "usage_logs" // Table holds the table name of the usersubscription in the database. Table = "user_subscriptions" // UserTable is the table that holds the user relation/edge. @@ -75,6 +80,13 @@ const ( AssignedByUserInverseTable = "users" // AssignedByUserColumn is the table column denoting the assigned_by_user relation/edge. AssignedByUserColumn = "assigned_by" + // UsageLogsTable is the table that holds the usage_logs relation/edge. + UsageLogsTable = "usage_logs" + // UsageLogsInverseTable is the table name for the UsageLog entity. + // It exists in this package in order to avoid circular dependency with the "usagelog" package. + UsageLogsInverseTable = "usage_logs" + // UsageLogsColumn is the table column denoting the usage_logs relation/edge. + UsageLogsColumn = "subscription_id" ) // Columns holds all SQL columns for usersubscription fields. @@ -82,6 +94,7 @@ var Columns = []string{ FieldID, FieldCreatedAt, FieldUpdatedAt, + FieldDeletedAt, FieldUserID, FieldGroupID, FieldStartsAt, @@ -108,7 +121,14 @@ func ValidColumn(column string) bool { return false } +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/Wei-Shaw/sub2api/ent/runtime" var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -147,6 +167,11 @@ func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() } +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + // ByUserID orders the results by the user_id field. func ByUserID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUserID, opts...).ToFunc() @@ -237,6 +262,20 @@ func ByAssignedByUserField(field string, opts ...sql.OrderTermOption) OrderOptio sqlgraph.OrderByNeighborTerms(s, newAssignedByUserStep(), sql.OrderByField(field, opts...)) } } + +// ByUsageLogsCount orders the results by usage_logs count. +func ByUsageLogsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newUsageLogsStep(), opts...) + } +} + +// ByUsageLogs orders the results by usage_logs terms. +func ByUsageLogs(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUsageLogsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newUserStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -258,3 +297,10 @@ func newAssignedByUserStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.M2O, true, AssignedByUserTable, AssignedByUserColumn), ) } +func newUsageLogsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UsageLogsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) +} diff --git a/backend/ent/usersubscription/where.go b/backend/ent/usersubscription/where.go index f6060d95..250e5ed5 100644 --- a/backend/ent/usersubscription/where.go +++ b/backend/ent/usersubscription/where.go @@ -65,6 +65,11 @@ func UpdatedAt(v time.Time) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldEQ(FieldUpdatedAt, v)) } +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v)) +} + // UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. func UserID(v int64) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v)) @@ -215,6 +220,56 @@ func UpdatedAtLTE(v time.Time) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldLTE(FieldUpdatedAt, v)) } +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.UserSubscription { + return predicate.UserSubscription(sql.FieldNotNull(FieldDeletedAt)) +} + // UserIDEQ applies the EQ predicate on the "user_id" field. func UserIDEQ(v int64) predicate.UserSubscription { return predicate.UserSubscription(sql.FieldEQ(FieldUserID, v)) @@ -884,6 +939,29 @@ func HasAssignedByUserWith(preds ...predicate.User) predicate.UserSubscription { }) } +// HasUsageLogs applies the HasEdge predicate on the "usage_logs" edge. +func HasUsageLogs() predicate.UserSubscription { + return predicate.UserSubscription(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, UsageLogsTable, UsageLogsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUsageLogsWith applies the HasEdge predicate on the "usage_logs" edge with a given conditions (other predicates). +func HasUsageLogsWith(preds ...predicate.UsageLog) predicate.UserSubscription { + return predicate.UserSubscription(func(s *sql.Selector) { + step := newUsageLogsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.UserSubscription) predicate.UserSubscription { return predicate.UserSubscription(sql.AndPredicates(predicates...)) diff --git a/backend/ent/usersubscription_create.go b/backend/ent/usersubscription_create.go index 43997f64..dd03115b 100644 --- a/backend/ent/usersubscription_create.go +++ b/backend/ent/usersubscription_create.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -52,6 +53,20 @@ func (_c *UserSubscriptionCreate) SetNillableUpdatedAt(v *time.Time) *UserSubscr return _c } +// SetDeletedAt sets the "deleted_at" field. +func (_c *UserSubscriptionCreate) SetDeletedAt(v time.Time) *UserSubscriptionCreate { + _c.mutation.SetDeletedAt(v) + return _c +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_c *UserSubscriptionCreate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionCreate { + if v != nil { + _c.SetDeletedAt(*v) + } + return _c +} + // SetUserID sets the "user_id" field. func (_c *UserSubscriptionCreate) SetUserID(v int64) *UserSubscriptionCreate { _c.mutation.SetUserID(v) @@ -245,6 +260,21 @@ func (_c *UserSubscriptionCreate) SetAssignedByUser(v *User) *UserSubscriptionCr return _c.SetAssignedByUserID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_c *UserSubscriptionCreate) AddUsageLogIDs(ids ...int64) *UserSubscriptionCreate { + _c.mutation.AddUsageLogIDs(ids...) + return _c +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_c *UserSubscriptionCreate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddUsageLogIDs(ids...) +} + // Mutation returns the UserSubscriptionMutation object of the builder. func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation { return _c.mutation @@ -252,7 +282,9 @@ func (_c *UserSubscriptionCreate) Mutation() *UserSubscriptionMutation { // Save creates the UserSubscription in the database. func (_c *UserSubscriptionCreate) Save(ctx context.Context) (*UserSubscription, error) { - _c.defaults() + if err := _c.defaults(); err != nil { + return nil, err + } return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) } @@ -279,12 +311,18 @@ func (_c *UserSubscriptionCreate) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_c *UserSubscriptionCreate) defaults() { +func (_c *UserSubscriptionCreate) defaults() error { if _, ok := _c.mutation.CreatedAt(); !ok { + if usersubscription.DefaultCreatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.DefaultCreatedAt (forgotten import ent/runtime?)") + } v := usersubscription.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) } if _, ok := _c.mutation.UpdatedAt(); !ok { + if usersubscription.DefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.DefaultUpdatedAt (forgotten import ent/runtime?)") + } v := usersubscription.DefaultUpdatedAt() _c.mutation.SetUpdatedAt(v) } @@ -305,9 +343,13 @@ func (_c *UserSubscriptionCreate) defaults() { _c.mutation.SetMonthlyUsageUsd(v) } if _, ok := _c.mutation.AssignedAt(); !ok { + if usersubscription.DefaultAssignedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.DefaultAssignedAt (forgotten import ent/runtime?)") + } v := usersubscription.DefaultAssignedAt() _c.mutation.SetAssignedAt(v) } + return nil } // check runs all checks and user-defined validators on the builder. @@ -391,6 +433,10 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) _node.UpdatedAt = value } + if value, ok := _c.mutation.DeletedAt(); ok { + _spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = &value + } if value, ok := _c.mutation.StartsAt(); ok { _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) _node.StartsAt = value @@ -486,6 +532,22 @@ func (_c *UserSubscriptionCreate) createSpec() (*UserSubscription, *sqlgraph.Cre _node.AssignedBy = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -550,6 +612,24 @@ func (u *UserSubscriptionUpsert) UpdateUpdatedAt() *UserSubscriptionUpsert { return u } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserSubscriptionUpsert) SetDeletedAt(v time.Time) *UserSubscriptionUpsert { + u.Set(usersubscription.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserSubscriptionUpsert) UpdateDeletedAt() *UserSubscriptionUpsert { + u.SetExcluded(usersubscription.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserSubscriptionUpsert) ClearDeletedAt() *UserSubscriptionUpsert { + u.SetNull(usersubscription.FieldDeletedAt) + return u +} + // SetUserID sets the "user_id" field. func (u *UserSubscriptionUpsert) SetUserID(v int64) *UserSubscriptionUpsert { u.Set(usersubscription.FieldUserID, v) @@ -825,6 +905,27 @@ func (u *UserSubscriptionUpsertOne) UpdateUpdatedAt() *UserSubscriptionUpsertOne }) } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserSubscriptionUpsertOne) SetDeletedAt(v time.Time) *UserSubscriptionUpsertOne { + return u.Update(func(s *UserSubscriptionUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserSubscriptionUpsertOne) UpdateDeletedAt() *UserSubscriptionUpsertOne { + return u.Update(func(s *UserSubscriptionUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserSubscriptionUpsertOne) ClearDeletedAt() *UserSubscriptionUpsertOne { + return u.Update(func(s *UserSubscriptionUpsert) { + s.ClearDeletedAt() + }) +} + // SetUserID sets the "user_id" field. func (u *UserSubscriptionUpsertOne) SetUserID(v int64) *UserSubscriptionUpsertOne { return u.Update(func(s *UserSubscriptionUpsert) { @@ -1302,6 +1403,27 @@ func (u *UserSubscriptionUpsertBulk) UpdateUpdatedAt() *UserSubscriptionUpsertBu }) } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserSubscriptionUpsertBulk) SetDeletedAt(v time.Time) *UserSubscriptionUpsertBulk { + return u.Update(func(s *UserSubscriptionUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserSubscriptionUpsertBulk) UpdateDeletedAt() *UserSubscriptionUpsertBulk { + return u.Update(func(s *UserSubscriptionUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserSubscriptionUpsertBulk) ClearDeletedAt() *UserSubscriptionUpsertBulk { + return u.Update(func(s *UserSubscriptionUpsert) { + s.ClearDeletedAt() + }) +} + // SetUserID sets the "user_id" field. func (u *UserSubscriptionUpsertBulk) SetUserID(v int64) *UserSubscriptionUpsertBulk { return u.Update(func(s *UserSubscriptionUpsert) { diff --git a/backend/ent/usersubscription_query.go b/backend/ent/usersubscription_query.go index 034f29b4..967fbddb 100644 --- a/backend/ent/usersubscription_query.go +++ b/backend/ent/usersubscription_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" @@ -13,6 +14,7 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -27,6 +29,7 @@ type UserSubscriptionQuery struct { withUser *UserQuery withGroup *GroupQuery withAssignedByUser *UserQuery + withUsageLogs *UsageLogQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -129,6 +132,28 @@ func (_q *UserSubscriptionQuery) QueryAssignedByUser() *UserQuery { return query } +// QueryUsageLogs chains the current query on the "usage_logs" edge. +func (_q *UserSubscriptionQuery) QueryUsageLogs() *UsageLogQuery { + query := (&UsageLogClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(usersubscription.Table, usersubscription.FieldID, selector), + sqlgraph.To(usagelog.Table, usagelog.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, usersubscription.UsageLogsTable, usersubscription.UsageLogsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first UserSubscription entity from the query. // Returns a *NotFoundError when no UserSubscription was found. func (_q *UserSubscriptionQuery) First(ctx context.Context) (*UserSubscription, error) { @@ -324,6 +349,7 @@ func (_q *UserSubscriptionQuery) Clone() *UserSubscriptionQuery { withUser: _q.withUser.Clone(), withGroup: _q.withGroup.Clone(), withAssignedByUser: _q.withAssignedByUser.Clone(), + withUsageLogs: _q.withUsageLogs.Clone(), // clone intermediate query. sql: _q.sql.Clone(), path: _q.path, @@ -363,6 +389,17 @@ func (_q *UserSubscriptionQuery) WithAssignedByUser(opts ...func(*UserQuery)) *U return _q } +// WithUsageLogs tells the query-builder to eager-load the nodes that are connected to +// the "usage_logs" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserSubscriptionQuery) WithUsageLogs(opts ...func(*UsageLogQuery)) *UserSubscriptionQuery { + query := (&UsageLogClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUsageLogs = query + return _q +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -441,10 +478,11 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) var ( nodes = []*UserSubscription{} _spec = _q.querySpec() - loadedTypes = [3]bool{ + loadedTypes = [4]bool{ _q.withUser != nil, _q.withGroup != nil, _q.withAssignedByUser != nil, + _q.withUsageLogs != nil, } ) _spec.ScanValues = func(columns []string) ([]any, error) { @@ -483,6 +521,13 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook) return nil, err } } + if query := _q.withUsageLogs; query != nil { + if err := _q.loadUsageLogs(ctx, query, nodes, + func(n *UserSubscription) { n.Edges.UsageLogs = []*UsageLog{} }, + func(n *UserSubscription, e *UsageLog) { n.Edges.UsageLogs = append(n.Edges.UsageLogs, e) }); err != nil { + return nil, err + } + } return nodes, nil } @@ -576,6 +621,39 @@ func (_q *UserSubscriptionQuery) loadAssignedByUser(ctx context.Context, query * } return nil } +func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *UsageLogQuery, nodes []*UserSubscription, init func(*UserSubscription), assign func(*UserSubscription, *UsageLog)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*UserSubscription) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(usagelog.FieldSubscriptionID) + } + query.Where(predicate.UsageLog(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(usersubscription.UsageLogsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.SubscriptionID + if fk == nil { + return fmt.Errorf(`foreign-key "subscription_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "subscription_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) { _spec := _q.querySpec() diff --git a/backend/ent/usersubscription_update.go b/backend/ent/usersubscription_update.go index c0df17ff..811dae7e 100644 --- a/backend/ent/usersubscription_update.go +++ b/backend/ent/usersubscription_update.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" ) @@ -36,6 +37,26 @@ func (_u *UserSubscriptionUpdate) SetUpdatedAt(v time.Time) *UserSubscriptionUpd return _u } +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserSubscriptionUpdate) SetDeletedAt(v time.Time) *UserSubscriptionUpdate { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserSubscriptionUpdate) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdate { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserSubscriptionUpdate) ClearDeletedAt() *UserSubscriptionUpdate { + _u.mutation.ClearDeletedAt() + return _u +} + // SetUserID sets the "user_id" field. func (_u *UserSubscriptionUpdate) SetUserID(v int64) *UserSubscriptionUpdate { _u.mutation.SetUserID(v) @@ -312,6 +333,21 @@ func (_u *UserSubscriptionUpdate) SetAssignedByUser(v *User) *UserSubscriptionUp return _u.SetAssignedByUserID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserSubscriptionUpdate) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdate { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdate) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserSubscriptionMutation object of the builder. func (_u *UserSubscriptionUpdate) Mutation() *UserSubscriptionMutation { return _u.mutation @@ -335,9 +371,32 @@ func (_u *UserSubscriptionUpdate) ClearAssignedByUser() *UserSubscriptionUpdate return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdate) ClearUsageLogs() *UserSubscriptionUpdate { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserSubscriptionUpdate) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdate { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserSubscriptionUpdate) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserSubscriptionUpdate) Save(ctx context.Context) (int, error) { - _u.defaults() + if err := _u.defaults(); err != nil { + return 0, err + } return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) } @@ -364,11 +423,15 @@ func (_u *UserSubscriptionUpdate) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_u *UserSubscriptionUpdate) defaults() { +func (_u *UserSubscriptionUpdate) defaults() error { if _, ok := _u.mutation.UpdatedAt(); !ok { + if usersubscription.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } v := usersubscription.UpdateDefaultUpdatedAt() _u.mutation.SetUpdatedAt(v) } + return nil } // check runs all checks and user-defined validators on the builder. @@ -402,6 +465,12 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime) + } if value, ok := _u.mutation.StartsAt(); ok { _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) } @@ -543,6 +612,51 @@ func (_u *UserSubscriptionUpdate) sqlSave(ctx context.Context) (_node int, err e } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{usersubscription.Label} @@ -569,6 +683,26 @@ func (_u *UserSubscriptionUpdateOne) SetUpdatedAt(v time.Time) *UserSubscription return _u } +// SetDeletedAt sets the "deleted_at" field. +func (_u *UserSubscriptionUpdateOne) SetDeletedAt(v time.Time) *UserSubscriptionUpdateOne { + _u.mutation.SetDeletedAt(v) + return _u +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (_u *UserSubscriptionUpdateOne) SetNillableDeletedAt(v *time.Time) *UserSubscriptionUpdateOne { + if v != nil { + _u.SetDeletedAt(*v) + } + return _u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (_u *UserSubscriptionUpdateOne) ClearDeletedAt() *UserSubscriptionUpdateOne { + _u.mutation.ClearDeletedAt() + return _u +} + // SetUserID sets the "user_id" field. func (_u *UserSubscriptionUpdateOne) SetUserID(v int64) *UserSubscriptionUpdateOne { _u.mutation.SetUserID(v) @@ -845,6 +979,21 @@ func (_u *UserSubscriptionUpdateOne) SetAssignedByUser(v *User) *UserSubscriptio return _u.SetAssignedByUserID(v.ID) } +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by IDs. +func (_u *UserSubscriptionUpdateOne) AddUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne { + _u.mutation.AddUsageLogIDs(ids...) + return _u +} + +// AddUsageLogs adds the "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdateOne) AddUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddUsageLogIDs(ids...) +} + // Mutation returns the UserSubscriptionMutation object of the builder. func (_u *UserSubscriptionUpdateOne) Mutation() *UserSubscriptionMutation { return _u.mutation @@ -868,6 +1017,27 @@ func (_u *UserSubscriptionUpdateOne) ClearAssignedByUser() *UserSubscriptionUpda return _u } +// ClearUsageLogs clears all "usage_logs" edges to the UsageLog entity. +func (_u *UserSubscriptionUpdateOne) ClearUsageLogs() *UserSubscriptionUpdateOne { + _u.mutation.ClearUsageLogs() + return _u +} + +// RemoveUsageLogIDs removes the "usage_logs" edge to UsageLog entities by IDs. +func (_u *UserSubscriptionUpdateOne) RemoveUsageLogIDs(ids ...int64) *UserSubscriptionUpdateOne { + _u.mutation.RemoveUsageLogIDs(ids...) + return _u +} + +// RemoveUsageLogs removes "usage_logs" edges to UsageLog entities. +func (_u *UserSubscriptionUpdateOne) RemoveUsageLogs(v ...*UsageLog) *UserSubscriptionUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveUsageLogIDs(ids...) +} + // Where appends a list predicates to the UserSubscriptionUpdate builder. func (_u *UserSubscriptionUpdateOne) Where(ps ...predicate.UserSubscription) *UserSubscriptionUpdateOne { _u.mutation.Where(ps...) @@ -883,7 +1053,9 @@ func (_u *UserSubscriptionUpdateOne) Select(field string, fields ...string) *Use // Save executes the query and returns the updated UserSubscription entity. func (_u *UserSubscriptionUpdateOne) Save(ctx context.Context) (*UserSubscription, error) { - _u.defaults() + if err := _u.defaults(); err != nil { + return nil, err + } return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) } @@ -910,11 +1082,15 @@ func (_u *UserSubscriptionUpdateOne) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (_u *UserSubscriptionUpdateOne) defaults() { +func (_u *UserSubscriptionUpdateOne) defaults() error { if _, ok := _u.mutation.UpdatedAt(); !ok { + if usersubscription.UpdateDefaultUpdatedAt == nil { + return fmt.Errorf("ent: uninitialized usersubscription.UpdateDefaultUpdatedAt (forgotten import ent/runtime?)") + } v := usersubscription.UpdateDefaultUpdatedAt() _u.mutation.SetUpdatedAt(v) } + return nil } // check runs all checks and user-defined validators on the builder. @@ -965,6 +1141,12 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(usersubscription.FieldUpdatedAt, field.TypeTime, value) } + if value, ok := _u.mutation.DeletedAt(); ok { + _spec.SetField(usersubscription.FieldDeletedAt, field.TypeTime, value) + } + if _u.mutation.DeletedAtCleared() { + _spec.ClearField(usersubscription.FieldDeletedAt, field.TypeTime) + } if value, ok := _u.mutation.StartsAt(); ok { _spec.SetField(usersubscription.FieldStartsAt, field.TypeTime, value) } @@ -1106,6 +1288,51 @@ func (_u *UserSubscriptionUpdateOne) sqlSave(ctx context.Context) (_node *UserSu } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedUsageLogsIDs(); len(nodes) > 0 && !_u.mutation.UsageLogsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UsageLogsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: usersubscription.UsageLogsTable, + Columns: []string{usersubscription.UsageLogsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(usagelog.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &UserSubscription{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 19dff447..23e85e9a 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -14,6 +14,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "strconv" "time" @@ -56,7 +57,7 @@ func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accoun func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { if account == nil { - return nil + return service.ErrAccountNilInput } builder := r.client.Account.Create(). @@ -98,7 +99,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account created, err := builder.Save(ctx) if err != nil { - return err + return translatePersistenceError(err, service.ErrAccountNotFound, nil) } account.ID = created.ID @@ -231,11 +232,32 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account } func (r *accountRepository) Delete(ctx context.Context, id int64) error { - if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil { + // 使用事务保证账号与关联分组的删除原子性 + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { return err } - _, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx) - return err + + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + // 已处于外部事务中(ErrTxStarted),复用当前 client + txClient = r.client + } + + if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil { + return err + } + if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil { + return err + } + + if tx != nil { + return tx.Commit() + } + return nil } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { @@ -393,25 +415,49 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s } func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { - if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil { + // 使用事务保证删除旧绑定与创建新绑定的原子性 + tx, err := r.client.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return err + } + + var txClient *dbent.Client + if err == nil { + defer func() { _ = tx.Rollback() }() + txClient = tx.Client() + } else { + // 已处于外部事务中(ErrTxStarted),复用当前 client + txClient = r.client + } + + if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil { return err } if len(groupIDs) == 0 { + if tx != nil { + return tx.Commit() + } return nil } builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs)) for i, groupID := range groupIDs { - builders = append(builders, r.client.AccountGroup.Create(). + builders = append(builders, txClient.AccountGroup.Create(). SetAccountID(accountID). SetGroupID(groupID). SetPriority(i+1), ) } - _, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx) - return err + if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil { + return err + } + + if tx != nil { + return tx.Commit() + } + return nil } func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) { @@ -555,24 +601,30 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m return nil } - accountExtra, err := r.client.Account.Query(). - Where(dbaccount.IDEQ(id)). - Select(dbaccount.FieldExtra). - Only(ctx) + // 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题 + payload, err := json.Marshal(updates) if err != nil { - return translatePersistenceError(err, service.ErrAccountNotFound, nil) + return err } - extra := normalizeJSONMap(accountExtra.Extra) - for k, v := range updates { - extra[k] = v + client := clientFromContext(ctx, r.client) + result, err := client.ExecContext( + ctx, + "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL", + payload, id, + ) + if err != nil { + return err } - _, err = r.client.Account.Update(). - Where(dbaccount.IDEQ(id)). - SetExtra(extra). - Save(ctx) - return err + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAccountNotFound + } + return nil } func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 269f0661..a3a52333 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -318,12 +318,13 @@ func groupEntityToService(g *dbent.Group) *service.Group { RateMultiplier: g.RateMultiplier, IsExclusive: g.IsExclusive, Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + DefaultValidityDays: g.DefaultValidityDays, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/error_translate.go b/backend/internal/repository/error_translate.go index 68348830..192f9261 100644 --- a/backend/internal/repository/error_translate.go +++ b/backend/internal/repository/error_translate.go @@ -1,6 +1,7 @@ package repository import ( + "context" "database/sql" "errors" "strings" @@ -10,6 +11,25 @@ import ( "github.com/lib/pq" ) +// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。 +// +// 这个辅助函数支持 repository 方法在事务上下文中工作: +// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client +// - 否则返回传入的默认 client +// +// 使用示例: +// +// func (r *someRepo) SomeMethod(ctx context.Context) error { +// client := clientFromContext(ctx, r.client) +// return client.SomeEntity.Create().Save(ctx) +// } +func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client { + if tx := dbent.TxFromContext(ctx); tx != nil { + return tx.Client() + } + return defaultClient +} + // translatePersistenceError 将数据库层错误翻译为业务层错误。 // // 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。 diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 5670a69b..53085247 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetSubscriptionType(groupIn.SubscriptionType). SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). - SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD) + SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetDefaultValidityDays(groupIn.DefaultValidityDays) created, err := builder.Save(ctx) if err == nil { @@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetDefaultValidityDays(groupIn.DefaultValidityDays). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) @@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er func (r *groupRepository) Delete(ctx context.Context, id int64) error { _, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx) - return err + return translatePersistenceError(err, service.ErrGroupNotFound, nil) } func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { @@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, // err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。 // Lock the group row to avoid concurrent writes while we cascade. - // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。 - rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id) + // 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。 + rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id) if err != nil { return nil, err } @@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, var affectedUserIDs []int64 if groupSvc.IsSubscriptionType() { - rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id) + // 只查询未软删除的订阅,避免通知已取消订阅的用户 + rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id) if err != nil { return nil, err } @@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return nil, err } - if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil { + // 软删除订阅:设置 deleted_at 而非硬删除 + if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil { return nil, err } } @@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return nil, err } - // 3. Remove the group id from users.allowed_groups array (legacy representation). - // Phase 1 compatibility: also delete from user_allowed_groups join table when present. + // 3. Remove the group id from user_allowed_groups join table. + // Legacy users.allowed_groups 列已弃用,不再同步。 if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil { return nil, err } - if _, err := exec.ExecContext( - ctx, - "UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)", - id, - ); err != nil { - return nil, err - } // 4. Delete account_groups join rows. if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil { diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index a02c5f8f..b9079d7a 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { count, _ := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().Zero(count) } + +// --- 软删除过滤测试 --- + +func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() { + group := &service.Group{ + Name: "to-soft-delete", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, group)) + + // 获取删除前的列表数量 + listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100}) + s.Require().NoError(err) + beforeCount := len(listBefore) + + // 软删除 + err = s.repo.Delete(s.ctx, group.ID) + s.Require().NoError(err, "Delete (soft delete)") + + // 验证列表中不再包含软删除的 group + listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100}) + s.Require().NoError(err) + s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list") + + // 验证 GetByID 也无法找到 + _, err = s.repo.GetByID(s.ctx, group.ID) + s.Require().Error(err) + s.Require().ErrorIs(err, service.ErrGroupNotFound) +} + +func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() { + group := &service.Group{ + Name: "lock-soft-delete", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, group)) + + // 软删除 + err := s.repo.Delete(s.ctx, group.ID) + s.Require().NoError(err) + + // 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound + // 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作 + _, err = s.repo.GetByID(s.ctx, group.ID) + s.Require().Error(err, "should fail to get soft-deleted group") + s.Require().ErrorIs(err, service.ErrGroupNotFound) +} diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 80b0fad7..49d96445 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -53,6 +53,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { var uagRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass)) require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist") + + // user_subscriptions: deleted_at for soft delete support (migration 012) + requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true) + + // orphan_allowed_groups_audit table should exist (migration 013) + var orphanAuditRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass)) + require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist") + + // account_groups: created_at should be timestamptz + requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false) + + // user_allowed_groups: created_at should be timestamptz + requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index f9315525..c24b2e2c 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -178,7 +178,7 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) { - rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id") + rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id") if err != nil { return nil, err } diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index 1429c678..ee8a01b5 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -168,7 +168,8 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { now := time.Now() - affected, err := r.client.RedeemCode.Update(). + client := clientFromContext(ctx, r.client) + affected, err := client.RedeemCode.Update(). Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)). SetStatus(service.StatusUsed). SetUsedBy(userID). diff --git a/backend/internal/repository/soft_delete_ent_integration_test.go b/backend/internal/repository/soft_delete_ent_integration_test.go index 02176f90..e3560ab5 100644 --- a/backend/internal/repository/soft_delete_ent_integration_test.go +++ b/backend/internal/repository/soft_delete_ent_integration_test.go @@ -7,10 +7,12 @@ import ( "fmt" "strings" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -111,3 +113,104 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { Only(mixins.SkipSoftDelete(ctx)) require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted") } + +// --- UserSubscription 软删除测试 --- + +func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group { + t.Helper() + + g, err := client.Group.Create(). + SetName(name). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err, "create ent group") + return g +} + +func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com") + g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group")) + + repo := NewUserSubscriptionRepository(client) + sub := &service.UserSubscription{ + UserID: u.ID, + GroupID: g.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub), "create user subscription") + + require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription") + + _, err := repo.GetByID(ctx, sub.ID) + require.Error(t, err, "deleted rows should be hidden by default") + + _, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx) + require.Error(t, err, "default ent query should not see soft-deleted rows") + require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter") + + got, err := client.UserSubscription.Query(). + Where(usersubscription.IDEQ(sub.ID)). + Only(mixins.SkipSoftDelete(ctx)) + require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows") + require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete") +} + +func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com") + g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2")) + + repo := NewUserSubscriptionRepository(client) + sub := &service.UserSubscription{ + UserID: u.ID, + GroupID: g.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub), "create user subscription") + + require.NoError(t, repo.Delete(ctx, sub.ID), "first delete") + require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent") +} + +func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + + u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com") + g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a")) + g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b")) + + repo := NewUserSubscriptionRepository(client) + + sub1 := &service.UserSubscription{ + UserID: u.ID, + GroupID: g1.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub1), "create subscription 1") + + sub2 := &service.UserSubscription{ + UserID: u.ID, + GroupID: g2.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + require.NoError(t, repo.Create(ctx, sub2), "create subscription 2") + + // 软删除 sub1 + require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1") + + // ListByUserID 应只返回未删除的订阅 + subs, err := repo.ListByUserID(ctx, u.ID) + require.NoError(t, err, "ListByUserID") + require.Len(t, subs, 1, "should only return non-deleted subscriptions") + require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned") +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 9a210bde..367ad430 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1109,6 +1109,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } today := timezone.Today() todayQuery := ` @@ -1135,6 +1138,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } return result, nil } @@ -1177,6 +1183,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } today := timezone.Today() todayQuery := ` @@ -1203,6 +1212,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe if err := rows.Close(); err != nil { return nil, err } + if err := rows.Err(); err != nil { + return nil, err + } return result, nil } diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 7766fe98..7294fadc 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -12,7 +12,6 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/lib/pq" ) type userRepository struct { @@ -86,10 +85,11 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{id}) - if err == nil { - if v, ok := groups[id]; ok { - out.AllowedGroups = v - } + if err != nil { + return nil, err + } + if v, ok := groups[id]; ok { + out.AllowedGroups = v } return out, nil } @@ -102,10 +102,11 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) - if err == nil { - if v, ok := groups[m.ID]; ok { - out.AllowedGroups = v - } + if err != nil { + return nil, err + } + if v, ok := groups[m.ID]; ok { + out.AllowedGroups = v } return out, nil } @@ -240,11 +241,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. } allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs) - if err == nil { - for id, u := range userMap { - if groups, ok := allowedGroupsByUser[id]; ok { - u.AllowedGroups = groups - } + if err != nil { + return nil, nil, err + } + for id, u := range userMap { + if groups, ok := allowedGroupsByUser[id]; ok { + u.AllowedGroups = groups } } @@ -252,12 +254,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. } func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { - _, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) - return err + client := clientFromContext(ctx, r.client) + n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if n == 0 { + return service.ErrUserNotFound + } + return nil } func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { - n, err := r.client.User.Update(). + client := clientFromContext(ctx, r.client) + n, err := client.User.Update(). Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)). AddBalance(-amount). Save(ctx) @@ -271,8 +281,15 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo } func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { - _, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx) - return err + client := clientFromContext(ctx, r.client) + n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + if n == 0 { + return service.ErrUserNotFound + } + return nil } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { @@ -280,33 +297,14 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, } func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { - exec := r.sql - if exec == nil { - // 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。 - exec = r.client - } - - joinAffected, err := r.client.UserAllowedGroup.Delete(). + // 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。 + affected, err := r.client.UserAllowedGroup.Delete(). Where(userallowedgroup.GroupIDEQ(groupID)). Exec(ctx) if err != nil { return 0, err } - - arrayRes, err := exec.ExecContext( - ctx, - "UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)", - groupID, - ) - if err != nil { - return 0, err - } - arrayAffected, _ := arrayRes.RowsAffected() - - if int64(joinAffected) > arrayAffected { - return int64(joinAffected), nil - } - return arrayAffected, nil + return int64(affected), nil } func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) { @@ -323,10 +321,11 @@ func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, erro out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) - if err == nil { - if v, ok := groups[m.ID]; ok { - out.AllowedGroups = v - } + if err != nil { + return nil, err + } + if v, ok := groups[m.ID]; ok { + out.AllowedGroups = v } return out, nil } @@ -356,8 +355,7 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) } // syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组: -// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致; -// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。 +// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。 func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error { if client == nil { return nil @@ -376,12 +374,10 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl unique[id] = struct{}{} } - legacyGroups := make([]int64, 0, len(unique)) if len(unique) > 0 { creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique)) for groupID := range unique { creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID)) - legacyGroups = append(legacyGroups, groupID) } if err := client.UserAllowedGroup. CreateBulk(creates...). @@ -392,16 +388,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl } } - // Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。 - var legacy any - if len(legacyGroups) > 0 { - sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] }) - legacy = pq.Array(legacyGroups) - } - if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil { - return err - } - return nil } diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go index a59d2312..c5c9e78c 100644 --- a/backend/internal/repository/user_repo_integration_test.go +++ b/backend/internal/repository/user_repo_integration_test.go @@ -508,3 +508,24 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch") } +// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 --- + +func (s *UserRepoSuite) TestUpdateBalance_NotFound() { + err := s.repo.UpdateBalance(s.ctx, 999999, 10.0) + s.Require().Error(err, "expected error for non-existent user") + s.Require().ErrorIs(err, service.ErrUserNotFound) +} + +func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() { + err := s.repo.UpdateConcurrency(s.ctx, 999999, 5) + s.Require().Error(err, "expected error for non-existent user") + s.Require().ErrorIs(err, service.ErrUserNotFound) +} + +func (s *UserRepoSuite) TestDeductBalance_NotFound() { + err := s.repo.DeductBalance(s.ctx, 999999, 5) + s.Require().Error(err, "expected error for non-existent user") + // DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配 + s.Require().ErrorIs(err, service.ErrInsufficientBalance) +} + diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 918ccab4..2b308674 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error { if sub == nil { - return nil + return service.ErrSubscriptionNilInput } - builder := r.client.UserSubscription.Create(). + client := clientFromContext(ctx, r.client) + builder := client.UserSubscription.Create(). SetUserID(sub.UserID). SetGroupID(sub.GroupID). SetExpiresAt(sub.ExpiresAt). @@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us } func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { - m, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + m, err := client.UserSubscription.Query(). Where(usersubscription.IDEQ(id)). WithUser(). WithGroup(). @@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se } func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { - m, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + m, err := client.UserSubscription.Query(). Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)). WithGroup(). Only(ctx) @@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, } func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { - m, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + m, err := client.UserSubscription.Query(). Where( usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID), @@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error { if sub == nil { - return nil + return service.ErrSubscriptionNilInput } - builder := r.client.UserSubscription.UpdateOneID(sub.ID). + client := clientFromContext(ctx, r.client) + builder := client.UserSubscription.UpdateOneID(sub.ID). SetUserID(sub.UserID). SetGroupID(sub.GroupID). SetStartsAt(sub.StartsAt). @@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { // Match GORM semantics: deleting a missing row is not an error. - _, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx) + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx) return err } func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { - subs, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + subs, err := client.UserSubscription.Query(). Where(usersubscription.UserIDEQ(userID)). WithGroup(). Order(dbent.Desc(usersubscription.FieldCreatedAt)). @@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in } func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { - subs, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + subs, err := client.UserSubscription.Query(). Where( usersubscription.UserIDEQ(userID), usersubscription.StatusEQ(service.SubscriptionStatusActive), @@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use } func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { - q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)) + client := clientFromContext(ctx, r.client) + q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)) total, err := q.Clone().Count(ctx) if err != nil { @@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID } func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { - q := r.client.UserSubscription.Query() + client := clientFromContext(ctx, r.client) + q := client.UserSubscription.Query() if userID != nil { q = q.Where(usersubscription.UserIDEQ(*userID)) } @@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination } func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { - return r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + return client.UserSubscription.Query(). Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)). Exist(ctx) } func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(subscriptionID). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(subscriptionID). SetExpiresAt(newExpiresAt). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { - _, err := r.client.UserSubscription.UpdateOneID(subscriptionID). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(subscriptionID). SetStatus(status). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { - _, err := r.client.UserSubscription.UpdateOneID(subscriptionID). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(subscriptionID). SetNotes(notes). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetDailyWindowStart(start). SetWeeklyWindowStart(start). SetMonthlyWindowStart(start). @@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int } func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetDailyUsageUsd(0). SetDailyWindowStart(newWindowStart). Save(ctx) @@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int } func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetWeeklyUsageUsd(0). SetWeeklyWindowStart(newWindowStart). Save(ctx) @@ -266,24 +283,112 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in } func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - _, err := r.client.UserSubscription.UpdateOneID(id). + client := clientFromContext(ctx, r.client) + _, err := client.UserSubscription.UpdateOneID(id). SetMonthlyUsageUsd(0). SetMonthlyWindowStart(newWindowStart). Save(ctx) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } +// IncrementUsage 原子性地累加用量并校验限额。 +// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。 +// 当更新失败时,会执行额外查询确定具体超出的限额类型。 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { - _, err := r.client.UserSubscription.UpdateOneID(id). - AddDailyUsageUsd(costUSD). - AddWeeklyUsageUsd(costUSD). - AddMonthlyUsageUsd(costUSD). - Save(ctx) - return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) + // 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加 + // NULL 限额表示无限制 + const atomicUpdateSQL = ` + UPDATE user_subscriptions us + SET + daily_usage_usd = us.daily_usage_usd + $1, + weekly_usage_usd = us.weekly_usage_usd + $1, + monthly_usage_usd = us.monthly_usage_usd + $1, + updated_at = NOW() + FROM groups g + WHERE us.id = $2 + AND us.deleted_at IS NULL + AND us.group_id = g.id + AND g.deleted_at IS NULL + AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd) + AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd) + AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd) + ` + + client := clientFromContext(ctx, r.client) + result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id) + if err != nil { + return err + } + + affected, err := result.RowsAffected() + if err != nil { + return err + } + + if affected > 0 { + return nil // 更新成功 + } + + // affected == 0:可能是订阅不存在、分组已删除、或限额超出 + // 执行额外查询确定具体原因 + return r.checkIncrementFailureReason(ctx, id, costUSD) +} + +// checkIncrementFailureReason 查询更新失败的具体原因 +func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error { + const checkSQL = ` + SELECT + CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted' + WHEN g.id IS NULL THEN 'subscription_not_found' + WHEN g.deleted_at IS NOT NULL THEN 'group_deleted' + WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded' + WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded' + WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded' + ELSE 'unknown' + END AS reason + FROM user_subscriptions us + LEFT JOIN groups g ON us.group_id = g.id + WHERE us.id = $2 + ` + + client := clientFromContext(ctx, r.client) + rows, err := client.QueryContext(ctx, checkSQL, costUSD, id) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return service.ErrSubscriptionNotFound + } + + var reason string + if err := rows.Scan(&reason); err != nil { + return err + } + + if err := rows.Err(); err != nil { + return err + } + + switch reason { + case "subscription_not_found", "subscription_deleted", "group_deleted": + return service.ErrSubscriptionNotFound + case "daily_exceeded": + return service.ErrDailyLimitExceeded + case "weekly_exceeded": + return service.ErrWeeklyLimitExceeded + case "monthly_exceeded": + return service.ErrMonthlyLimitExceeded + default: + // unknown 情况理论上不应发生,但作为兜底返回 + return service.ErrSubscriptionNotFound + } } func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { - n, err := r.client.UserSubscription.Update(). + client := clientFromContext(ctx, r.client) + n, err := client.UserSubscription.Update(). Where( usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.ExpiresAtLTE(time.Now()), @@ -296,7 +401,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex // Extra repository helpers (currently used only by integration tests). func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) { - subs, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + subs, err := client.UserSubscription.Query(). Where( usersubscription.StatusEQ(service.SubscriptionStatusActive), usersubscription.ExpiresAtLTE(time.Now()), @@ -309,12 +415,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service } func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { - count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx) + client := clientFromContext(ctx, r.client) + count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx) return int64(count), err } func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { - count, err := r.client.UserSubscription.Query(). + client := clientFromContext(ctx, r.client) + count, err := client.UserSubscription.Query(). Where( usersubscription.GroupIDEQ(groupID), usersubscription.StatusEQ(service.SubscriptionStatusActive), @@ -325,7 +433,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g } func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { - n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx) + client := clientFromContext(ctx, r.client) + n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx) return int64(n), err } diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 282b9673..3a6c6434 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -4,6 +4,7 @@ package repository import ( "context" + "fmt" "testing" "time" @@ -631,3 +632,249 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba s.Require().NoError(err, "GetByID expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") } + +// --- 限额检查与软删除过滤测试 --- + +func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group { + s.T().Helper() + + create := s.client.Group.Create(). + SetName(name). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeSubscription) + + if daily != nil { + create.SetDailyLimitUsd(*daily) + } + if weekly != nil { + create.SetWeeklyLimitUsd(*weekly) + } + if monthly != nil { + create.SetMonthlyLimitUsd(*monthly) + } + + g, err := create.Save(s.ctx) + s.Require().NoError(err, "create group with limits") + return groupEntityToService(g) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() { + user := s.mustCreateUser("dailylimit@test.com", service.RoleUser) + dailyLimit := 10.0 + group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 先增加 9.0,应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0) + s.Require().NoError(err, "first increment should succeed") + + // 再增加 2.0,会超过 10.0 限额,应该失败 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0) + s.Require().Error(err, "should fail when daily limit exceeded") + s.Require().ErrorIs(err, service.ErrDailyLimitExceeded) + + // 验证用量没有变化 + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment") +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() { + user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser) + weeklyLimit := 50.0 + group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 增加 45.0,应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0) + s.Require().NoError(err, "first increment should succeed") + + // 再增加 10.0,会超过 50.0 限额,应该失败 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) + s.Require().Error(err, "should fail when weekly limit exceeded") + s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() { + user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser) + monthlyLimit := 100.0 + group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 增加 90.0,应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0) + s.Require().NoError(err, "first increment should succeed") + + // 再增加 20.0,会超过 100.0 限额,应该失败 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0) + s.Require().Error(err, "should fail when monthly limit exceeded") + s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() { + user := s.mustCreateUser("nolimits@test.com", service.RoleUser) + group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额 + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 应该可以增加任意金额 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0) + s.Require().NoError(err, "should succeed without limits") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() { + user := s.mustCreateUser("exactlimit@test.com", service.RoleUser) + dailyLimit := 10.0 + group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 正好达到限额应该成功 + err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) + s.Require().NoError(err, "should succeed at exact limit") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() { + user := s.mustCreateUser("softdeleted@test.com", service.RoleUser) + group := s.mustCreateGroup("g-softdeleted") + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 软删除分组 + _, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx) + s.Require().NoError(err, "soft delete group") + + // IncrementUsage 应该失败,因为分组已软删除 + err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0) + s.Require().Error(err, "should fail for soft-deleted group") + s.Require().ErrorIs(err, service.ErrSubscriptionNotFound) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() { + err := s.repo.IncrementUsage(s.ctx, 999999, 1.0) + s.Require().Error(err, "should fail for non-existent subscription") + s.Require().ErrorIs(err, service.ErrSubscriptionNotFound) +} + +// --- nil 入参测试 --- + +func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() { + err := s.repo.Create(s.ctx, nil) + s.Require().Error(err, "Create should fail with nil input") + s.Require().ErrorIs(err, service.ErrSubscriptionNilInput) +} + +func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() { + err := s.repo.Update(s.ctx, nil) + s.Require().Error(err, "Update should fail with nil input") + s.Require().ErrorIs(err, service.ErrSubscriptionNilInput) +} + +// --- 并发用量更新测试 --- + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { + user := s.mustCreateUser("concurrent@test.com", service.RoleUser) + group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额 + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + const numGoroutines = 10 + const incrementPerGoroutine = 1.5 + + // 启动多个 goroutine 并发调用 IncrementUsage + errCh := make(chan error, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine) + }() + } + + // 等待所有 goroutine 完成 + for i := 0; i < numGoroutines; i++ { + err := <-errCh + s.Require().NoError(err, "IncrementUsage should succeed") + } + + // 验证累加结果正确 + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + expectedUsage := float64(numGoroutines) * incrementPerGoroutine + s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated") + s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated") + s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated") +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() { + user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser) + dailyLimit := 5.0 + group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil) + sub := s.mustCreateSubscription(user.ID, group.ID, nil) + + // 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑 + // 尝试增加 10 次,每次 1.0,但限额只有 5.0 + const numAttempts = 10 + const incrementPerAttempt = 1.0 + + successCount := 0 + for i := 0; i < numAttempts; i++ { + err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt) + if err == nil { + successCount++ + } + } + + // 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额) + s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)") + + // 验证最终用量等于限额 + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit") +} + +func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() { + baseClient := testEntClient(s.T()) + tx, err := baseClient.Tx(context.Background()) + s.Require().NoError(err, "begin tx") + defer func() { + if tx != nil { + _ = tx.Rollback() + } + }() + + txCtx := dbent.NewTxContext(context.Background(), tx) + suffix := fmt.Sprintf("%d", time.Now().UnixNano()) + + userEnt, err := tx.Client().User.Create(). + SetEmail("tx-user-" + suffix + "@example.com"). + SetPasswordHash("test"). + Save(txCtx) + s.Require().NoError(err, "create user in tx") + + groupEnt, err := tx.Client().Group.Create(). + SetName("tx-group-" + suffix). + Save(txCtx) + s.Require().NoError(err, "create group in tx") + + repo := NewUserSubscriptionRepository(baseClient) + sub := &service.UserSubscription{ + UserID: userEnt.ID, + GroupID: groupEnt.ID, + ExpiresAt: time.Now().AddDate(0, 0, 30), + Status: service.SubscriptionStatusActive, + AssignedAt: time.Now(), + Notes: "tx", + } + s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx") + s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx") + + s.Require().NoError(tx.Rollback(), "rollback tx") + tx = nil + + _, err = repo.GetByID(context.Background(), sub.ID) + s.Require().ErrorIs(err, service.ErrSubscriptionNotFound) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index ca3c4250..05895c8b 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -11,6 +11,7 @@ import ( var ( ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found") + ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil") ) type AccountRepository interface { diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index f1e36b89..7d6f407d 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -11,10 +11,11 @@ type Group struct { IsExclusive bool Status string - SubscriptionType string - DailyLimitUSD *float64 - WeeklyLimitUSD *float64 - MonthlyLimitUSD *float64 + SubscriptionType string + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + DefaultValidityDays int CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index c587d212..7b0b80f5 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -9,6 +9,7 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -72,6 +73,7 @@ type RedeemService struct { subscriptionService *SubscriptionService cache RedeemCache billingCacheService *BillingCacheService + entClient *dbent.Client } // NewRedeemService 创建兑换码服务实例 @@ -81,6 +83,7 @@ func NewRedeemService( subscriptionService *SubscriptionService, cache RedeemCache, billingCacheService *BillingCacheService, + entClient *dbent.Client, ) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, @@ -88,6 +91,7 @@ func NewRedeemService( subscriptionService: subscriptionService, cache: cache, billingCacheService: billingCacheService, + entClient: entClient, } } @@ -248,9 +252,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( } _ = user // 使用变量避免未使用错误 + // 使用数据库事务保证兑换码标记与权益发放的原子性 + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + // 将事务放入 context,使 repository 方法能够使用同一事务 + txCtx := dbent.NewTxContext(ctx, tx) + // 【关键】先标记兑换码为已使用,确保并发安全 // 利用数据库乐观锁(WHERE status = 'unused')保证原子性 - if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil { + if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil { if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) { return nil, ErrRedeemCodeUsed } @@ -261,21 +275,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( switch redeemCode.Type { case RedeemTypeBalance: // 增加用户余额 - if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil { + if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } - // 失效余额缓存 - if s.billingCacheService != nil { - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) - }() - } case RedeemTypeConcurrency: // 增加用户并发数 - if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil { + if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil { return nil, fmt.Errorf("update user concurrency: %w", err) } @@ -284,7 +290,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( if validityDays <= 0 { validityDays = 30 } - _, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + _, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{ UserID: userID, GroupID: *redeemCode.GroupID, ValidityDays: validityDays, @@ -294,20 +300,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( if err != nil { return nil, fmt.Errorf("assign or extend subscription: %w", err) } - // 失效订阅缓存 - if s.billingCacheService != nil { - groupID := *redeemCode.GroupID - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) - }() - } default: return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type) } + // 提交事务 + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + + // 事务提交成功后失效缓存 + s.invalidateRedeemCaches(ctx, userID, redeemCode) + // 重新获取更新后的兑换码 redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID) if err != nil { @@ -317,6 +322,31 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( return redeemCode, nil } +// invalidateRedeemCaches 失效兑换相关的缓存 +func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { + if s.billingCacheService == nil { + return + } + + switch redeemCode.Type { + case RedeemTypeBalance: + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) + }() + case RedeemTypeSubscription: + if redeemCode.GroupID != nil { + groupID := *redeemCode.GroupID + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + }() + } + } +} + // GetByID 根据ID获取兑换码 func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index fec6c147..09554c0f 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -26,6 +26,7 @@ var ( ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") + ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil") ) // SubscriptionService 订阅服务 diff --git a/backend/migrations/011_remove_duplicate_unique_indexes.sql b/backend/migrations/011_remove_duplicate_unique_indexes.sql new file mode 100644 index 00000000..8fd62710 --- /dev/null +++ b/backend/migrations/011_remove_duplicate_unique_indexes.sql @@ -0,0 +1,39 @@ +-- 011_remove_duplicate_unique_indexes.sql +-- 移除重复的唯一索引 +-- 这些字段在 ent schema 的 Fields() 中已声明 .Unique(), +-- 因此在 Indexes() 中再次声明 index.Fields("x").Unique() 会创建重复索引。 +-- 本迁移脚本清理这些冗余索引。 + +-- 重复索引命名约定(由 Ent 自动生成/历史迁移遗留): +-- - 字段级 Unique() 创建的索引名: __key +-- - Indexes() 中的 Unique() 创建的索引名:
_ +-- - 初始化迁移中的非唯一索引: idx_
_ + +-- 仅当索引存在时才删除(幂等操作) + +-- api_keys 表: key 字段 +DROP INDEX IF EXISTS apikey_key; +DROP INDEX IF EXISTS api_keys_key; +DROP INDEX IF EXISTS idx_api_keys_key; + +-- users 表: email 字段 +DROP INDEX IF EXISTS user_email; +DROP INDEX IF EXISTS users_email; +DROP INDEX IF EXISTS idx_users_email; + +-- settings 表: key 字段 +DROP INDEX IF EXISTS settings_key; +DROP INDEX IF EXISTS idx_settings_key; + +-- redeem_codes 表: code 字段 +DROP INDEX IF EXISTS redeemcode_code; +DROP INDEX IF EXISTS redeem_codes_code; +DROP INDEX IF EXISTS idx_redeem_codes_code; + +-- groups 表: name 字段 +DROP INDEX IF EXISTS group_name; +DROP INDEX IF EXISTS groups_name; +DROP INDEX IF EXISTS idx_groups_name; + +-- 注意: 每个字段的唯一约束仍由字段级 Unique() 创建的约束保留, +-- 如 api_keys_key_key、users_email_key 等。 diff --git a/backend/migrations/012_add_user_subscription_soft_delete.sql b/backend/migrations/012_add_user_subscription_soft_delete.sql new file mode 100644 index 00000000..b6cb7366 --- /dev/null +++ b/backend/migrations/012_add_user_subscription_soft_delete.sql @@ -0,0 +1,13 @@ +-- 012: 为 user_subscriptions 表添加软删除支持 +-- 任务:fix-medium-data-hygiene 1.1 + +-- 添加 deleted_at 字段 +ALTER TABLE user_subscriptions +ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ DEFAULT NULL; + +-- 添加 deleted_at 索引以优化软删除查询 +CREATE INDEX IF NOT EXISTS usersubscription_deleted_at +ON user_subscriptions (deleted_at); + +-- 注释:与其他使用软删除的实体保持一致 +COMMENT ON COLUMN user_subscriptions.deleted_at IS '软删除时间戳,NULL 表示未删除'; diff --git a/backend/migrations/013_log_orphan_allowed_groups.sql b/backend/migrations/013_log_orphan_allowed_groups.sql new file mode 100644 index 00000000..976c0aca --- /dev/null +++ b/backend/migrations/013_log_orphan_allowed_groups.sql @@ -0,0 +1,32 @@ +-- 013: 记录 users.allowed_groups 中的孤立 group_id +-- 任务:fix-medium-data-hygiene 3.1 +-- +-- 目的:在删除 legacy allowed_groups 列前,记录所有引用了不存在 group 的孤立记录 +-- 这些记录可用于审计或后续数据修复 + +-- 创建审计表存储孤立的 allowed_groups 记录 +CREATE TABLE IF NOT EXISTS orphan_allowed_groups_audit ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + group_id BIGINT NOT NULL, + recorded_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (user_id, group_id) +); + +-- 记录孤立的 group_id(存在于 users.allowed_groups 但不存在于 groups 表) +INSERT INTO orphan_allowed_groups_audit (user_id, group_id) +SELECT u.id, x.group_id +FROM users u +CROSS JOIN LATERAL unnest(u.allowed_groups) AS x(group_id) +LEFT JOIN groups g ON g.id = x.group_id +WHERE u.allowed_groups IS NOT NULL + AND g.id IS NULL +ON CONFLICT (user_id, group_id) DO NOTHING; + +-- 添加索引便于查询 +CREATE INDEX IF NOT EXISTS idx_orphan_allowed_groups_audit_user_id +ON orphan_allowed_groups_audit(user_id); + +-- 记录迁移完成信息 +COMMENT ON TABLE orphan_allowed_groups_audit IS +'审计表:记录 users.allowed_groups 中引用的不存在的 group_id,用于数据清理前的审计'; diff --git a/backend/migrations/014_drop_legacy_allowed_groups.sql b/backend/migrations/014_drop_legacy_allowed_groups.sql new file mode 100644 index 00000000..2c2a3d45 --- /dev/null +++ b/backend/migrations/014_drop_legacy_allowed_groups.sql @@ -0,0 +1,15 @@ +-- 014: 删除 legacy users.allowed_groups 列 +-- 任务:fix-medium-data-hygiene 3.3 +-- +-- 前置条件: +-- - 迁移 007 已将数据回填到 user_allowed_groups 联接表 +-- - 迁移 013 已记录所有孤立的 group_id 到审计表 +-- - 应用代码已停止写入该列(3.2 完成) +-- +-- 该列现已废弃,所有读写操作均使用 user_allowed_groups 联接表。 + +-- 删除 allowed_groups 列 +ALTER TABLE users DROP COLUMN IF EXISTS allowed_groups; + +-- 添加注释记录删除原因 +COMMENT ON TABLE users IS '用户表。注:原 allowed_groups BIGINT[] 列已迁移至 user_allowed_groups 联接表'; From dbc0cf33a1641f085854ba9061ee68548652a0e3 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 14:26:20 +0800 Subject: [PATCH 07/49] Merge branch 'main' of https://github.com/mt21625457/aicodex2api --- build_image.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 build_image.sh diff --git a/build_image.sh b/build_image.sh old mode 100644 new mode 100755 From 682f546c0e743718934ac3f0de5abc132ec6ae0e Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 14:51:58 +0800 Subject: [PATCH 08/49] =?UTF-8?q?fix(lint):=20=E4=BF=AE=E5=A4=8D=20golangc?= =?UTF-8?q?i-lint=20=E6=8A=A5=E5=91=8A=E7=9A=84=E4=BB=A3=E7=A0=81=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - errcheck: 修复类型断言未检查返回值的问题 - pool.go: 添加 sync.Map 类型断言安全检查 - req_client_pool.go: 添加 sync.Map 类型断言安全检查 - concurrency_cache_benchmark_test.go: 显式忽略断言返回值 - gateway_service.go: 显式忽略 WriteString 返回值 - gofmt: 修复代码格式问题 - redis.go: 注释对齐 - api_key_repo.go: 结构体字段对齐 - concurrency_cache.go: 字段对齐 - http_upstream.go: 注释对齐 - unused: 删除未使用的代码 - user_repo.go: 删除未使用的 sql 字段 - usage_service.go: 删除未使用的 calculateStats 函数 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/internal/infrastructure/redis.go | 4 +-- backend/internal/pkg/httpclient/pool.go | 9 +++++-- backend/internal/repository/api_key_repo.go | 14 +++++----- .../internal/repository/concurrency_cache.go | 4 +-- .../concurrency_cache_benchmark_test.go | 2 +- backend/internal/repository/http_upstream.go | 2 +- .../internal/repository/req_client_pool.go | 9 +++++-- backend/internal/repository/user_repo.go | 5 ++-- backend/internal/service/gateway_service.go | 2 +- backend/internal/service/usage_service.go | 26 ------------------- 10 files changed, 30 insertions(+), 47 deletions(-) diff --git a/backend/internal/infrastructure/redis.go b/backend/internal/infrastructure/redis.go index 5bb92d19..9f4c8770 100644 --- a/backend/internal/infrastructure/redis.go +++ b/backend/internal/infrastructure/redis.go @@ -33,7 +33,7 @@ func buildRedisOptions(cfg *config.Config) *redis.Options { DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时 ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时 WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时 - PoolSize: cfg.Redis.PoolSize, // 连接池大小 - MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接 + PoolSize: cfg.Redis.PoolSize, // 连接池大小 + MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接 } } diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index f68d50a5..1028fb84 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -58,7 +58,9 @@ var sharedClients sync.Map func GetClient(opts Options) (*http.Client, error) { key := buildClientKey(opts) if cached, ok := sharedClients.Load(key); ok { - return cached.(*http.Client), nil + if client, ok := cached.(*http.Client); ok { + return client, nil + } } client, err := buildClient(opts) @@ -72,7 +74,10 @@ func GetClient(opts Options) (*http.Client, error) { } actual, _ := sharedClients.LoadOrStore(key, client) - return actual.(*http.Client), nil + if c, ok := actual.(*http.Client); ok { + return c, nil + } + return client, nil } func buildClient(opts Options) (*http.Client, error) { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index a3a52333..3ba2fd85 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -311,13 +311,13 @@ func groupEntityToService(g *dbent.Group) *service.Group { return nil } return &service.Group{ - ID: g.ID, - Name: g.Name, - Description: derefString(g.Description), - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, + ID: g.ID, + Name: g.Name, + Description: derefString(g.Description), + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, SubscriptionType: g.SubscriptionType, DailyLimitUSD: g.DailyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd, diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 31527f22..9205230b 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -126,7 +126,7 @@ var ( ) type concurrencyCache struct { - rdb *redis.Client + rdb *redis.Client slotTTLSeconds int // 槽位过期时间(秒) } @@ -137,7 +137,7 @@ func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.Concurre slotTTLMinutes = defaultSlotTTLMinutes } return &concurrencyCache{ - rdb: rdb, + rdb: rdb, slotTTLSeconds: slotTTLMinutes * 60, } } diff --git a/backend/internal/repository/concurrency_cache_benchmark_test.go b/backend/internal/repository/concurrency_cache_benchmark_test.go index 29cc7fbc..cafab9cb 100644 --- a/backend/internal/repository/concurrency_cache_benchmark_test.go +++ b/backend/internal/repository/concurrency_cache_benchmark_test.go @@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) { _ = rdb.Close() }() - cache := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) + cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) ctx := context.Background() for _, size := range []int{10, 100, 1000} { diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 061866b1..180844b5 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -572,7 +572,7 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Tran // trackedBody 带跟踪功能的响应体包装器 // 在 Close 时执行回调,用于更新请求计数 type trackedBody struct { - io.ReadCloser // 原始响应体 + io.ReadCloser // 原始响应体 once sync.Once onClose func() // 关闭时的回调函数 } diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index bfe0ccd2..b23462a4 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -35,7 +35,9 @@ var sharedReqClients sync.Map func getSharedReqClient(opts reqClientOptions) *req.Client { key := buildReqClientKey(opts) if cached, ok := sharedReqClients.Load(key); ok { - return cached.(*req.Client) + if c, ok := cached.(*req.Client); ok { + return c + } } client := req.C().SetTimeout(opts.Timeout) @@ -47,7 +49,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { } actual, _ := sharedReqClients.LoadOrStore(key, client) - return actual.(*req.Client) + if c, ok := actual.(*req.Client); ok { + return c + } + return client } func buildReqClientKey(opts reqClientOptions) string { diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 7294fadc..8393ae7c 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -16,15 +16,14 @@ import ( type userRepository struct { client *dbent.Client - sql sqlExecutor } func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository { return newUserRepositoryWithSQL(client, sqlDB) } -func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository { - return &userRepository{client: client, sql: sqlq} +func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository { + return &userRepository{client: client} } func (r *userRepository) Create(ctx context.Context, userIn *service.User) error { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index dd879da2..d542e9c2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -197,7 +197,7 @@ func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc["type"] == "ephemeral" { if text, ok := partMap["text"].(string); ok { - builder.WriteString(text) + _, _ = builder.WriteString(text) } } } diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index f57e90eb..f653ddfe 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -235,32 +235,6 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int return stats, nil } -// calculateStats 计算统计数据 -func (s *UsageService) calculateStats(logs []UsageLog) *UsageStats { - stats := &UsageStats{} - - for _, log := range logs { - stats.TotalRequests++ - stats.TotalInputTokens += int64(log.InputTokens) - stats.TotalOutputTokens += int64(log.OutputTokens) - stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens) - stats.TotalTokens += int64(log.TotalTokens()) - stats.TotalCost += log.TotalCost - stats.TotalActualCost += log.ActualCost - - if log.DurationMs != nil { - stats.AverageDurationMs += float64(*log.DurationMs) - } - } - - // 计算平均持续时间 - if stats.TotalRequests > 0 { - stats.AverageDurationMs /= float64(stats.TotalRequests) - } - - return stats -} - // Delete 删除使用日志(管理员功能,谨慎使用) func (s *UsageService) Delete(ctx context.Context, id int64) error { if err := s.usageRepo.Delete(ctx, id); err != nil { From d77d0544d0bd646781f94abef18152affac2c035 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 15:20:58 +0800 Subject: [PATCH 09/49] =?UTF-8?q?fix(=E4=BB=93=E5=82=A8):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E5=B9=B6=E5=8F=91=E7=BC=93=E5=AD=98=E5=89=8D=E7=BC=80?= =?UTF-8?q?=E4=B8=8E=E8=BD=AF=E5=88=A0=E9=99=A4=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 补齐 Redis ZSET 前缀处理,确保并发释放计数正确 删除时改用 Client().Mutate 走更新逻辑,保留软删除记录 测试: make test-integration --- backend/ent/schema/mixins/soft_delete.go | 4 +++- backend/internal/repository/integration_harness_test.go | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/backend/ent/schema/mixins/soft_delete.go b/backend/ent/schema/mixins/soft_delete.go index 00ef77a6..9571bc9c 100644 --- a/backend/ent/schema/mixins/soft_delete.go +++ b/backend/ent/schema/mixins/soft_delete.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "entgo.io/ent/schema/mixin" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/intercept" ) @@ -112,6 +113,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook { SetOp(ent.Op) SetDeletedAt(time.Time) WhereP(...func(*sql.Selector)) + Client() *dbent.Client }) if !ok { return nil, fmt.Errorf("unexpected mutation type %T", m) @@ -122,7 +124,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook { mx.SetOp(ent.OpUpdate) // 设置删除时间为当前时间 mx.SetDeletedAt(time.Now()) - return next.Mutate(ctx, m) + return mx.Client().Mutate(ctx, m) }) }, } diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go index 553a581a..6ef447e1 100644 --- a/backend/internal/repository/integration_harness_test.go +++ b/backend/internal/repository/integration_harness_test.go @@ -330,7 +330,8 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) { switch strings.ToLower(cmd.Name()) { case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl", - "hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists": + "hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists", + "zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore": prefixOne(1) case "del", "unlink": for i := 1; i < len(args); i++ { From 6f6dc3032c6dc9aacdf077cd975a757c10823a41 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 15:31:26 +0800 Subject: [PATCH 10/49] =?UTF-8?q?fix(=E8=AE=BE=E7=BD=AE):=20=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E7=AB=99=E7=82=B9=E8=AE=BE=E7=BD=AE=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E5=A4=B1=E8=B4=A5=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题: 1. Setting.value 字段设置了 NotEmpty() 约束,导致保存空字符串值时验证失败 2. 数据库 settings 表缺少 key 字段的唯一约束,导致 ON CONFLICT 语句执行失败 修复: - 移除 ent/schema/setting.go 中 value 字段的 NotEmpty() 约束 - 新增迁移 015_fix_settings_unique_constraint.sql 添加缺失的唯一约束 - 添加3个回归测试确保空值保存功能正常 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/ent/runtime/runtime.go | 4 -- backend/ent/schema/setting.go | 1 - backend/ent/setting/setting.go | 2 - backend/ent/setting_create.go | 5 -- backend/ent/setting_update.go | 10 ---- .../setting_repo_integration_test.go | 56 +++++++++++++++++++ .../015_fix_settings_unique_constraint.sql | 19 +++++++ 7 files changed, 75 insertions(+), 22 deletions(-) create mode 100644 backend/migrations/015_fix_settings_unique_constraint.sql diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index da0accd7..0b254b3e 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -415,10 +415,6 @@ func init() { return nil } }() - // settingDescValue is the schema descriptor for value field. - settingDescValue := settingFields[1].Descriptor() - // setting.ValueValidator is a validator for the "value" field. It is called by the builders before save. - setting.ValueValidator = settingDescValue.Validators[0].(func(string) error) // settingDescUpdatedAt is the schema descriptor for updated_at field. settingDescUpdatedAt := settingFields[2].Descriptor() // setting.DefaultUpdatedAt holds the default value on creation for the updated_at field. diff --git a/backend/ent/schema/setting.go b/backend/ent/schema/setting.go index 3f896fab..0acfde59 100644 --- a/backend/ent/schema/setting.go +++ b/backend/ent/schema/setting.go @@ -36,7 +36,6 @@ func (Setting) Fields() []ent.Field { NotEmpty(). Unique(), field.String("value"). - NotEmpty(). SchemaType(map[string]string{ dialect.Postgres: "text", }), diff --git a/backend/ent/setting/setting.go b/backend/ent/setting/setting.go index feb86b87..79abe970 100644 --- a/backend/ent/setting/setting.go +++ b/backend/ent/setting/setting.go @@ -44,8 +44,6 @@ func ValidColumn(column string) bool { var ( // KeyValidator is a validator for the "key" field. It is called by the builders before save. KeyValidator func(string) error - // ValueValidator is a validator for the "value" field. It is called by the builders before save. - ValueValidator func(string) error // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. DefaultUpdatedAt func() time.Time // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. diff --git a/backend/ent/setting_create.go b/backend/ent/setting_create.go index 66c1231e..553261e7 100644 --- a/backend/ent/setting_create.go +++ b/backend/ent/setting_create.go @@ -102,11 +102,6 @@ func (_c *SettingCreate) check() error { if _, ok := _c.mutation.Value(); !ok { return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "Setting.value"`)} } - if v, ok := _c.mutation.Value(); ok { - if err := setting.ValueValidator(v); err != nil { - return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Setting.value": %w`, err)} - } - } if _, ok := _c.mutation.UpdatedAt(); !ok { return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "Setting.updated_at"`)} } diff --git a/backend/ent/setting_update.go b/backend/ent/setting_update.go index 007fa36e..42d016d6 100644 --- a/backend/ent/setting_update.go +++ b/backend/ent/setting_update.go @@ -110,11 +110,6 @@ func (_u *SettingUpdate) check() error { return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "Setting.key": %w`, err)} } } - if v, ok := _u.mutation.Value(); ok { - if err := setting.ValueValidator(v); err != nil { - return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Setting.value": %w`, err)} - } - } return nil } @@ -254,11 +249,6 @@ func (_u *SettingUpdateOne) check() error { return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "Setting.key": %w`, err)} } } - if v, ok := _u.mutation.Value(); ok { - if err := setting.ValueValidator(v); err != nil { - return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "Setting.value": %w`, err)} - } - } return nil } diff --git a/backend/internal/repository/setting_repo_integration_test.go b/backend/internal/repository/setting_repo_integration_test.go index 784124f4..f91c0651 100644 --- a/backend/internal/repository/setting_repo_integration_test.go +++ b/backend/internal/repository/setting_repo_integration_test.go @@ -105,3 +105,59 @@ func (s *SettingRepoSuite) TestSetMultiple_Upsert() { s.Require().NoError(err) s.Require().Equal("new_val", got2) } + +// TestSet_EmptyValue 测试保存空字符串值 +// 这是一个回归测试,确保可选设置(如站点Logo、API端点地址等)可以保存为空字符串 +func (s *SettingRepoSuite) TestSet_EmptyValue() { + // 测试 Set 方法保存空值 + s.Require().NoError(s.repo.Set(s.ctx, "empty_key", ""), "Set with empty value should succeed") + + got, err := s.repo.GetValue(s.ctx, "empty_key") + s.Require().NoError(err, "GetValue for empty value") + s.Require().Equal("", got, "empty value should be preserved") +} + +// TestSetMultiple_WithEmptyValues 测试批量保存包含空字符串的设置 +// 模拟用户保存站点设置时部分字段为空的场景 +func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() { + // 模拟保存站点设置,部分字段有值,部分字段为空 + settings := map[string]string{ + "site_name": "AICodex2API", + "site_subtitle": "Subscription to API", + "site_logo": "", // 用户未上传Logo + "api_base_url": "", // 用户未设置API地址 + "contact_info": "", // 用户未设置联系方式 + "doc_url": "", // 用户未设置文档链接 + } + + s.Require().NoError(s.repo.SetMultiple(s.ctx, settings), "SetMultiple with empty values should succeed") + + // 验证所有值都正确保存 + result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"}) + s.Require().NoError(err, "GetMultiple after SetMultiple with empty values") + + s.Require().Equal("AICodex2API", result["site_name"]) + s.Require().Equal("Subscription to API", result["site_subtitle"]) + s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved") + s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved") + s.Require().Equal("", result["contact_info"], "empty contact_info should be preserved") + s.Require().Equal("", result["doc_url"], "empty doc_url should be preserved") +} + +// TestSetMultiple_UpdateToEmpty 测试将已有值更新为空字符串 +// 确保用户可以清空之前设置的值 +func (s *SettingRepoSuite) TestSetMultiple_UpdateToEmpty() { + // 先设置非空值 + s.Require().NoError(s.repo.Set(s.ctx, "clearable_key", "initial_value")) + + got, err := s.repo.GetValue(s.ctx, "clearable_key") + s.Require().NoError(err) + s.Require().Equal("initial_value", got) + + // 更新为空值 + s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"clearable_key": ""}), "Update to empty should succeed") + + got, err = s.repo.GetValue(s.ctx, "clearable_key") + s.Require().NoError(err) + s.Require().Equal("", got, "value should be updated to empty string") +} diff --git a/backend/migrations/015_fix_settings_unique_constraint.sql b/backend/migrations/015_fix_settings_unique_constraint.sql new file mode 100644 index 00000000..60f8fcad --- /dev/null +++ b/backend/migrations/015_fix_settings_unique_constraint.sql @@ -0,0 +1,19 @@ +-- 015_fix_settings_unique_constraint.sql +-- 修复 settings 表 key 字段缺失的唯一约束 +-- 此约束是 ON CONFLICT ("key") DO UPDATE 语句所必需的 + +-- 检查并添加唯一约束(如果不存在) +DO $$ +BEGIN + -- 检查是否已存在唯一约束 + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint + WHERE conrelid = 'settings'::regclass + AND contype = 'u' + AND conname = 'settings_key_key' + ) THEN + -- 添加唯一约束 + ALTER TABLE settings ADD CONSTRAINT settings_key_key UNIQUE (key); + END IF; +END +$$; From aac7dd6b08c74a9fc4f5fc1c3fbec607b49f955d Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 15:52:02 +0800 Subject: [PATCH 11/49] style: fix gofmt formatting in test file Remove redundant alignment whitespace before comments. --- .../internal/repository/setting_repo_integration_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/internal/repository/setting_repo_integration_test.go b/backend/internal/repository/setting_repo_integration_test.go index f91c0651..147313d6 100644 --- a/backend/internal/repository/setting_repo_integration_test.go +++ b/backend/internal/repository/setting_repo_integration_test.go @@ -124,10 +124,10 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() { settings := map[string]string{ "site_name": "AICodex2API", "site_subtitle": "Subscription to API", - "site_logo": "", // 用户未上传Logo - "api_base_url": "", // 用户未设置API地址 - "contact_info": "", // 用户未设置联系方式 - "doc_url": "", // 用户未设置文档链接 + "site_logo": "", // 用户未上传Logo + "api_base_url": "", // 用户未设置API地址 + "contact_info": "", // 用户未设置联系方式 + "doc_url": "", // 用户未设置文档链接 } s.Require().NoError(s.repo.SetMultiple(s.ctx, settings), "SetMultiple with empty values should succeed") From 1ef4f09df516c7a28ed1070185ee16ebac4f737f Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 16:17:45 +0800 Subject: [PATCH 12/49] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=20model=20=E5=8F=82=E6=95=B0=E5=BF=85=E5=A1=AB?= =?UTF-8?q?=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在以下端点添加 model 参数的必填验证,缺失时直接返回 400 错误: - /v1/messages - /v1/messages/count_tokens - /openai/v1/responses 修复前:空 model 会进入账号选择流程,最终由上游 API 返回错误 修复后:入口处直接拒绝,避免浪费资源和不明确的错误信息 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/internal/handler/gateway_handler.go | 12 ++++++++++++ backend/internal/handler/openai_gateway_handler.go | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index fc92b2d8..a2f833ff 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -88,6 +88,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqModel := parsedReq.Model reqStream := parsedReq.Stream + // 验证 model 必填 + if reqModel == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + // Track if we've started streaming (for error handling) streamStarted := false @@ -517,6 +523,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + // 验证 model 必填 + if parsedReq.Model == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + // 获取订阅信息(可能为nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 7fcb329d..7c9934c6 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -80,6 +80,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { reqModel, _ := reqBody["model"].(string) reqStream, _ := reqBody["stream"].(bool) + // 验证 model 必填 + if reqModel == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + // For non-Codex CLI requests, set default instructions userAgent := c.GetHeader("User-Agent") if !openai.IsCodexCLIRequest(userAgent) { From 81213f23241d68007c73a2d15f5d36369a3e28f4 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 16:25:45 +0800 Subject: [PATCH 13/49] =?UTF-8?q?refactor(service):=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E6=97=B6=E9=97=B4=E6=88=B3=E8=A7=A3=E6=9E=90=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=A4=9A=E7=A7=8D=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 Account.GetCredentialAsTime 方法,统一处理凭证中的时间戳字段, 兼容 RFC3339 字符串、Unix 时间戳字符串和数字类型。 - 重构 Claude/Gemini/Antigravity TokenRefresher.NeedsRefresh - 移除重复的 parseExpiresAt/parseAntigravityExpiresAt 函数 - 简化 GetOpenAITokenExpiresAt 实现 - 新增 RFC3339 格式单元测试用例 --- backend/internal/service/account.go | 36 ++++++++++++------- .../internal/service/account_test_service.go | 6 ++-- .../service/antigravity_quota_refresher.go | 2 +- .../service/antigravity_token_provider.go | 21 ++--------- .../service/antigravity_token_refresher.go | 12 ++----- .../internal/service/gemini_token_provider.go | 21 ++--------- .../service/gemini_token_refresher.go | 12 ++----- backend/internal/service/token_refresher.go | 12 ++----- .../internal/service/token_refresher_test.go | 14 ++++++++ 9 files changed, 55 insertions(+), 81 deletions(-) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index bfe3822c..5d461b9c 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -110,6 +110,28 @@ func (a *Account) GetCredential(key string) string { } } +// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式 +// 兼容以下格式: +// - RFC3339 字符串: "2025-01-01T00:00:00Z" +// - Unix 时间戳字符串: "1735689600" +// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number) +func (a *Account) GetCredentialAsTime(key string) *time.Time { + s := a.GetCredential(key) + if s == "" { + return nil + } + // 尝试 RFC3339 格式 + if t, err := time.Parse(time.RFC3339, s); err == nil { + return &t + } + // 尝试 Unix 时间戳(纯数字字符串) + if ts, err := strconv.ParseInt(s, 10, 64); err == nil { + t := time.Unix(ts, 0) + return &t + } + return nil +} + func (a *Account) GetModelMapping() map[string]string { if a.Credentials == nil { return nil @@ -324,19 +346,7 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time { if !a.IsOpenAIOAuth() { return nil } - expiresAtStr := a.GetCredential("expires_at") - if expiresAtStr == "" { - return nil - } - t, err := time.Parse(time.RFC3339, expiresAtStr) - if err != nil { - if v, ok := a.Credentials["expires_at"].(float64); ok { - tt := time.Unix(int64(v), 0) - return &tt - } - return nil - } - return &t + return a.GetCredentialAsTime("expires_at") } func (a *Account) IsOpenAITokenExpired() bool { diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 318be8b8..7dd451cd 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,7 +12,6 @@ import ( "log" "net/http" "regexp" - "strconv" "strings" "time" @@ -187,9 +186,8 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Check if token needs refresh needRefresh := false - if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" { - expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) - if err == nil && time.Now().Unix()+300 > expiresAt { + if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil { + if time.Now().Add(5 * time.Minute).After(*expiresAt) { needRefresh = true } } diff --git a/backend/internal/service/antigravity_quota_refresher.go b/backend/internal/service/antigravity_quota_refresher.go index dd579ef1..c4b11d73 100644 --- a/backend/internal/service/antigravity_quota_refresher.go +++ b/backend/internal/service/antigravity_quota_refresher.go @@ -191,7 +191,7 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc // isTokenExpired 检查 token 是否过期 func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool { - expiresAt := parseAntigravityExpiresAt(account) + expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { return false } diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index efd3e15f..cbd1bef4 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -55,7 +55,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * } // 2. 如果即将过期则刷新 - expiresAt := parseAntigravityExpiresAt(account) + expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew if needsRefresh && p.tokenCache != nil { locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) @@ -72,7 +72,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * if err == nil && fresh != nil { account = fresh } - expiresAt = parseAntigravityExpiresAt(account) + expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew { if p.antigravityOAuthService == nil { return "", errors.New("antigravity oauth service not configured") @@ -91,7 +91,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) } - expiresAt = parseAntigravityExpiresAt(account) + expiresAt = account.GetCredentialAsTime("expires_at") } } } @@ -128,18 +128,3 @@ func antigravityTokenCacheKey(account *Account) string { } return "ag:account:" + strconv.FormatInt(account.ID, 10) } - -func parseAntigravityExpiresAt(account *Account) *time.Time { - raw := strings.TrimSpace(account.GetCredential("expires_at")) - if raw == "" { - return nil - } - if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 { - t := time.Unix(unixSec, 0) - return &t - } - if t, err := time.Parse(time.RFC3339, raw); err == nil { - return &t - } - return nil -} diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index 8ee2d25c..b4739025 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -2,7 +2,6 @@ package service import ( "context" - "strconv" "time" ) @@ -34,16 +33,11 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati if !r.CanRefresh(account) { return false } - expiresAtStr := account.GetCredential("expires_at") - if expiresAtStr == "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt == nil { return false } - expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) - if err != nil { - return false - } - expiryTime := time.Unix(expiresAt, 0) - return time.Until(expiryTime) < antigravityRefreshWindow + return time.Until(*expiresAt) < antigravityRefreshWindow } // Refresh 执行 token 刷新 diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index f587b500..2195ec55 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -50,7 +50,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } // 2) Refresh if needed (pre-expiry skew). - expiresAt := parseExpiresAt(account) + expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew if needsRefresh && p.tokenCache != nil { locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) @@ -66,7 +66,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou if err == nil && fresh != nil { account = fresh } - expiresAt = parseExpiresAt(account) + expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew { if p.geminiOAuthService == nil { return "", errors.New("gemini oauth service not configured") @@ -83,7 +83,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } account.Credentials = newCredentials _ = p.accountRepo.Update(ctx, account) - expiresAt = parseExpiresAt(account) + expiresAt = account.GetCredentialAsTime("expires_at") } } } @@ -154,18 +154,3 @@ func geminiTokenCacheKey(account *Account) string { } return "account:" + strconv.FormatInt(account.ID, 10) } - -func parseExpiresAt(account *Account) *time.Time { - raw := strings.TrimSpace(account.GetCredential("expires_at")) - if raw == "" { - return nil - } - if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 { - t := time.Unix(unixSec, 0) - return &t - } - if t, err := time.Parse(time.RFC3339, raw); err == nil { - return &t - } - return nil -} diff --git a/backend/internal/service/gemini_token_refresher.go b/backend/internal/service/gemini_token_refresher.go index 19ba9424..7dfc5521 100644 --- a/backend/internal/service/gemini_token_refresher.go +++ b/backend/internal/service/gemini_token_refresher.go @@ -2,7 +2,6 @@ package service import ( "context" - "strconv" "time" ) @@ -22,16 +21,11 @@ func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time if !r.CanRefresh(account) { return false } - expiresAtStr := account.GetCredential("expires_at") - if expiresAtStr == "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt == nil { return false } - expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) - if err != nil { - return false - } - expiryTime := time.Unix(expiresAt, 0) - return time.Until(expiryTime) < refreshWindow + return time.Until(*expiresAt) < refreshWindow } func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) { diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 2ae3c822..214a290a 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -43,17 +43,11 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool { // NeedsRefresh 检查token是否需要刷新 // 基于 expires_at 字段判断是否在刷新窗口内 func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { - s := account.GetCredential("expires_at") - if s == "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt == nil { return false } - - expiresAt, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return false - } - - return time.Until(time.Unix(expiresAt, 0)) < refreshWindow + return time.Until(*expiresAt) < refreshWindow } // Refresh 执行token刷新 diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go index c00fcfa3..0a5135ac 100644 --- a/backend/internal/service/token_refresher_test.go +++ b/backend/internal/service/token_refresher_test.go @@ -33,6 +33,13 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { }, wantRefresh: true, }, + { + name: "expires_at as RFC3339 - expired", + credentials: map[string]any{ + "expires_at": "1970-01-01T00:00:00Z", // RFC3339 格式,已过期 + }, + wantRefresh: true, + }, { name: "expires_at as string - far future", credentials: map[string]any{ @@ -47,6 +54,13 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { }, wantRefresh: false, }, + { + name: "expires_at as RFC3339 - far future", + credentials: map[string]any{ + "expires_at": "2099-12-31T23:59:59Z", // RFC3339 格式,远未来 + }, + wantRefresh: false, + }, { name: "expires_at missing", credentials: map[string]any{}, From 59269dc1c1a397fa974428ea85610241f3ec1afa Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 31 Dec 2025 16:37:18 +0800 Subject: [PATCH 14/49] =?UTF-8?q?fix(=E6=95=B0=E6=8D=AE=E5=B1=82):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=BD=AF=E5=88=A0=E9=99=A4=E4=B8=8E=E5=94=AF?= =?UTF-8?q?=E4=B8=80=E7=BA=A6=E6=9D=9F=E5=86=B2=E7=AA=81=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题:软删除的记录仍占用唯一约束位置,导致删后无法重建同名/同邮箱/同订阅 修复方案:使用 PostgreSQL 部分唯一索引(WHERE deleted_at IS NULL) - User.email: 移除字段级 Unique(),改用部分唯一索引 - Group.name: 移除字段级 Unique(),改用部分唯一索引 - UserSubscription.(user_id, group_id): 移除组合唯一索引,改用部分唯一索引 - ApiKey.key: 保留普通唯一约束(安全考虑,已删除的 Key 不应重用) 安全性: - 应用层已有 ExistsByXxx 检查,自动过滤软删除记录 - 数据库层部分唯一索引提供最后一道防线 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/ent/migrate/schema.go | 6 +-- backend/ent/schema/group.go | 5 +- backend/ent/schema/user.go | 5 +- backend/ent/schema/user_subscription.go | 4 +- ...016_soft_delete_partial_unique_indexes.sql | 51 +++++++++++++++++++ 5 files changed, 63 insertions(+), 8 deletions(-) create mode 100644 backend/migrations/016_soft_delete_partial_unique_indexes.sql diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 848ac74c..c9a1675e 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -204,7 +204,7 @@ var ( {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, - {Name: "name", Type: field.TypeString, Unique: true, Size: 100}, + {Name: "name", Type: field.TypeString, Size: 100}, {Name: "description", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, {Name: "is_exclusive", Type: field.TypeBool, Default: false}, @@ -470,7 +470,7 @@ var ( {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "deleted_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, - {Name: "email", Type: field.TypeString, Unique: true, Size: 255}, + {Name: "email", Type: field.TypeString, Size: 255}, {Name: "password_hash", Type: field.TypeString, Size: 255}, {Name: "role", Type: field.TypeString, Size: 20, Default: "user"}, {Name: "balance", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -605,7 +605,7 @@ var ( }, { Name: "usersubscription_user_id_group_id", - Unique: true, + Unique: false, Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[15]}, }, { diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 7f3ed167..7a8a5345 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -33,10 +33,11 @@ func (Group) Mixin() []ent.Mixin { func (Group) Fields() []ent.Field { return []ent.Field{ + // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用 + // 见迁移文件 016_soft_delete_partial_unique_indexes.sql field.String("name"). MaxLen(100). - NotEmpty(). - Unique(), + NotEmpty(), field.String("description"). Optional(). Nillable(). diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index ba7f0ce7..c1f742d1 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -33,10 +33,11 @@ func (User) Mixin() []ent.Mixin { func (User) Fields() []ent.Field { return []ent.Field{ + // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用 + // 见迁移文件 016_soft_delete_partial_unique_indexes.sql field.String("email"). MaxLen(255). - NotEmpty(). - Unique(), + NotEmpty(), field.String("password_hash"). MaxLen(255). NotEmpty(), diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index 88c4ea8f..b21f4083 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -109,7 +109,9 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("status"), index.Fields("expires_at"), index.Fields("assigned_by"), - index.Fields("user_id", "group_id").Unique(), + // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅 + // 见迁移文件 016_soft_delete_partial_unique_indexes.sql + index.Fields("user_id", "group_id"), index.Fields("deleted_at"), } } diff --git a/backend/migrations/016_soft_delete_partial_unique_indexes.sql b/backend/migrations/016_soft_delete_partial_unique_indexes.sql new file mode 100644 index 00000000..b006b775 --- /dev/null +++ b/backend/migrations/016_soft_delete_partial_unique_indexes.sql @@ -0,0 +1,51 @@ +-- 016_soft_delete_partial_unique_indexes.sql +-- 修复软删除 + 唯一约束冲突问题 +-- 将普通唯一约束替换为部分唯一索引(WHERE deleted_at IS NULL) +-- 这样软删除的记录不会占用唯一约束位置,允许删后重建同名/同邮箱/同订阅关系 + +-- ============================================================================ +-- 1. users 表: email 字段 +-- ============================================================================ + +-- 删除旧的唯一约束(可能的命名方式) +ALTER TABLE users DROP CONSTRAINT IF EXISTS users_email_key; +DROP INDEX IF EXISTS users_email_key; +DROP INDEX IF EXISTS user_email_key; + +-- 创建部分唯一索引:只对未删除的记录建立唯一约束 +CREATE UNIQUE INDEX IF NOT EXISTS users_email_unique_active + ON users(email) + WHERE deleted_at IS NULL; + +-- ============================================================================ +-- 2. groups 表: name 字段 +-- ============================================================================ + +-- 删除旧的唯一约束 +ALTER TABLE groups DROP CONSTRAINT IF EXISTS groups_name_key; +DROP INDEX IF EXISTS groups_name_key; +DROP INDEX IF EXISTS group_name_key; + +-- 创建部分唯一索引 +CREATE UNIQUE INDEX IF NOT EXISTS groups_name_unique_active + ON groups(name) + WHERE deleted_at IS NULL; + +-- ============================================================================ +-- 3. user_subscriptions 表: (user_id, group_id) 组合字段 +-- ============================================================================ + +-- 删除旧的唯一约束/索引 +ALTER TABLE user_subscriptions DROP CONSTRAINT IF EXISTS user_subscriptions_user_id_group_id_key; +DROP INDEX IF EXISTS user_subscriptions_user_id_group_id_key; +DROP INDEX IF EXISTS usersubscription_user_id_group_id; + +-- 创建部分唯一索引 +CREATE UNIQUE INDEX IF NOT EXISTS user_subscriptions_user_group_unique_active + ON user_subscriptions(user_id, group_id) + WHERE deleted_at IS NULL; + +-- ============================================================================ +-- 注意: api_keys 表的 key 字段保留普通唯一约束 +-- API Key 即使软删除后也不应该重复使用(安全考虑) +-- ============================================================================ From 9aeef15d1b8f26f3a2018d5735f6a620289d3c47 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 17:12:09 +0800 Subject: [PATCH 15/49] =?UTF-8?q?fix(ci):=20GHCR=20=E9=95=9C=E5=83=8F?= =?UTF-8?q?=E5=90=8D=E8=BD=AC=E4=B8=BA=E5=B0=8F=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/release.yml | 2 +- .goreleaser.yaml | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f8976d93..55996bdf 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -168,7 +168,7 @@ jobs: VERSION=${TAG_NAME#v} REPO="${{ github.repository }}" DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api" - GHCR_IMAGE="ghcr.io/${REPO}" + GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase # 获取 tag message 内容 TAG_MESSAGE='${{ steps.tag_message.outputs.message }}' diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 95b66f8f..5b855724 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -78,12 +78,12 @@ dockers: - "--label=org.opencontainers.image.version={{ .Version }}" - "--label=org.opencontainers.image.revision={{ .Commit }}" - # GHCR images + # GHCR images (owner must be lowercase) - id: ghcr-amd64 goos: linux goarch: amd64 image_templates: - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" dockerfile: Dockerfile.goreleaser use: buildx build_flag_templates: @@ -96,7 +96,7 @@ dockers: goos: linux goarch: arm64 image_templates: - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" dockerfile: Dockerfile.goreleaser use: buildx build_flag_templates: @@ -127,26 +127,26 @@ docker_manifests: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64" - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" - # GHCR manifests - - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}" + # GHCR manifests (owner must be lowercase) + - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}" image_templates: - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" - - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:latest" + - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:latest" image_templates: - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" - - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}.{{ .Minor }}" + - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}.{{ .Minor }}" image_templates: - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" - - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}" + - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}" image_templates: - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" release: github: @@ -173,7 +173,7 @@ release: docker pull {{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }} # GitHub Container Registry - docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }} + docker pull ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }} ``` **One-line install (Linux):** From 3fd9bd4a80da44b2ec09b1732797b9c08fe988eb Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 17:25:43 +0800 Subject: [PATCH 16/49] =?UTF-8?q?fix(ci):=20=E4=BD=BF=E7=94=A8=E9=A2=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=E7=9A=84=E5=B0=8F=E5=86=99=20owner=20?= =?UTF-8?q?=E6=9B=BF=E4=BB=A3=20lower=20=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit GoReleaser 不支持 lower 模板函数,改为: - 在 GitHub Actions 中预处理小写 owner - 传递 GITHUB_REPO_OWNER_LOWER 环境变量给 GoReleaser --- .github/workflows/release.yml | 5 +++++ .goreleaser.yaml | 30 +++++++++++++++--------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 55996bdf..d20ed0c8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -129,6 +129,10 @@ jobs: echo "$TAG_MESSAGE" >> $GITHUB_OUTPUT echo "EOF" >> $GITHUB_OUTPUT + - name: Set lowercase owner for GHCR + id: lowercase + run: echo "owner=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_OUTPUT + - name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 with: @@ -138,6 +142,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} TAG_MESSAGE: ${{ steps.tag_message.outputs.message }} GITHUB_REPO_OWNER: ${{ github.repository_owner }} + GITHUB_REPO_OWNER_LOWER: ${{ steps.lowercase.outputs.owner }} GITHUB_REPO_NAME: ${{ github.event.repository.name }} DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 5b855724..c72f7422 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -83,7 +83,7 @@ dockers: goos: linux goarch: amd64 image_templates: - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64" dockerfile: Dockerfile.goreleaser use: buildx build_flag_templates: @@ -96,7 +96,7 @@ dockers: goos: linux goarch: arm64 image_templates: - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64" dockerfile: Dockerfile.goreleaser use: buildx build_flag_templates: @@ -128,25 +128,25 @@ docker_manifests: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" # GHCR manifests (owner must be lowercase) - - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}" + - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}" image_templates: - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64" - - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:latest" + - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest" image_templates: - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64" - - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}.{{ .Minor }}" + - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Major }}.{{ .Minor }}" image_templates: - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64" - - name_template: "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Major }}" + - name_template: "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Major }}" image_templates: - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-amd64" - - "ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }}-arm64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-arm64" release: github: @@ -173,7 +173,7 @@ release: docker pull {{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }} # GitHub Container Registry - docker pull ghcr.io/{{ lower .Env.GITHUB_REPO_OWNER }}/sub2api:{{ .Version }} + docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }} ``` **One-line install (Linux):** From 2c35f0276f4cb57151f36ade7e5e2cba186c2551 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 20:46:54 +0800 Subject: [PATCH 17/49] =?UTF-8?q?fix(frontend):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=97=A0=E9=99=90=E5=88=B6=E8=AE=A2=E9=98=85=E7=9A=84=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../common/SubscriptionProgressMini.vue | 188 ++++++++++-------- frontend/src/i18n/locales/en.ts | 5 +- frontend/src/i18n/locales/zh.ts | 5 +- .../src/views/admin/SubscriptionsView.vue | 9 +- frontend/src/views/user/SubscriptionsView.vue | 18 +- 5 files changed, 134 insertions(+), 91 deletions(-) diff --git a/frontend/src/components/common/SubscriptionProgressMini.vue b/frontend/src/components/common/SubscriptionProgressMini.vue index b84175e9..92198c2c 100644 --- a/frontend/src/components/common/SubscriptionProgressMini.vue +++ b/frontend/src/components/common/SubscriptionProgressMini.vue @@ -69,94 +69,108 @@ - +
-
- {{ - t('subscriptionProgress.daily') - }} -
-
-
- - {{ - formatUsage(subscription.daily_usage_usd, subscription.group?.daily_limit_usd) - }} + +
+ + + {{ t('subscriptionProgress.unlimited') }}
-
- {{ - t('subscriptionProgress.weekly') - }} -
-
+ +
@@ -215,7 +229,19 @@ function getMaxUsagePercentage(sub: UserSubscription): number { return percentages.length > 0 ? Math.max(...percentages) : 0 } +function isUnlimited(sub: UserSubscription): boolean { + return ( + !sub.group?.daily_limit_usd && + !sub.group?.weekly_limit_usd && + !sub.group?.monthly_limit_usd + ) +} + function getProgressDotClass(sub: UserSubscription): string { + // Unlimited subscriptions get a special color + if (isUnlimited(sub)) { + return 'bg-emerald-500' + } const maxPercentage = getMaxUsagePercentage(sub) if (maxPercentage >= 90) return 'bg-red-500' if (maxPercentage >= 70) return 'bg-orange-500' diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index d153b553..6d1095cf 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -749,6 +749,7 @@ export default { weekly: 'Weekly', monthly: 'Monthly', noLimits: 'No limits configured', + unlimited: 'Unlimited', resetNow: 'Resetting soon', windowNotActive: 'Window not active', resetInMinutes: 'Resets in {minutes}m', @@ -1492,7 +1493,8 @@ export default { expiresToday: 'Expires today', expiresTomorrow: 'Expires tomorrow', viewAll: 'View all subscriptions', - noSubscriptions: 'No active subscriptions' + noSubscriptions: 'No active subscriptions', + unlimited: 'Unlimited' }, // Version Badge @@ -1535,6 +1537,7 @@ export default { expires: 'Expires', noExpiration: 'No expiration', unlimited: 'Unlimited', + unlimitedDesc: 'No usage limits on this subscription', daily: 'Daily', weekly: 'Weekly', monthly: 'Monthly', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index c6105683..97d57051 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -840,6 +840,7 @@ export default { weekly: '每周', monthly: '每月', noLimits: '未配置限额', + unlimited: '无限制', resetNow: '即将重置', windowNotActive: '窗口未激活', resetInMinutes: '{minutes} 分钟后重置', @@ -1689,7 +1690,8 @@ export default { expiresToday: '今天到期', expiresTomorrow: '明天到期', viewAll: '查看全部订阅', - noSubscriptions: '暂无有效订阅' + noSubscriptions: '暂无有效订阅', + unlimited: '无限制' }, // Version Badge @@ -1731,6 +1733,7 @@ export default { expires: '到期时间', noExpiration: '无到期时间', unlimited: '无限制', + unlimitedDesc: '该订阅无用量限制', daily: '每日', weekly: '每周', monthly: '每月', diff --git a/frontend/src/views/admin/SubscriptionsView.vue b/frontend/src/views/admin/SubscriptionsView.vue index bd6a17eb..679c3275 100644 --- a/frontend/src/views/admin/SubscriptionsView.vue +++ b/frontend/src/views/admin/SubscriptionsView.vue @@ -202,16 +202,19 @@
- +
- {{ t('admin.subscriptions.noLimits') }} + + + {{ t('admin.subscriptions.unlimited') }} +
diff --git a/frontend/src/views/user/SubscriptionsView.vue b/frontend/src/views/user/SubscriptionsView.vue index dc93a9c1..b03b665a 100644 --- a/frontend/src/views/user/SubscriptionsView.vue +++ b/frontend/src/views/user/SubscriptionsView.vue @@ -230,18 +230,26 @@

- +
- {{ - t('userSubscriptions.unlimited') - }} +
+ +
+

+ {{ t('userSubscriptions.unlimited') }} +

+

+ {{ t('userSubscriptions.unlimitedDesc') }} +

+
+
From 15e676e9cd7250880e56e092976e8832abf82a67 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Wed, 31 Dec 2025 20:56:38 +0800 Subject: [PATCH 18/49] =?UTF-8?q?fix(upstream):=20=E6=94=AF=E6=8C=81=20Cla?= =?UTF-8?q?ude=20custom=20=E7=B1=BB=E5=9E=8B=E5=B7=A5=E5=85=B7=20(MCP)=20?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ClaudeTool 结构体增加 Type 和 Custom 字段 - buildTools 函数支持从 custom 字段读取 input_schema - convertClaudeToolsToGeminiTools 函数支持 MCP 工具格式 - 修复 Antigravity upstream error 400: JSON schema invalid 修复 Issue 0.2: tools.X.custom.input_schema 验证错误 --- .../internal/pkg/antigravity/claude_types.go | 13 +++++++++- .../pkg/antigravity/request_transformer.go | 18 +++++++++++-- .../service/gemini_messages_compat_service.go | 26 ++++++++++++++++--- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 9cab4cea..f394d7e3 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -37,8 +37,19 @@ type ClaudeMetadata struct { } // ClaudeTool Claude 工具定义 +// 支持两种格式: +// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} } +// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } } type ClaudeTool struct { - Name string `json:"name"` + Type string `json:"type,omitempty"` // "custom" 或空(标准格式) + Name string `json:"name"` + Description string `json:"description,omitempty"` // 标准格式使用 + InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用 + Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用 +} + +// CustomToolSpec MCP custom 工具规格 +type CustomToolSpec struct { Description string `json:"description,omitempty"` InputSchema map[string]any `json:"input_schema"` } diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 2ff0ec02..51eb4299 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -379,12 +379,26 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl for _, tool := range tools { + var description string + var inputSchema map[string]any + + // 检查是否为 custom 类型工具 (MCP) + if tool.Type == "custom" && tool.Custom != nil { + // Custom 格式: 从 custom 字段获取 description 和 input_schema + description = tool.Custom.Description + inputSchema = tool.Custom.InputSchema + } else { + // 标准格式: 从顶层字段获取 + description = tool.Description + inputSchema = tool.InputSchema + } + // 清理 JSON Schema - params := cleanJSONSchema(tool.InputSchema) + params := cleanJSONSchema(inputSchema) funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, - Description: tool.Description, + Description: description, Parameters: params, }) } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index ee3ade16..e55d798a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2245,12 +2245,32 @@ func convertClaudeToolsToGeminiTools(tools any) []any { if !ok { continue } - name, _ := tm["name"].(string) - desc, _ := tm["description"].(string) - params := tm["input_schema"] + + var name, desc string + var params any + + // 检查是否为 custom 类型工具 (MCP) + toolType, _ := tm["type"].(string) + if toolType == "custom" { + // Custom 格式: 从 custom 字段获取 description 和 input_schema + custom, ok := tm["custom"].(map[string]any) + if !ok { + continue + } + name, _ = tm["name"].(string) + desc, _ = custom["description"].(string) + params = custom["input_schema"] + } else { + // 标准格式: 从顶层字段获取 + name, _ = tm["name"].(string) + desc, _ = tm["description"].(string) + params = tm["input_schema"] + } + if name == "" { continue } + funcDecls = append(funcDecls, map[string]any{ "name": name, "description": desc, From 0b6371174e18ab03848b1358566a85694f46a2eb Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 21:05:33 +0800 Subject: [PATCH 19/49] =?UTF-8?q?fix(settings):=20=E4=BF=9D=E5=AD=98=20Tur?= =?UTF-8?q?nstile=20=E8=AE=BE=E7=BD=AE=E6=97=B6=E9=AA=8C=E8=AF=81=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E6=9C=89=E6=95=88=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 2 +- .../internal/handler/admin/setting_handler.go | 42 ++++++++++++++++--- backend/internal/server/api_contract_test.go | 2 +- backend/internal/service/turnstile_service.go | 20 +++++++++ 4 files changed, 59 insertions(+), 7 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index d469dcbb..c4859383 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -109,7 +109,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) proxyHandler := admin.NewProxyHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService) - settingHandler := admin.NewSettingHandler(settingService, emailService) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.NewGitHubReleaseClient() serviceBuildInfo := provideServiceBuildInfo(buildInfo) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 14b569de..e533aef1 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -10,15 +10,17 @@ import ( // SettingHandler 系统设置处理器 type SettingHandler struct { - settingService *service.SettingService - emailService *service.EmailService + settingService *service.SettingService + emailService *service.EmailService + turnstileService *service.TurnstileService } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler { return &SettingHandler{ - settingService: settingService, - emailService: emailService, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, } } @@ -108,6 +110,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.SmtpPort = 587 } + // Turnstile 参数验证 + if req.TurnstileEnabled { + // 检查必填字段 + if req.TurnstileSiteKey == "" { + response.BadRequest(c, "Turnstile Site Key is required when enabled") + return + } + if req.TurnstileSecretKey == "" { + response.BadRequest(c, "Turnstile Secret Key is required when enabled") + return + } + + // 获取当前设置,检查参数是否有变化 + currentSettings, err := h.settingService.GetAllSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录) + siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey + secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey + if siteKeyChanged || secretKeyChanged { + if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil { + response.ErrorFrom(c, err) + return + } + } + } + settings := &service.SystemSettings{ RegistrationEnabled: req.RegistrationEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 5a243bfc..3912c8fb 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -385,7 +385,7 @@ func newContractDeps(t *testing.T) *contractDeps { authHandler := handler.NewAuthHandler(cfg, nil, userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ diff --git a/backend/internal/service/turnstile_service.go b/backend/internal/service/turnstile_service.go index 2a68c11b..cfb87c57 100644 --- a/backend/internal/service/turnstile_service.go +++ b/backend/internal/service/turnstile_service.go @@ -11,6 +11,7 @@ import ( var ( ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed") ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured") + ErrTurnstileInvalidSecretKey = infraerrors.BadRequest("TURNSTILE_INVALID_SECRET_KEY", "invalid turnstile secret key") ) // TurnstileVerifier 验证 Turnstile token 的接口 @@ -83,3 +84,22 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote func (s *TurnstileService) IsEnabled(ctx context.Context) bool { return s.settingService.IsTurnstileEnabled(ctx) } + +// ValidateSecretKey 验证 Turnstile Secret Key 是否有效 +func (s *TurnstileService) ValidateSecretKey(ctx context.Context, secretKey string) error { + // 发送一个测试token的验证请求来检查secret_key是否有效 + result, err := s.verifier.VerifyToken(ctx, secretKey, "test-validation", "") + if err != nil { + return fmt.Errorf("validate secret key: %w", err) + } + + // 检查是否有 invalid-input-secret 错误 + for _, code := range result.ErrorCodes { + if code == "invalid-input-secret" { + return ErrTurnstileInvalidSecretKey + } + } + + // 其他错误(如 invalid-input-response)说明 secret key 是有效的 + return nil +} From 35b768b71967740d30b9923fe7418f7eab7a4977 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:35:41 +0800 Subject: [PATCH 20/49] =?UTF-8?q?fix(upstream):=20=E8=B7=B3=E8=BF=87=20Cla?= =?UTF-8?q?ude=20=E6=A8=A1=E5=9E=8B=E6=97=A0=20signature=20=E7=9A=84=20thi?= =?UTF-8?q?nking=20block?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - buildParts 函数检测 thinking block 的 signature - Claude 模型 (allowDummyThought=false) 时跳过无 signature 的 block - 记录警告日志以便调试 - Gemini 模型继续使用 dummy signature 兼容方案 修复 Issue 0.1: Claude thinking block signature 缺失错误 --- backend/internal/pkg/antigravity/request_transformer.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 51eb4299..e5ab8ece 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -3,6 +3,7 @@ package antigravity import ( "encoding/json" "fmt" + "log" "strings" "github.com/google/uuid" @@ -205,6 +206,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu // 保留原有 signature(Claude 模型需要有效的 signature) if block.Signature != "" { part.ThoughtSignature = block.Signature + } else if !allowDummyThought { + // Claude 模型需要有效 signature,跳过无 signature 的 thinking block + log.Printf("Warning: skipping thinking block without signature for Claude model") + continue } parts = append(parts, part) From c1e25b7ecf745a97832b9a1cc8827d6e6123dc69 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:44:56 +0800 Subject: [PATCH 21/49] =?UTF-8?q?fix(upstream):=20=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E8=BE=B9=E7=95=8C=E6=A3=80=E6=9F=A5=E5=92=8C=20thinking=20bloc?= =?UTF-8?q?k=20=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 基于 Gemini + Codex 审查结果的修复: 1. thinking block dummy signature 填充 - Gemini 模型现在会填充 dummyThoughtSignature - 与 tool_use 处理逻辑保持一致 2. 边界检查增强 - buildTools: 跳过空工具名称 - buildTools: 为 nil schema 提供默认值 - convertClaudeToolsToGeminiTools: 为 nil params 提供默认值 3. 防止下游 API 验证错误 - 确保所有工具都有有效的 parameters - 默认 schema: {type: 'object', properties: {}} 审查报告:Gemini 评分 95%, Codex 评分 8.2/10 --- .../pkg/antigravity/request_transformer.go | 17 +++++++++++++++++ .../service/gemini_messages_compat_service.go | 8 ++++++++ 2 files changed, 25 insertions(+) diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index e5ab8ece..e0b5b886 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -210,6 +210,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu // Claude 模型需要有效 signature,跳过无 signature 的 thinking block log.Printf("Warning: skipping thinking block without signature for Claude model") continue + } else { + // Gemini 模型使用 dummy signature + part.ThoughtSignature = dummyThoughtSignature } parts = append(parts, part) @@ -384,6 +387,12 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl for _, tool := range tools { + // 跳过无效工具名称 + if tool.Name == "" { + log.Printf("Warning: skipping tool with empty name") + continue + } + var description string var inputSchema map[string]any @@ -401,6 +410,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 清理 JSON Schema params := cleanJSONSchema(inputSchema) + // 为 nil schema 提供默认值 + if params == nil { + params = map[string]any{ + "type": "OBJECT", + "properties": map[string]any{}, + } + } + funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index e55d798a..a0bf1b6a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2271,6 +2271,14 @@ func convertClaudeToolsToGeminiTools(tools any) []any { continue } + // 为 nil params 提供默认值 + if params == nil { + params = map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + } + funcDecls = append(funcDecls, map[string]any{ "name": name, "description": desc, From c5b792add579e3d837d5699928ca938e64346a08 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 22:48:35 +0800 Subject: [PATCH 22/49] =?UTF-8?q?fix(billing):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E9=99=90=E9=A2=9D=E4=B8=BA0=E6=97=B6=E6=B6=88=E8=B4=B9?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=E5=A4=B1=E8=B4=A5=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 normalizeLimit 函数,将 0 或负数限额规范化为 nil(无限制) - 简化 IncrementUsage,移除冗余的配额检查逻辑 - 配额检查已在请求前由中间件和网关完成 - 消费记录应无条件执行,确保数据完整性 - 删除测试限额超出行为的无效集成测试 --- .../repository/user_subscription_repo.go | 74 +--------- ...user_subscription_repo_integration_test.go | 137 +----------------- backend/internal/service/admin_service.go | 27 +++- .../internal/service/subscription_service.go | 1 + 4 files changed, 31 insertions(+), 208 deletions(-) diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 2b308674..cd3b9db6 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } -// IncrementUsage 原子性地累加用量并校验限额。 -// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。 -// 当更新失败时,会执行额外查询确定具体超出的限额类型。 +// IncrementUsage 原子性地累加订阅用量。 +// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成, +// 此处仅负责记录实际消费,确保消费数据的完整性。 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { - // 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加 - // NULL 限额表示无限制 - const atomicUpdateSQL = ` + const updateSQL = ` UPDATE user_subscriptions us SET daily_usage_usd = us.daily_usage_usd + $1, @@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 AND us.deleted_at IS NULL AND us.group_id = g.id AND g.deleted_at IS NULL - AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd) - AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd) - AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd) ` client := clientFromContext(ctx, r.client) - result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id) + result, err := client.ExecContext(ctx, updateSQL, costUSD, id) if err != nil { return err } @@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 } if affected > 0 { - return nil // 更新成功 + return nil } - // affected == 0:可能是订阅不存在、分组已删除、或限额超出 - // 执行额外查询确定具体原因 - return r.checkIncrementFailureReason(ctx, id, costUSD) -} - -// checkIncrementFailureReason 查询更新失败的具体原因 -func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error { - const checkSQL = ` - SELECT - CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted' - WHEN g.id IS NULL THEN 'subscription_not_found' - WHEN g.deleted_at IS NOT NULL THEN 'group_deleted' - WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded' - WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded' - WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded' - ELSE 'unknown' - END AS reason - FROM user_subscriptions us - LEFT JOIN groups g ON us.group_id = g.id - WHERE us.id = $2 - ` - - client := clientFromContext(ctx, r.client) - rows, err := client.QueryContext(ctx, checkSQL, costUSD, id) - if err != nil { - return err - } - defer func() { _ = rows.Close() }() - - if !rows.Next() { - return service.ErrSubscriptionNotFound - } - - var reason string - if err := rows.Scan(&reason); err != nil { - return err - } - - if err := rows.Err(); err != nil { - return err - } - - switch reason { - case "subscription_not_found", "subscription_deleted", "group_deleted": - return service.ErrSubscriptionNotFound - case "daily_exceeded": - return service.ErrDailyLimitExceeded - case "weekly_exceeded": - return service.ErrWeeklyLimitExceeded - case "monthly_exceeded": - return service.ErrMonthlyLimitExceeded - default: - // unknown 情况理论上不应发生,但作为兜底返回 - return service.ErrSubscriptionNotFound - } + // affected == 0:订阅不存在或已删除 + return service.ErrSubscriptionNotFound } func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 3a6c6434..2099e5d8 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") } -// --- 限额检查与软删除过滤测试 --- - -func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group { - s.T().Helper() - - create := s.client.Group.Create(). - SetName(name). - SetStatus(service.StatusActive). - SetSubscriptionType(service.SubscriptionTypeSubscription) - - if daily != nil { - create.SetDailyLimitUsd(*daily) - } - if weekly != nil { - create.SetWeeklyLimitUsd(*weekly) - } - if monthly != nil { - create.SetMonthlyLimitUsd(*monthly) - } - - g, err := create.Save(s.ctx) - s.Require().NoError(err, "create group with limits") - return groupEntityToService(g) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() { - user := s.mustCreateUser("dailylimit@test.com", service.RoleUser) - dailyLimit := 10.0 - group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 先增加 9.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 2.0,会超过 10.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0) - s.Require().Error(err, "should fail when daily limit exceeded") - s.Require().ErrorIs(err, service.ErrDailyLimitExceeded) - - // 验证用量没有变化 - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment") -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() { - user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser) - weeklyLimit := 50.0 - group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 增加 45.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 10.0,会超过 50.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) - s.Require().Error(err, "should fail when weekly limit exceeded") - s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() { - user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser) - monthlyLimit := 100.0 - group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 增加 90.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 20.0,会超过 100.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0) - s.Require().Error(err, "should fail when monthly limit exceeded") - s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() { - user := s.mustCreateUser("nolimits@test.com", service.RoleUser) - group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额 - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 应该可以增加任意金额 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0) - s.Require().NoError(err, "should succeed without limits") - - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() { - user := s.mustCreateUser("exactlimit@test.com", service.RoleUser) - dailyLimit := 10.0 - group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 正好达到限额应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) - s.Require().NoError(err, "should succeed at exact limit") - - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6) -} +// --- 软删除过滤测试 --- func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() { user := s.mustCreateUser("softdeleted@test.com", service.RoleUser) @@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { user := s.mustCreateUser("concurrent@test.com", service.RoleUser) - group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额 + group := s.mustCreateGroup("g-concurrent") sub := s.mustCreateSubscription(user.ID, group.ID, nil) const numGoroutines = 10 @@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated") } -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() { - user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser) - dailyLimit := 5.0 - group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑 - // 尝试增加 10 次,每次 1.0,但限额只有 5.0 - const numAttempts = 10 - const incrementPerAttempt = 1.0 - - successCount := 0 - for i := 0; i < numAttempts; i++ { - err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt) - if err == nil { - successCount++ - } - } - - // 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额) - s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)") - - // 验证最终用量等于限额 - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit") -} - func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() { baseClient := testEntClient(s.T()) tx, err := baseClient.Tx(context.Background()) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 4be09810..feeb19a0 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn subscriptionType = SubscriptionTypeStandard } + // 限额字段:0 和 nil 都表示"无限制" + dailyLimit := normalizeLimit(input.DailyLimitUSD) + weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) + monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) + group := &Group{ Name: input.Name, Description: input.Description, @@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn IsExclusive: input.IsExclusive, Status: StatusActive, SubscriptionType: subscriptionType, - DailyLimitUSD: input.DailyLimitUSD, - WeeklyLimitUSD: input.WeeklyLimitUSD, - MonthlyLimitUSD: input.MonthlyLimitUSD, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return group, nil } +// normalizeLimit 将 0 或负数转换为 nil(表示无限制) +func normalizeLimit(limit *float64) *float64 { + if limit == nil || *limit <= 0 { + return nil + } + return limit +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.SubscriptionType != "" { group.SubscriptionType = input.SubscriptionType } - // 限额字段支持设置为nil(清除限额)或具体值 + // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额 if input.DailyLimitUSD != nil { - group.DailyLimitUSD = input.DailyLimitUSD + group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) } if input.WeeklyLimitUSD != nil { - group.WeeklyLimitUSD = input.WeeklyLimitUSD + group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) } if input.MonthlyLimitUSD != nil { - group.MonthlyLimitUSD = input.MonthlyLimitUSD + group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) } if err := s.groupRepo.Update(ctx, group); err != nil { diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 09554c0f..f6aefb83 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -490,6 +490,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use } // CheckUsageLimits 检查使用限额(返回错误如果超限) +// 用于中间件的快速预检查,additionalCost 通常为 0 func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error { if !sub.CheckDailyLimit(group, additionalCost) { return ErrDailyLimitExceeded From bb7ade265da1da38154e1e44e57d060db7eb2c8e Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 23:37:51 +0800 Subject: [PATCH 23/49] =?UTF-8?q?chore(token-refresh):=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=20Antigravity=20Token=20=E5=88=B7=E6=96=B0=E8=B0=83?= =?UTF-8?q?=E8=AF=95=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - NeedsRefresh 判断为 true 时输出 expires_at、time_until_expiry、window - 修正注释中的刷新窗口描述(10分钟 → 15分钟) --- .../internal/service/antigravity_token_refresher.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index b4739025..9dd4463f 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "time" ) @@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { } // NeedsRefresh 检查账户是否需要刷新 -// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置 +// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置 func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool { if !r.CanRefresh(account) { return false @@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati if expiresAt == nil { return false } - return time.Until(*expiresAt) < antigravityRefreshWindow + timeUntilExpiry := time.Until(*expiresAt) + needsRefresh := timeUntilExpiry < antigravityRefreshWindow + if needsRefresh { + fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n", + account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow) + } + return needsRefresh } // Refresh 执行 token 刷新 From 2270a54ff6d9373b84082f20ac233f8fb419a563 Mon Sep 17 00:00:00 2001 From: NepetaLemon Date: Wed, 31 Dec 2025 23:42:01 +0800 Subject: [PATCH 24/49] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=20infrastr?= =?UTF-8?q?ucture=20=E7=9B=AE=E5=BD=95=20(#108)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: 迁移初始化 db 和 redis 到 repository * refactor: 迁移 errors 到 pkg --- backend/cmd/server/wire.go | 2 - backend/cmd/server/wire_gen.go | 7 +- backend/internal/infrastructure/wire.go | 79 ------------------- .../{infrastructure => pkg}/errors/errors.go | 0 .../errors/errors_test.go | 0 .../{infrastructure => pkg}/errors/http.go | 0 .../{infrastructure => pkg}/errors/types.go | 0 backend/internal/pkg/response/response.go | 2 +- .../internal/pkg/response/response_test.go | 14 ++-- .../{infrastructure => repository}/db_pool.go | 2 +- .../db_pool_test.go | 2 +- .../{infrastructure => repository}/ent.go | 2 +- .../internal/repository/error_translate.go | 2 +- .../repository/integration_harness_test.go | 3 +- .../migrations_runner.go | 2 +- .../migrations_schema_integration_test.go | 3 +- .../{infrastructure => repository}/redis.go | 2 +- .../redis_test.go | 2 +- backend/internal/repository/wire.go | 59 ++++++++++++++ .../internal/server/middleware/recovery.go | 2 +- .../server/middleware/recovery_test.go | 2 +- backend/internal/service/account_service.go | 2 +- backend/internal/service/api_key_service.go | 2 +- backend/internal/service/auth_service.go | 2 +- .../internal/service/billing_cache_service.go | 2 +- backend/internal/service/email_service.go | 2 +- backend/internal/service/group_service.go | 2 +- backend/internal/service/proxy_service.go | 2 +- backend/internal/service/redeem_service.go | 2 +- backend/internal/service/setting_service.go | 2 +- .../internal/service/subscription_service.go | 2 +- backend/internal/service/turnstile_service.go | 2 +- backend/internal/service/usage_service.go | 2 +- backend/internal/service/user_service.go | 2 +- backend/internal/setup/setup.go | 4 +- 35 files changed, 96 insertions(+), 121 deletions(-) delete mode 100644 backend/internal/infrastructure/wire.go rename backend/internal/{infrastructure => pkg}/errors/errors.go (100%) rename backend/internal/{infrastructure => pkg}/errors/errors_test.go (100%) rename backend/internal/{infrastructure => pkg}/errors/http.go (100%) rename backend/internal/{infrastructure => pkg}/errors/types.go (100%) rename backend/internal/{infrastructure => repository}/db_pool.go (97%) rename backend/internal/{infrastructure => repository}/db_pool_test.go (98%) rename backend/internal/{infrastructure => repository}/ent.go (99%) rename backend/internal/{infrastructure => repository}/migrations_runner.go (99%) rename backend/internal/{infrastructure => repository}/redis.go (98%) rename backend/internal/{infrastructure => repository}/redis_test.go (97%) diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index fffcd5f9..8596b8ba 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -12,7 +12,6 @@ import ( "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -31,7 +30,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { wire.Build( // Infrastructure layer ProviderSets config.ProviderSet, - infrastructure.ProviderSet, // Business layer ProviderSets repository.ProviderSet, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index c4859383..83cba823 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -12,7 +12,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler/admin" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -35,18 +34,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { if err != nil { return nil, err } - client, err := infrastructure.ProvideEnt(configConfig) + client, err := repository.ProvideEnt(configConfig) if err != nil { return nil, err } - db, err := infrastructure.ProvideSQLDB(client) + db, err := repository.ProvideSQLDB(client) if err != nil { return nil, err } userRepository := repository.NewUserRepository(client, db) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) - redisClient := infrastructure.ProvideRedis(configConfig) + redisClient := repository.ProvideRedis(configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() diff --git a/backend/internal/infrastructure/wire.go b/backend/internal/infrastructure/wire.go deleted file mode 100644 index 1e64640c..00000000 --- a/backend/internal/infrastructure/wire.go +++ /dev/null @@ -1,79 +0,0 @@ -package infrastructure - -import ( - "database/sql" - "errors" - - "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/internal/config" - - "github.com/google/wire" - "github.com/redis/go-redis/v9" - - entsql "entgo.io/ent/dialect/sql" -) - -// ProviderSet 是基础设施层的 Wire 依赖提供者集合。 -// -// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数 -// 组织在一起,便于在应用程序启动时自动组装依赖关系。 -// -// 包含的提供者: -// - ProvideEnt: 提供 Ent ORM 客户端 -// - ProvideSQLDB: 提供底层 SQL 数据库连接 -// - ProvideRedis: 提供 Redis 客户端 -var ProviderSet = wire.NewSet( - ProvideEnt, - ProvideSQLDB, - ProvideRedis, -) - -// ProvideEnt 为依赖注入提供 Ent 客户端。 -// -// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。 -// Wire 会在编译时分析依赖关系,自动生成初始化代码。 -// -// 依赖:config.Config -// 提供:*ent.Client -func ProvideEnt(cfg *config.Config) (*ent.Client, error) { - client, _, err := InitEnt(cfg) - return client, err -} - -// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。 -// -// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询), -// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。 -// -// 设计说明: -// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问 -// - 这种设计允许在同一事务中混用 Ent 和原生 SQL -// -// 依赖:*ent.Client -// 提供:*sql.DB -func ProvideSQLDB(client *ent.Client) (*sql.DB, error) { - if client == nil { - return nil, errors.New("nil ent client") - } - // 从 Ent 客户端获取底层驱动 - drv, ok := client.Driver().(*entsql.Driver) - if !ok { - return nil, errors.New("ent driver does not expose *sql.DB") - } - // 返回驱动持有的 sql.DB 实例 - return drv.DB(), nil -} - -// ProvideRedis 为依赖注入提供 Redis 客户端。 -// -// Redis 用于: -// - 分布式锁(如并发控制) -// - 缓存(如用户会话、API 响应缓存) -// - 速率限制 -// - 实时统计数据 -// -// 依赖:config.Config -// 提供:*redis.Client -func ProvideRedis(cfg *config.Config) *redis.Client { - return InitRedis(cfg) -} diff --git a/backend/internal/infrastructure/errors/errors.go b/backend/internal/pkg/errors/errors.go similarity index 100% rename from backend/internal/infrastructure/errors/errors.go rename to backend/internal/pkg/errors/errors.go diff --git a/backend/internal/infrastructure/errors/errors_test.go b/backend/internal/pkg/errors/errors_test.go similarity index 100% rename from backend/internal/infrastructure/errors/errors_test.go rename to backend/internal/pkg/errors/errors_test.go diff --git a/backend/internal/infrastructure/errors/http.go b/backend/internal/pkg/errors/http.go similarity index 100% rename from backend/internal/infrastructure/errors/http.go rename to backend/internal/pkg/errors/http.go diff --git a/backend/internal/infrastructure/errors/types.go b/backend/internal/pkg/errors/types.go similarity index 100% rename from backend/internal/infrastructure/errors/types.go rename to backend/internal/pkg/errors/types.go diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index e26d2531..87dc4264 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -4,7 +4,7 @@ import ( "math" "net/http" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" ) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index 13b184af..ef31ca3c 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -9,7 +9,7 @@ import ( "net/http/httptest" "testing" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "application_error", - err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), + err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), wantWritten: true, wantHTTPCode: http.StatusForbidden, wantBody: Response{ @@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "bad_request_error", - err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"), + err: errors2.BadRequest("INVALID_REQUEST", "invalid request"), wantWritten: true, wantHTTPCode: http.StatusBadRequest, wantBody: Response{ @@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "unauthorized_error", - err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"), + err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"), wantWritten: true, wantHTTPCode: http.StatusUnauthorized, wantBody: Response{ @@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "not_found_error", - err: infraerrors.NotFound("NOT_FOUND", "not found"), + err: errors2.NotFound("NOT_FOUND", "not found"), wantWritten: true, wantHTTPCode: http.StatusNotFound, wantBody: Response{ @@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "conflict_error", - err: infraerrors.Conflict("CONFLICT", "conflict"), + err: errors2.Conflict("CONFLICT", "conflict"), wantWritten: true, wantHTTPCode: http.StatusConflict, wantBody: Response{ @@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) { wantHTTPCode: http.StatusInternalServerError, wantBody: Response{ Code: http.StatusInternalServerError, - Message: infraerrors.UnknownMessage, + Message: errors2.UnknownMessage, }, }, } diff --git a/backend/internal/infrastructure/db_pool.go b/backend/internal/repository/db_pool.go similarity index 97% rename from backend/internal/infrastructure/db_pool.go rename to backend/internal/repository/db_pool.go index 612155bf..d7116ab1 100644 --- a/backend/internal/infrastructure/db_pool.go +++ b/backend/internal/repository/db_pool.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "database/sql" diff --git a/backend/internal/infrastructure/db_pool_test.go b/backend/internal/repository/db_pool_test.go similarity index 98% rename from backend/internal/infrastructure/db_pool_test.go rename to backend/internal/repository/db_pool_test.go index 0f0e9716..3868106a 100644 --- a/backend/internal/infrastructure/db_pool_test.go +++ b/backend/internal/repository/db_pool_test.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "database/sql" diff --git a/backend/internal/infrastructure/ent.go b/backend/internal/repository/ent.go similarity index 99% rename from backend/internal/infrastructure/ent.go rename to backend/internal/repository/ent.go index b1ab9a55..d457ba72 100644 --- a/backend/internal/infrastructure/ent.go +++ b/backend/internal/repository/ent.go @@ -1,6 +1,6 @@ // Package infrastructure 提供应用程序的基础设施层组件。 // 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。 -package infrastructure +package repository import ( "context" diff --git a/backend/internal/repository/error_translate.go b/backend/internal/repository/error_translate.go index 192f9261..b8065ffe 100644 --- a/backend/internal/repository/error_translate.go +++ b/backend/internal/repository/error_translate.go @@ -7,7 +7,7 @@ import ( "strings" dbent "github.com/Wei-Shaw/sub2api/ent" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/lib/pq" ) diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go index 6ef447e1..fb9c26c4 100644 --- a/backend/internal/repository/integration_harness_test.go +++ b/backend/internal/repository/integration_harness_test.go @@ -17,7 +17,6 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" _ "github.com/Wei-Shaw/sub2api/ent/runtime" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -97,7 +96,7 @@ func TestMain(m *testing.M) { log.Printf("failed to open sql db: %v", err) os.Exit(1) } - if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil { + if err := ApplyMigrations(ctx, integrationDB); err != nil { log.Printf("failed to apply db migrations: %v", err) os.Exit(1) } diff --git a/backend/internal/infrastructure/migrations_runner.go b/backend/internal/repository/migrations_runner.go similarity index 99% rename from backend/internal/infrastructure/migrations_runner.go rename to backend/internal/repository/migrations_runner.go index 8477c031..e556b9ce 100644 --- a/backend/internal/infrastructure/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "context" diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 49d96445..4c7848b2 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -7,7 +7,6 @@ import ( "database/sql" "testing" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/stretchr/testify/require" ) @@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { tx := testTx(t) // Re-apply migrations to verify idempotency (no errors, no duplicate rows). - require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB)) + require.NoError(t, ApplyMigrations(context.Background(), integrationDB)) // schema_migrations should have at least the current migration set. var applied int diff --git a/backend/internal/infrastructure/redis.go b/backend/internal/repository/redis.go similarity index 98% rename from backend/internal/infrastructure/redis.go rename to backend/internal/repository/redis.go index 9f4c8770..f3606ad9 100644 --- a/backend/internal/infrastructure/redis.go +++ b/backend/internal/repository/redis.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "time" diff --git a/backend/internal/infrastructure/redis_test.go b/backend/internal/repository/redis_test.go similarity index 97% rename from backend/internal/infrastructure/redis_test.go rename to backend/internal/repository/redis_test.go index 5e38e826..756a63dc 100644 --- a/backend/internal/infrastructure/redis_test.go +++ b/backend/internal/repository/redis_test.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "testing" diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index edeaf782..2de2d1de 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -1,6 +1,11 @@ package repository import ( + "database/sql" + "errors" + + entsql "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/wire" @@ -47,4 +52,58 @@ var ProviderSet = wire.NewSet( NewOpenAIOAuthClient, NewGeminiOAuthClient, NewGeminiCliCodeAssistClient, + + ProvideEnt, + ProvideSQLDB, + ProvideRedis, ) + +// ProvideEnt 为依赖注入提供 Ent 客户端。 +// +// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。 +// Wire 会在编译时分析依赖关系,自动生成初始化代码。 +// +// 依赖:config.Config +// 提供:*ent.Client +func ProvideEnt(cfg *config.Config) (*ent.Client, error) { + client, _, err := InitEnt(cfg) + return client, err +} + +// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。 +// +// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询), +// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。 +// +// 设计说明: +// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问 +// - 这种设计允许在同一事务中混用 Ent 和原生 SQL +// +// 依赖:*ent.Client +// 提供:*sql.DB +func ProvideSQLDB(client *ent.Client) (*sql.DB, error) { + if client == nil { + return nil, errors.New("nil ent client") + } + // 从 Ent 客户端获取底层驱动 + drv, ok := client.Driver().(*entsql.Driver) + if !ok { + return nil, errors.New("ent driver does not expose *sql.DB") + } + // 返回驱动持有的 sql.DB 实例 + return drv.DB(), nil +} + +// ProvideRedis 为依赖注入提供 Redis 客户端。 +// +// Redis 用于: +// - 分布式锁(如并发控制) +// - 缓存(如用户会话、API 响应缓存) +// - 速率限制 +// - 实时统计数据 +// +// 依赖:config.Config +// 提供:*redis.Client +func ProvideRedis(cfg *config.Config) *redis.Client { + return InitRedis(cfg) +} diff --git a/backend/internal/server/middleware/recovery.go b/backend/internal/server/middleware/recovery.go index 04ea6f9d..f05154d3 100644 --- a/backend/internal/server/middleware/recovery.go +++ b/backend/internal/server/middleware/recovery.go @@ -7,7 +7,7 @@ import ( "os" "strings" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/gin-gonic/gin" ) diff --git a/backend/internal/server/middleware/recovery_test.go b/backend/internal/server/middleware/recovery_test.go index 5edb6da0..439f44cb 100644 --- a/backend/internal/server/middleware/recovery_test.go +++ b/backend/internal/server/middleware/recovery_test.go @@ -8,7 +8,7 @@ import ( "net/http/httptest" "testing" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 05895c8b..3c5841bd 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index facf997e..f22c383a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -8,7 +8,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" ) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 54bbfa5c..69765520 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -8,7 +8,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 58ed555a..9cdeed7b 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -9,7 +9,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) // 错误定义 diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 7b4db611..6537b01e 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -10,7 +10,7 @@ import ( "strconv" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 886c0a3a..403636e8 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index c074b13d..044f9ffc 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 7b0b80f5..b6324235 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -10,7 +10,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 0ffe991d..b5786ece 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,7 +9,7 @@ import ( "strconv" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index f6aefb83..d960c86f 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -6,7 +6,7 @@ import ( "log" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/turnstile_service.go b/backend/internal/service/turnstile_service.go index cfb87c57..4afcc335 100644 --- a/backend/internal/service/turnstile_service.go +++ b/backend/internal/service/turnstile_service.go @@ -5,7 +5,7 @@ import ( "fmt" "log" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index f653ddfe..e1e97671 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index c17588c6..44a94d32 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 5565ab91..230d016f 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -11,7 +11,7 @@ import ( "strconv" "time" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" + "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service" _ "github.com/lib/pq" @@ -262,7 +262,7 @@ func initializeDatabase(cfg *SetupConfig) error { migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - return infrastructure.ApplyMigrations(migrationCtx, db) + return repository.ApplyMigrations(migrationCtx, db) } func createAdminUser(cfg *SetupConfig) error { From 8e55ee0e2ca9c5fd00e7afa5ded757bea43d2667 Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 31 Dec 2025 23:50:15 +0800 Subject: [PATCH 25/49] style: fix gofmt formatting in claude_types.go --- backend/internal/pkg/antigravity/claude_types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index f394d7e3..01b805cd 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -41,7 +41,7 @@ type ClaudeMetadata struct { // 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} } // 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } } type ClaudeTool struct { - Type string `json:"type,omitempty"` // "custom" 或空(标准格式) + Type string `json:"type,omitempty"` // "custom" 或空(标准格式) Name string `json:"name"` Description string `json:"description,omitempty"` // 标准格式使用 InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用 From 7f7bbdf67797510242f41654c151d62fd70eef02 Mon Sep 17 00:00:00 2001 From: song Date: Wed, 31 Dec 2025 21:16:32 +0800 Subject: [PATCH 26/49] =?UTF-8?q?refactor(antigravity):=20=E7=AE=80?= =?UTF-8?q?=E5=8C=96=E6=A8=A1=E5=9E=8B=E6=98=A0=E5=B0=84=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E6=94=AF=E6=8C=81=E5=89=8D=E7=BC=80=E5=8C=B9=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 删除精确映射表 antigravityModelMapping,统一使用前缀映射 - 前缀映射支持模型版本号变化(如 -20251111, -thinking, -preview) - 简化 IsModelSupported 函数,所有 claude-/gemini- 前缀模型都支持 - 添加跨协议测试用例:Claude 端点调用 Gemini 模型、Gemini 端点调用 Claude 模型 --- .../internal/integration/e2e_gateway_test.go | 59 ++++++++++++++ .../service/antigravity_gateway_service.go | 76 +++++++++---------- backend/internal/service/gateway_service.go | 20 +---- 3 files changed, 96 insertions(+), 59 deletions(-) diff --git a/backend/internal/integration/e2e_gateway_test.go b/backend/internal/integration/e2e_gateway_test.go index 05cdc85f..ec0b29f7 100644 --- a/backend/internal/integration/e2e_gateway_test.go +++ b/backend/internal/integration/e2e_gateway_test.go @@ -57,6 +57,7 @@ var geminiModels = []string{ "gemini-2.5-flash-lite", "gemini-3-flash", "gemini-3-pro-low", + "gemini-3-pro-high", } func TestMain(m *testing.M) { @@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) { t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"]) } +// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型 +// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射) +// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity") +func TestClaudeMessagesWithGeminiModel(t *testing.T) { + if endpointPrefix != "/antigravity" { + t.Skip("仅在 Antigravity 模式下运行") + } + + // 测试通过 Claude 端点调用 Gemini 模型 + geminiViaClaude := []string{ + "gemini-3-flash", // 直接支持 + "gemini-3-pro-low", // 直接支持 + "gemini-3-pro-high", // 直接支持 + "gemini-3-pro", // 前缀映射 -> gemini-3-pro-high + "gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high + } + + for i, model := range geminiViaClaude { + if i > 0 { + time.Sleep(testInterval) + } + t.Run(model+"_通过Claude端点", func(t *testing.T) { + testClaudeMessage(t, model, false) + }) + time.Sleep(testInterval) + t.Run(model+"_通过Claude端点_流式", func(t *testing.T) { + testClaudeMessage(t, model, true) + }) + } +} + // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // 验证:Gemini 模型接受没有 signature 的 thinking block func TestClaudeMessagesWithNoSignature(t *testing.T) { @@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) { } t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"]) } + +// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型 +// 仅在 Antigravity 模式下运行(ENDPOINT_PREFIX="/antigravity") +func TestGeminiEndpointWithClaudeModel(t *testing.T) { + if endpointPrefix != "/antigravity" { + t.Skip("仅在 Antigravity 模式下运行") + } + + // 测试通过 Gemini 端点调用 Claude 模型 + claudeViaGemini := []string{ + "claude-sonnet-4-5", + "claude-opus-4-5-thinking", + } + + for i, model := range claudeViaGemini { + if i > 0 { + time.Sleep(testInterval) + } + t.Run(model+"_通过Gemini端点", func(t *testing.T) { + testGeminiGenerate(t, model, false) + }) + time.Sleep(testInterval) + t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) { + testGeminiGenerate(t, model, true) + }) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index ae2976f8..52dbe263 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -25,7 +25,7 @@ const ( antigravityRetryMaxDelay = 16 * time.Second ) -// Antigravity 直接支持的模型 +// Antigravity 直接支持的模型(精确匹配透传) var antigravitySupportedModels = map[string]bool{ "claude-opus-4-5-thinking": true, "claude-sonnet-4-5": true, @@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{ "gemini-3-flash": true, "gemini-3-pro-low": true, "gemini-3-pro-high": true, - "gemini-3-pro-preview": true, "gemini-3-pro-image": true, } -// Antigravity 系统默认模型映射表(不支持 → 支持) -var antigravityModelMapping = map[string]string{ - "claude-3-5-sonnet-20241022": "claude-sonnet-4-5", - "claude-3-5-sonnet-20240620": "claude-sonnet-4-5", - "claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking", - "claude-opus-4": "claude-opus-4-5-thinking", - "claude-opus-4-5-20251101": "claude-opus-4-5-thinking", - "claude-haiku-4": "gemini-3-flash", - "claude-haiku-4-5": "gemini-3-flash", - "claude-3-haiku-20240307": "gemini-3-flash", - "claude-haiku-4-5-20251001": "gemini-3-flash", - // 生图模型:官方名 → Antigravity 内部名 - "gemini-3-pro-image-preview": "gemini-3-pro-image", +// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先) +// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀) +var antigravityPrefixMapping = []struct { + prefix string + target string +}{ + // 长前缀优先 + {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 + {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx + {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx + {"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx + {"claude-opus-4-5", "claude-opus-4-5-thinking"}, + {"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx + {"claude-sonnet-4", "claude-sonnet-4-5"}, + {"claude-haiku-4", "gemini-3-flash"}, + {"claude-opus-4", "claude-opus-4-5-thinking"}, + {"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等 } // AntigravityGatewayService 处理 Antigravity 平台的 API 转发 @@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider } // getMappedModel 获取映射后的模型名 +// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值 func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string { - // 1. 优先使用账户级映射(复用现有方法) + // 1. 账户级映射(用户自定义优先) if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel { return mapped } - // 2. 系统默认映射 - if mapped, ok := antigravityModelMapping[requestedModel]; ok { - return mapped - } - - // 3. Gemini 模型透传 - if strings.HasPrefix(requestedModel, "gemini-") { + // 2. 直接支持的模型透传 + if antigravitySupportedModels[requestedModel] { return requestedModel } - // 4. Claude 前缀透传直接支持的模型 - if antigravitySupportedModels[requestedModel] { + // 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview) + for _, pm := range antigravityPrefixMapping { + if strings.HasPrefix(requestedModel, pm.prefix) { + return pm.target + } + } + + // 4. Gemini 模型透传(未匹配到前缀的 gemini 模型) + if strings.HasPrefix(requestedModel, "gemini-") { return requestedModel } @@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo } // IsModelSupported 检查模型是否被支持 +// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持 func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool { - // 直接支持的模型 - if antigravitySupportedModels[requestedModel] { - return true - } - // 可映射的模型 - if _, ok := antigravityModelMapping[requestedModel]; ok { - return true - } - // Gemini 前缀透传 - if strings.HasPrefix(requestedModel, "gemini-") { - return true - } - // Claude 模型支持(通过默认映射) - if strings.HasPrefix(requestedModel, "claude-") { - return true - } - return false + return strings.HasPrefix(requestedModel, "claude-") || + strings.HasPrefix(requestedModel, "gemini-") } // TestConnectionResult 测试连接结果 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d542e9c2..9874751d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -515,24 +515,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } // IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型 +// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持 func IsAntigravityModelSupported(requestedModel string) bool { - // 直接支持的模型 - if antigravitySupportedModels[requestedModel] { - return true - } - // 可映射的模型 - if _, ok := antigravityModelMapping[requestedModel]; ok { - return true - } - // Gemini 前缀透传 - if strings.HasPrefix(requestedModel, "gemini-") { - return true - } - // Claude 模型支持(通过默认映射到 claude-sonnet-4-5) - if strings.HasPrefix(requestedModel, "claude-") { - return true - } - return false + return strings.HasPrefix(requestedModel, "claude-") || + strings.HasPrefix(requestedModel, "gemini-") } // GetAccessToken 获取账号凭证 From 85485f1702d25c4d34c7a65533990b447ccb97ee Mon Sep 17 00:00:00 2001 From: song Date: Thu, 1 Jan 2026 01:59:25 +0800 Subject: [PATCH 27/49] style: fix gofmt formatting --- backend/internal/service/antigravity_gateway_service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 52dbe263..e9225bff 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -51,7 +51,7 @@ var antigravityPrefixMapping = []struct { {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx {"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx {"claude-opus-4-5", "claude-opus-4-5-thinking"}, - {"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx + {"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx {"claude-sonnet-4", "claude-sonnet-4-5"}, {"claude-haiku-4", "gemini-3-flash"}, {"claude-opus-4", "claude-opus-4-5-thinking"}, From edee46e47f98d9e70c43fc4cddb4f203b31f1568 Mon Sep 17 00:00:00 2001 From: song Date: Thu, 1 Jan 2026 02:07:41 +0800 Subject: [PATCH 28/49] =?UTF-8?q?test:=20=E6=9B=B4=E6=96=B0=20model=20mapp?= =?UTF-8?q?ing=20=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B=E6=9C=9F=E6=9C=9B?= =?UTF-8?q?=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/antigravity_model_mapping_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index b3631dfc..1e37cdc2 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { name: "系统映射 - claude-sonnet-4-5-20250929", requestedModel: "claude-sonnet-4-5-20250929", accountMapping: nil, - expected: "claude-sonnet-4-5-thinking", + expected: "claude-sonnet-4-5", }, // 3. Gemini 透传 From 592d2d097875a94f02978a14aa88bdfea0aa6c91 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 1 Jan 2026 04:01:51 +0800 Subject: [PATCH 29/49] =?UTF-8?q?feat(gateway):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E8=B4=9F=E8=BD=BD=E6=84=9F=E7=9F=A5=E7=9A=84=E8=B4=A6=E5=8F=B7?= =?UTF-8?q?=E8=B0=83=E5=BA=A6=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增调度配置:粘性会话排队、兜底排队、负载计算、槽位清理 - 实现账号级等待队列和批量负载查询(Redis Lua 脚本) - 三层选择策略:粘性会话优先 → 负载感知选择 → 兜底排队 - 后台定期清理过期槽位,防止资源泄漏 - 集成到所有网关处理器(Claude/Gemini/OpenAI) --- backend/cmd/server/wire_gen.go | 6 +- backend/internal/config/config.go | 42 ++ backend/internal/config/config_test.go | 49 ++- backend/internal/handler/gateway_handler.go | 108 ++++- backend/internal/handler/gateway_helper.go | 22 +- .../internal/handler/gemini_v1beta_handler.go | 51 ++- .../handler/openai_gateway_handler.go | 49 ++- .../internal/repository/concurrency_cache.go | 185 ++++++++- .../concurrency_cache_benchmark_test.go | 2 +- .../concurrency_cache_integration_test.go | 44 +- backend/internal/repository/wire.go | 9 +- .../internal/service/concurrency_service.go | 110 +++++ .../service/gateway_multiplatform_test.go | 54 +++ backend/internal/service/gateway_service.go | 387 +++++++++++++++++- .../service/openai_gateway_service.go | 260 ++++++++++++ backend/internal/service/wire.go | 11 +- 16 files changed, 1342 insertions(+), 47 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index c4859383..e3498680 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.NewConcurrencyService(concurrencyCache) + concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) oAuthHandler := admin.NewOAuthHandler(oAuthService) @@ -128,10 +128,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index aeeddcb4..8c154a9d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "strings" + "time" "github.com/spf13/viper" ) @@ -119,6 +120,26 @@ type GatewayConfig struct { // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + + // Scheduling: 账号调度相关配置 + Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` +} + +// GatewaySchedulingConfig accounts scheduling configuration. +type GatewaySchedulingConfig struct { + // 粘性会话排队配置 + StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"` + StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"` + + // 兜底排队配置 + FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` + FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` + + // 负载计算 + LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + + // 过期槽位清理周期(0 表示禁用) + SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` } func (s *ServerConfig) Address() string { @@ -323,6 +344,12 @@ func setDefaults() { viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) + viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) + viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) + viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) + viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) + viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -411,6 +438,21 @@ func (c *Config) Validate() error { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") } + if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { + return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") + } + if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 { + return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive") + } + if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 { + return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive") + } + if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { + return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") + } + if c.Gateway.Scheduling.SlotCleanupInterval < 0 { + return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") + } return nil } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 1f1becb8..6e722a54 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,11 @@ package config -import "testing" +import ( + "testing" + "time" + + "github.com/spf13/viper" +) func TestNormalizeRunMode(t *testing.T) { tests := []struct { @@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) { } } } + +func TestLoadDefaultSchedulingConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 { + t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting) + } + if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second { + t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout) + } + if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second { + t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout) + } + if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 { + t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting) + } + if !cfg.Gateway.Scheduling.LoadBatchEnabled { + t.Fatalf("LoadBatchEnabled = false, want true") + } + if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second { + t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval) + } +} + +func TestLoadSchedulingConfigFromEnv(t *testing.T) { + viper.Reset() + t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 { + t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting) + } +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a2f833ff..769e6700 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -141,6 +141,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } else if apiKey.Group != nil { platform = apiKey.Group.Platform } + sessionKey := sessionHash + if platform == service.PlatformGemini && sessionHash != "" { + sessionKey = "gemini:" + sessionHash + } if platform == service.PlatformGemini { const maxAccountSwitches = 3 @@ -149,7 +153,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { lastFailoverStatus := 0 for { - account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -158,9 +162,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } + account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -170,11 +178,44 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // 转发请求 - 根据账号平台分流 @@ -187,6 +228,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -231,7 +275,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -240,9 +284,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } + account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -252,11 +300,44 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // 转发请求 - 根据账号平台分流 @@ -269,6 +350,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 4c7bd0f0..4e049dbb 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -83,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 h.concurrencyService.DecrementWaitCount(ctx, userID) } +// IncrementAccountWaitCount increments the wait count for an account +func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait) +} + +// DecrementAccountWaitCount decrements the wait count for an account +func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) { + h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) +} + // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -126,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait) + return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) +} + +// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. +func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() // Determine if ping is needed (streaming + ping format defined) @@ -200,6 +215,11 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, } } +// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). +func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { + return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) +} + // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 4e99e00d..1959c0f3 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -197,13 +197,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) parsedReq, _ := service.ParseGatewayRequest(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + sessionKey := sessionHash + if sessionHash != "" { + sessionKey = "gemini:" + sessionHash + } const maxAccountSwitches = 3 switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 for { - account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) @@ -212,12 +216,46 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { handleGeminiFailoverExhausted(c, lastFailoverStatus) return } + account := selection.Account // 4) account concurrency slot - accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) - if err != nil { - googleError(c, http.StatusTooManyRequests, err.Error()) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts") + return + } + canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") + return + } + accountWaitRelease = func() { + geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + stream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + googleError(c, http.StatusTooManyRequests, err.Error()) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // 5) forward (根据平台分流) @@ -230,6 +268,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 7c9934c6..c6b969bc 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) if len(failedAccountIDs) == 0 { @@ -156,14 +156,48 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } + account := selection.Account log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) // 3. Acquire account concurrency slot - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return + accountReleaseFunc := selection.ReleaseFunc + var accountWaitRelease func() + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + if accountWaitRelease != nil { + accountWaitRelease() + } + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } // Forward request @@ -171,6 +205,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } + if accountWaitRelease != nil { + accountWaitRelease() + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 9205230b..d8d6989b 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -2,7 +2,9 @@ package repository import ( "context" + "errors" "fmt" + "strconv" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" @@ -27,6 +29,8 @@ const ( userSlotKeyPrefix = "concurrency:user:" // 等待队列计数器格式: concurrency:wait:{userID} waitQueueKeyPrefix = "concurrency:wait:" + // 账号级等待队列计数器格式: wait:account:{accountID} + accountWaitKeyPrefix = "wait:account:" // 默认槽位过期时间(分钟),可通过配置覆盖 defaultSlotTTLMinutes = 15 @@ -112,33 +116,112 @@ var ( redis.call('EXPIRE', KEYS[1], ARGV[2]) end - return 1 - `) + return 1 + `) + + // incrementAccountWaitScript - account-level wait queue count + incrementAccountWaitScript = redis.NewScript(` + local current = redis.call('GET', KEYS[1]) + if current == false then + current = 0 + else + current = tonumber(current) + end + + if current >= tonumber(ARGV[1]) then + return 0 + end + + local newVal = redis.call('INCR', KEYS[1]) + + -- Only set TTL on first creation to avoid refreshing zombie data + if newVal == 1 then + redis.call('EXPIRE', KEYS[1], ARGV[2]) + end + + return 1 + `) // decrementWaitScript - same as before decrementWaitScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current ~= false and tonumber(current) > 0 then - redis.call('DECR', KEYS[1]) - end - return 1 - `) + local current = redis.call('GET', KEYS[1]) + if current ~= false and tonumber(current) > 0 then + redis.call('DECR', KEYS[1]) + end + return 1 + `) + + // getAccountsLoadBatchScript - batch load query (read-only) + // ARGV[1] = slot TTL (seconds, retained for compatibility) + // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... + getAccountsLoadBatchScript = redis.NewScript(` + local result = {} + + local i = 2 + while i <= #ARGV do + local accountID = ARGV[i] + local maxConcurrency = tonumber(ARGV[i + 1]) + + local slotKey = 'concurrency:account:' .. accountID + local currentConcurrency = redis.call('ZCARD', slotKey) + + local waitKey = 'wait:account:' .. accountID + local waitingCount = redis.call('GET', waitKey) + if waitingCount == false then + waitingCount = 0 + else + waitingCount = tonumber(waitingCount) + end + + local loadRate = 0 + if maxConcurrency > 0 then + loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) + end + + table.insert(result, accountID) + table.insert(result, currentConcurrency) + table.insert(result, waitingCount) + table.insert(result, loadRate) + + i = i + 2 + end + + return result + `) + + // cleanupExpiredSlotsScript - remove expired slots + // KEYS[1] = concurrency:account:{accountID} + // ARGV[1] = TTL (seconds) + cleanupExpiredSlotsScript = redis.NewScript(` + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + local timeResult = redis.call('TIME') + local now = tonumber(timeResult[1]) + local expireBefore = now - ttl + return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) + `) ) type concurrencyCache struct { - rdb *redis.Client - slotTTLSeconds int // 槽位过期时间(秒) + rdb *redis.Client + slotTTLSeconds int // 槽位过期时间(秒) + waitQueueTTLSeconds int // 等待队列过期时间(秒) } // NewConcurrencyCache 创建并发控制缓存 // slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟 -func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache { +// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL +func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache { if slotTTLMinutes <= 0 { slotTTLMinutes = defaultSlotTTLMinutes } + if waitQueueTTLSeconds <= 0 { + waitQueueTTLSeconds = slotTTLMinutes * 60 + } return &concurrencyCache{ - rdb: rdb, - slotTTLSeconds: slotTTLMinutes * 60, + rdb: rdb, + slotTTLSeconds: slotTTLMinutes * 60, + waitQueueTTLSeconds: waitQueueTTLSeconds, } } @@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string { return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) } +func accountWaitKey(accountID int64) string { + return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) +} + // Account slot operations func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { @@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() return err } + +// Account wait queue operations + +func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + key := accountWaitKey(accountID) + result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int() + if err != nil { + return false, err + } + return result == 1, nil +} + +func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + key := accountWaitKey(accountID) + _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() + return err +} + +func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + key := accountWaitKey(accountID) + val, err := c.rdb.Get(ctx, key).Int() + if err != nil && !errors.Is(err, redis.Nil) { + return 0, err + } + if errors.Is(err, redis.Nil) { + return 0, nil + } + return val, nil +} + +func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + if len(accounts) == 0 { + return map[int64]*service.AccountLoadInfo{}, nil + } + + args := []interface{}{c.slotTTLSeconds} + for _, acc := range accounts { + args = append(args, acc.ID, acc.MaxConcurrency) + } + + result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + if err != nil { + return nil, err + } + + loadMap := make(map[int64]*service.AccountLoadInfo) + for i := 0; i < len(result); i += 4 { + if i+3 >= len(result) { + break + } + + accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) + currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) + waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) + loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) + + loadMap[accountID] = &service.AccountLoadInfo{ + AccountID: accountID, + CurrentConcurrency: currentConcurrency, + WaitingCount: waitingCount, + LoadRate: loadRate, + } + } + + return loadMap, nil +} + +func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + key := accountSlotKey(accountID) + _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() + return err +} diff --git a/backend/internal/repository/concurrency_cache_benchmark_test.go b/backend/internal/repository/concurrency_cache_benchmark_test.go index cafab9cb..25697ab1 100644 --- a/backend/internal/repository/concurrency_cache_benchmark_test.go +++ b/backend/internal/repository/concurrency_cache_benchmark_test.go @@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) { _ = rdb.Close() }() - cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) + cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache) ctx := context.Background() for _, size := range []int{10, 100, 1000} { diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 6a7c83f4..f3d70ef1 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct { func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() - s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) + s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) } func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { @@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") } +func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { + accountID := int64(30) + waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) + + ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) + require.NoError(s.T(), err, "IncrementAccountWaitCount 1") + require.True(s.T(), ok) + + ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) + require.NoError(s.T(), err, "IncrementAccountWaitCount 2") + require.True(s.T(), ok) + + ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) + require.NoError(s.T(), err, "IncrementAccountWaitCount 3") + require.False(s.T(), ok, "expected account wait increment over max to fail") + + ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() + require.NoError(s.T(), err, "TTL account waitKey") + s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) + + require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount") + + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } + require.Equal(s.T(), 1, val, "expected account wait count 1") +} + +func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { + accountID := int64(301) + waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) + + require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") + + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } + require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") +} + func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { // When no slots exist, GetAccountConcurrency should return 0 cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index edeaf782..f1a8d4cf 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -10,7 +10,14 @@ import ( // ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 // 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { - return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) + waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds()) + if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout { + waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds()) + } + if waitTTLSeconds <= 0 { + waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60 + } + return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds) } // ProviderSet is the Wire provider set for all repositories diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index b5229491..65ef16db 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -18,6 +18,11 @@ type ConcurrencyCache interface { ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) + // 账号等待队列(账号级) + IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) + DecrementAccountWaitCount(ctx context.Context, accountID int64) error + GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) + // 用户槽位管理 // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) @@ -27,6 +32,12 @@ type ConcurrencyCache interface { // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error + + // 批量负载查询(只读) + GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) + + // 清理过期槽位(后台任务) + CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } // generateRequestID generates a unique request ID for concurrency slot tracking @@ -61,6 +72,18 @@ type AcquireResult struct { ReleaseFunc func() // Must be called when done (typically via defer) } +type AccountWithConcurrency struct { + ID int64 + MaxConcurrency int +} + +type AccountLoadInfo struct { + AccountID int64 + CurrentConcurrency int + WaitingCount int + LoadRate int // 0-100+ (percent) +} + // AcquireAccountSlot attempts to acquire a concurrency slot for an account. // If the account is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. @@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 } } +// IncrementAccountWaitCount increments the wait queue counter for an account. +func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + if s.cache == nil { + return true, nil + } + + result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) + if err != nil { + log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err) + return true, nil + } + return result, nil +} + +// DecrementAccountWaitCount decrements the wait queue counter for an account. +func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) { + if s.cache == nil { + return + } + + bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { + log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err) + } +} + +// GetAccountWaitingCount gets current wait queue count for an account. +func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if s.cache == nil { + return 0, nil + } + return s.cache.GetAccountWaitingCount(ctx, accountID) +} + // CalculateMaxWait calculates the maximum wait queue size for a user // maxWait = userConcurrency + defaultExtraWaitSlots func CalculateMaxWait(userConcurrency int) int { @@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int { return userConcurrency + defaultExtraWaitSlots } +// GetAccountsLoadBatch returns load info for multiple accounts. +func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if s.cache == nil { + return map[int64]*AccountLoadInfo{}, nil + } + return s.cache.GetAccountsLoadBatch(ctx, accounts) +} + +// CleanupExpiredAccountSlots removes expired slots for one account (background task). +func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + if s.cache == nil { + return nil + } + return s.cache.CleanupExpiredAccountSlots(ctx, accountID) +} + +// StartSlotCleanupWorker starts a background cleanup worker for expired account slots. +func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) { + if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 { + return + } + + runCleanup := func() { + listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + accounts, err := accountRepo.ListSchedulable(listCtx) + cancel() + if err != nil { + log.Printf("Warning: list schedulable accounts failed: %v", err) + return + } + for _, account := range accounts { + accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second) + err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) + accountCancel() + if err != nil { + log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err) + } + } + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + runCleanup() + for range ticker.C { + runCleanup() + } + }() +} + // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index d779bcfa..e1b61632 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户") } +func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") +} + // TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户 func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) { ctx := context.Background() @@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ctx := context.Background() + t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") + }) + t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d542e9c2..6c45ff0f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -13,6 +13,7 @@ import ( "log" "net/http" "regexp" + "sort" "strings" "time" @@ -66,6 +67,20 @@ type GatewayCache interface { RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error } +type AccountWaitPlan struct { + AccountID int64 + MaxConcurrency int + Timeout time.Duration + MaxWaiting int +} + +type AccountSelectionResult struct { + Account *Account + Acquired bool + ReleaseFunc func() + WaitPlan *AccountWaitPlan // nil means no wait allowed +} + // ClaudeUsage 表示Claude API返回的usage信息 type ClaudeUsage struct { InputTokens int `json:"input_tokens"` @@ -108,6 +123,7 @@ type GatewayService struct { identityService *IdentityService httpUpstream HTTPUpstream deferredService *DeferredService + concurrencyService *ConcurrencyService } // NewGatewayService creates a new GatewayService @@ -119,6 +135,7 @@ func NewGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, + concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -134,6 +151,7 @@ func NewGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, + concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -183,6 +201,14 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return "" } +// BindStickySession sets session -> account binding with standard TTL. +func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 { + return nil + } + return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) +} + func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" @@ -332,8 +358,360 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } +// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + cfg := s.schedulingConfig() + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + if err != nil { + return nil, err + } + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) + if err != nil { + return nil, err + } + preferOAuth := platform == PlatformGemini + + accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, errors.New("no available accounts") + } + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // ============ Layer 1: 粘性会话优先 ============ + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + if err == nil && accountID > 0 && !isExcluded(accountID) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) && + account.IsSchedulable() && + (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + } + } + + // ============ Layer 2: 负载感知选择 ============ + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, errors.New("no available accounts") + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.Concurrency, + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { + return result, nil + } + } else { + type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo + } + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) > 0 { + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + if preferOAuth && a.account.Type != b.account.Type { + return a.account.Type == AccountTypeOAuth + } + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + + for _, item := range available { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } + } + + // ============ Layer 3: 兜底排队 ============ + sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) + for _, acc := range candidates { + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + return nil, errors.New("no available accounts") +} + +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) + + for _, acc := range ordered { + result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, true + } + } + + return nil, false +} + +func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} + +func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + return forcePlatform, true, nil + } + if groupID != nil { + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return "", false, fmt.Errorf("get group failed: %w", err) + } + return group.Platform, false, nil + } + return PlatformAnthropic, false, nil +} + +func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform + if useMixed { + platforms := []string{platform, PlatformAntigravity} + var accounts []Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } + if err != nil { + return nil, useMixed, err + } + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + filtered = append(filtered, acc) + } + return filtered, useMixed, nil + } + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) + if err == nil && len(accounts) == 0 && hasForcePlatform { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } + if err != nil { + return nil, useMixed, err + } + return accounts, useMixed, nil +} + +func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { + if account == nil { + return false + } + if useMixed { + if account.Platform == platform { + return true + } + return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() + } + return account.Platform == platform +} + +func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +func sortAccountsByPriority(accounts []*Account) { + sort.SliceStable(accounts, func(i, j int) bool { + return accounts[i].Priority < accounts[j].Priority + }) +} + +func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + switch { + case a.LastUsedAt == nil && b.LastUsedAt != nil: + return true + case a.LastUsedAt != nil && b.LastUsedAt == nil: + return false + case a.LastUsedAt == nil && b.LastUsedAt == nil: + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + default: + return a.LastUsedAt.Before(*b.LastUsedAt) + } + }) +} + // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + preferOAuth := platform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) @@ -389,7 +767,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - // keep selected (both never used) + if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc @@ -419,6 +799,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { platforms := []string{nativePlatform, PlatformAntigravity} + preferOAuth := nativePlatform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { @@ -478,7 +859,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - // keep selected (both never used) + if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 84e98679..f8eb29bd 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -13,6 +13,7 @@ import ( "log" "net/http" "regexp" + "sort" "strconv" "strings" "time" @@ -80,6 +81,7 @@ type OpenAIGatewayService struct { userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config + concurrencyService *ConcurrencyService billingService *BillingService rateLimitService *RateLimitService billingCacheService *BillingCacheService @@ -95,6 +97,7 @@ func NewOpenAIGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, + concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -108,6 +111,7 @@ func NewOpenAIGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, + concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -126,6 +130,14 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { return hex.EncodeToString(hash[:]) } +// BindStickySession sets session -> account binding with standard TTL. +func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 { + return nil + } + return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL) +} + // SelectAccount selects an OpenAI account with sticky session support func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") @@ -218,6 +230,254 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C return selected, nil } +// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. +func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + cfg := s.schedulingConfig() + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + if err != nil { + return nil, err + } + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + accounts, err := s.listSchedulableAccounts(ctx, groupID) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, errors.New("no available accounts") + } + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // ============ Layer 1: Sticky session ============ + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) + if err == nil && accountID > 0 && !isExcluded(accountID) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account.IsSchedulable() && account.IsOpenAI() && + (requestedModel == "" || account.IsModelSupported(requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + } + } + + // ============ Layer 2: Load-aware selection ============ + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, errors.New("no available accounts") + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.Concurrency, + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, false) + for _, acc := range ordered { + result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + } + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } else { + type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo + } + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) > 0 { + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + + for _, item := range available { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + } + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } + } + + // ============ Layer 3: Fallback wait ============ + sortAccountsByPriorityAndLastUsed(candidates, false) + for _, acc := range candidates { + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + return nil, errors.New("no available accounts") +} + +func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + return accounts, nil +} + +func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} + // GetAccessToken gets the access token for an OpenAI account func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 81e01d47..a202ccf2 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -73,6 +73,15 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh return svc } +// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. +func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { + svc := NewConcurrencyService(cache) + if cfg != nil { + svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) + } + return svc +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -107,7 +116,7 @@ var ProviderSet = wire.NewSet( ProvideEmailQueueService, NewTurnstileService, NewSubscriptionService, - NewConcurrencyService, + ProvideConcurrencyService, NewIdentityService, NewCRSSyncService, ProvideUpdateService, From fe31495a893e276e3192b3762ecbf3e4079cd4cf Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 1 Jan 2026 04:15:31 +0800 Subject: [PATCH 30/49] =?UTF-8?q?test(gateway):=20=E8=A1=A5=E5=85=85?= =?UTF-8?q?=E8=B4=A6=E5=8F=B7=E8=B0=83=E5=BA=A6=E4=BC=98=E5=8C=96=E7=9A=84?= =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 GetAccountsLoadBatch 批量负载查询测试 - 添加 CleanupExpiredAccountSlots 过期槽位清理测试 - 添加 SelectAccountWithLoadAwareness 负载感知选择测试 - 测试覆盖降级行为、账号排除、错误处理等场景 --- .../concurrency_cache_integration_test.go | 132 +++++++++++++++ .../service/gateway_multiplatform_test.go | 157 ++++++++++++++++++ 2 files changed, 289 insertions(+) diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index f3d70ef1..707cbdab 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -274,6 +274,138 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() { require.Equal(s.T(), 0, cur) } +func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() { + // Setup: Create accounts with different load states + account1 := int64(100) + account2 := int64(101) + account3 := int64(102) + + // Account 1: 2/3 slots used, 1 waiting + ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1") + require.NoError(s.T(), err) + require.True(s.T(), ok) + ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2") + require.NoError(s.T(), err) + require.True(s.T(), ok) + ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5) + require.NoError(s.T(), err) + require.True(s.T(), ok) + + // Account 2: 1/2 slots used, 0 waiting + ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3") + require.NoError(s.T(), err) + require.True(s.T(), ok) + + // Account 3: 0/1 slots used, 0 waiting (idle) + + // Query batch load + accounts := []service.AccountWithConcurrency{ + {ID: account1, MaxConcurrency: 3}, + {ID: account2, MaxConcurrency: 2}, + {ID: account3, MaxConcurrency: 1}, + } + + loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts) + require.NoError(s.T(), err) + require.Len(s.T(), loadMap, 3) + + // Verify account1: (2 + 1) / 3 = 100% + load1 := loadMap[account1] + require.NotNil(s.T(), load1) + require.Equal(s.T(), account1, load1.AccountID) + require.Equal(s.T(), 2, load1.CurrentConcurrency) + require.Equal(s.T(), 1, load1.WaitingCount) + require.Equal(s.T(), 100, load1.LoadRate) + + // Verify account2: (1 + 0) / 2 = 50% + load2 := loadMap[account2] + require.NotNil(s.T(), load2) + require.Equal(s.T(), account2, load2.AccountID) + require.Equal(s.T(), 1, load2.CurrentConcurrency) + require.Equal(s.T(), 0, load2.WaitingCount) + require.Equal(s.T(), 50, load2.LoadRate) + + // Verify account3: (0 + 0) / 1 = 0% + load3 := loadMap[account3] + require.NotNil(s.T(), load3) + require.Equal(s.T(), account3, load3.AccountID) + require.Equal(s.T(), 0, load3.CurrentConcurrency) + require.Equal(s.T(), 0, load3.WaitingCount) + require.Equal(s.T(), 0, load3.LoadRate) +} + +func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() { + // Test with empty account list + loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{}) + require.NoError(s.T(), err) + require.Empty(s.T(), loadMap) +} + +func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() { + accountID := int64(200) + slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) + + // Acquire 3 slots + ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1") + require.NoError(s.T(), err) + require.True(s.T(), ok) + ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2") + require.NoError(s.T(), err) + require.True(s.T(), ok) + ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3") + require.NoError(s.T(), err) + require.True(s.T(), ok) + + // Verify 3 slots exist + cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) + require.NoError(s.T(), err) + require.Equal(s.T(), 3, cur) + + // Manually set old timestamps for req1 and req2 (simulate expired slots) + now := time.Now().Unix() + expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL + err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err() + require.NoError(s.T(), err) + err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err() + require.NoError(s.T(), err) + + // Run cleanup + err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID) + require.NoError(s.T(), err) + + // Verify only 1 slot remains (req3) + cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID) + require.NoError(s.T(), err) + require.Equal(s.T(), 1, cur) + + // Verify req3 still exists + members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result() + require.NoError(s.T(), err) + require.Len(s.T(), members, 1) + require.Equal(s.T(), "req3", members[0]) +} + +func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() { + accountID := int64(201) + + // Acquire 2 fresh slots + ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1") + require.NoError(s.T(), err) + require.True(s.T(), ok) + ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2") + require.NoError(s.T(), err) + require.True(s.T(), ok) + + // Run cleanup (should not remove anything) + err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID) + require.NoError(s.T(), err) + + // Verify both slots still exist + cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) + require.NoError(s.T(), err) + require.Equal(s.T(), 2, cur) +} + func TestConcurrencyCacheSuite(t *testing.T) { suite.Run(t, new(ConcurrencyCacheSuite)) } diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index e1b61632..560c7767 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -837,3 +837,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) { }) } } + +// mockConcurrencyService for testing +type mockConcurrencyService struct { + accountLoads map[int64]*AccountLoadInfo + accountWaitCounts map[int64]int + acquireResults map[int64]bool +} + +func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if m.accountLoads == nil { + return map[int64]*AccountLoadInfo{}, nil + } + result := make(map[int64]*AccountLoadInfo) + for _, acc := range accounts { + if load, ok := m.accountLoads[acc.ID]; ok { + result[acc.ID] = load + } else { + result[acc.ID] = &AccountLoadInfo{ + AccountID: acc.ID, + CurrentConcurrency: 0, + WaitingCount: 0, + LoadRate: 0, + } + } + } + return result, nil +} + +func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if m.accountWaitCounts == nil { + return 0, nil + } + return m.accountWaitCounts[accountID], nil +} + +// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection +func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { + ctx := context.Background() + + t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: nil, // No concurrency service + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号") + }) + + t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: nil, + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号") + }) + + t.Run("排除账号-不选择被排除的账号", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: nil, + } + + excludedIDs := map[int64]struct{}{1: {}} + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号") + }) + + t.Run("无可用账号-返回错误", func(t *testing.T) { + repo := &mockAccountRepoForPlatform{ + accounts: []Account{}, + accountsByID: map[int64]*Account{}, + } + + cache := &mockGatewayCacheForPlatform{} + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = false + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: nil, + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "no available accounts") + }) +} From 34c102045ac46de7b28d77530bc0e7eca120af95 Mon Sep 17 00:00:00 2001 From: IanShaw <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 1 Jan 2026 04:21:18 +0800 Subject: [PATCH 31/49] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20/v1/messages?= =?UTF-8?q?=20=E9=97=B4=E6=AD=87=E6=80=A7=20400=20=E9=94=99=E8=AF=AF=20(#1?= =?UTF-8?q?8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(upstream): 修复上游格式兼容性问题 - 跳过Claude模型无signature的thinking block - 支持custom类型工具(MCP)格式转换 - 添加ClaudeCustomToolSpec结构体支持MCP工具 - 添加Custom字段验证,跳过无效custom工具 - 在convertClaudeToolsToGeminiTools中添加schema清理 - 完整的单元测试覆盖,包含边界情况 修复: Issue 0.1 signature缺失, Issue 0.2 custom工具格式 改进: Codex审查发现的2个重要问题 测试: - TestBuildParts_ThinkingBlockWithoutSignature: 验证thinking block处理 - TestBuildTools_CustomTypeTools: 验证custom工具转换和边界情况 - TestConvertClaudeToolsToGeminiTools_CustomType: 验证service层转换 * feat(gemini): 添加Gemini限额与TierID支持 实现PR1:Gemini限额与TierID功能 后端修改: - GeminiTokenInfo结构体添加TierID字段 - fetchProjectID函数返回(projectID, tierID, error) - 从LoadCodeAssist响应中提取tierID(优先IsDefault,回退到第一个非空tier) - ExchangeCode、RefreshAccountToken、GetAccessToken函数更新以处理tierID - BuildAccountCredentials函数保存tier_id到credentials 前端修改: - AccountStatusIndicator组件添加tier显示 - 支持LEGACY/PRO/ULTRA等tier类型的友好显示 - 使用蓝色badge展示tier信息 技术细节: - tierID提取逻辑:优先选择IsDefault的tier,否则选择第一个非空tier - 所有fetchProjectID调用点已更新以处理新的返回签名 - 前端gracefully处理missing/unknown tier_id * refactor(gemini): 优化TierID实现并添加安全验证 根据并发代码审查(code-reviewer, security-auditor, gemini, codex)的反馈进行改进: 安全改进: - 添加validateTierID函数验证tier_id格式和长度(最大64字符) - 限制tier_id字符集为字母数字、下划线、连字符和斜杠 - 在BuildAccountCredentials中验证tier_id后再存储 - 静默跳过无效tier_id,不阻塞账户创建 代码质量改进: - 提取extractTierIDFromAllowedTiers辅助函数消除重复代码 - 重构fetchProjectID函数,tierID提取逻辑只执行一次 - 改进代码可读性和可维护性 审查工具: - code-reviewer agent (a09848e) - security-auditor agent (a9a149c) - gemini CLI (bcc7c81) - codex (b5d8919) 修复问题: - HIGH: 未验证的tier_id输入 - MEDIUM: 代码重复(tierID提取逻辑重复2次) * fix(format): 修复 gofmt 格式问题 - 修复 claude_types.go 中的字段对齐问题 - 修复 gemini_messages_compat_service.go 中的缩进问题 * fix(upstream): 修复上游格式兼容性问题 (#14) * fix(upstream): 修复上游格式兼容性问题 - 跳过Claude模型无signature的thinking block - 支持custom类型工具(MCP)格式转换 - 添加ClaudeCustomToolSpec结构体支持MCP工具 - 添加Custom字段验证,跳过无效custom工具 - 在convertClaudeToolsToGeminiTools中添加schema清理 - 完整的单元测试覆盖,包含边界情况 修复: Issue 0.1 signature缺失, Issue 0.2 custom工具格式 改进: Codex审查发现的2个重要问题 测试: - TestBuildParts_ThinkingBlockWithoutSignature: 验证thinking block处理 - TestBuildTools_CustomTypeTools: 验证custom工具转换和边界情况 - TestConvertClaudeToolsToGeminiTools_CustomType: 验证service层转换 * fix(format): 修复 gofmt 格式问题 - 修复 claude_types.go 中的字段对齐问题 - 修复 gemini_messages_compat_service.go 中的缩进问题 * fix(format): 修复 claude_types.go 的 gofmt 格式问题 * feat(antigravity): 优化 thinking block 和 schema 处理 - 为 dummy thinking block 添加 ThoughtSignature - 重构 thinking block 处理逻辑,在每个条件分支内创建 part - 优化 excludedSchemaKeys,移除 Gemini 实际支持的字段 (minItems, maxItems, minimum, maximum, additionalProperties, format) - 添加详细注释说明 Gemini API 支持的 schema 字段 * fix(antigravity): 增强 schema 清理的安全性 基于 Codex review 建议: - 添加 format 字段白名单过滤,只保留 Gemini 支持的 date-time/date/time - 补充更多不支持的 schema 关键字到黑名单: * 组合 schema: oneOf, anyOf, allOf, not, if/then/else * 对象验证: minProperties, maxProperties, patternProperties 等 * 定义引用: $defs, definitions - 避免不支持的 schema 字段导致 Gemini API 校验失败 * fix(lint): 修复 gemini_messages_compat_service 空分支警告 - 在 cleanToolSchema 的 if 语句中添加 continue - 移除重复的注释 * fix(antigravity): 移除 minItems/maxItems 以兼容 Claude API - 将 minItems 和 maxItems 添加到 schema 黑名单 - Claude API (Vertex AI) 不支持这些数组验证字段 - 添加调试日志记录工具 schema 转换过程 - 修复 tools.14.custom.input_schema 验证错误 * fix(antigravity): 修复 additionalProperties schema 对象问题 - 将 additionalProperties 的 schema 对象转换为布尔值 true - Claude API 只支持 additionalProperties: false,不支持 schema 对象 - 修复 tools.14.custom.input_schema 验证错误 - 参考 Claude 官方文档的 JSON Schema 限制 * fix(antigravity): 修复 Claude 模型 thinking 块兼容性问题 - 完全跳过 Claude 模型的 thinking 块以避免 signature 验证失败 - 只在 Gemini 模型中使用 dummy thought signature - 修改 additionalProperties 默认值为 false(更安全) - 添加调试日志以便排查问题 * fix(upstream): 修复跨模型切换时的 dummy signature 问题 基于 Codex review 和用户场景分析的修复: 1. 问题场景 - Gemini (thinking) → Claude (thinking) 切换时 - Gemini 返回的 thinking 块使用 dummy signature - Claude API 会拒绝 dummy signature,导致 400 错误 2. 修复内容 - request_transformer.go:262: 跳过 dummy signature - 只保留真实的 Claude signature - 支持频繁的跨模型切换 3. 其他修复(基于 Codex review) - gateway_service.go:691: 修复 io.ReadAll 错误处理 - gateway_service.go:687: 条件日志(尊重 LogUpstreamErrorBody 配置) - gateway_service.go:915: 收紧 400 failover 启发式 - request_transformer.go:188: 移除签名成功日志 4. 新增功能(默认关闭) - 阶段 1: 上游错误日志(GATEWAY_LOG_UPSTREAM_ERROR_BODY) - 阶段 2: Antigravity thinking 修复 - 阶段 3: API-key beta 注入(GATEWAY_INJECT_BETA_FOR_APIKEY) - 阶段 3: 智能 400 failover(GATEWAY_FAILOVER_ON_400) 测试:所有测试通过 * fix(lint): 修复 golangci-lint 问题 - 应用 De Morgan 定律简化条件判断 - 修复 gofmt 格式问题 - 移除未使用的 min 函数 --- backend/internal/config/config.go | 15 ++ .../internal/pkg/antigravity/claude_types.go | 3 + .../pkg/antigravity/request_transformer.go | 223 +++++++++++++----- .../antigravity/request_transformer_test.go | 179 ++++++++++++++ backend/internal/pkg/claude/constants.go | 6 + .../service/antigravity_gateway_service.go | 9 + backend/internal/service/gateway_service.go | 138 +++++++++++ .../service/gemini_messages_compat_service.go | 39 ++- .../gemini_messages_compat_service_test.go | 128 ++++++++++ .../internal/service/gemini_oauth_service.go | 104 +++++--- .../internal/service/gemini_token_provider.go | 5 +- deploy/config.example.yaml | 15 ++ frontend/package-lock.json | 10 + .../account/AccountStatusIndicator.vue | 27 +++ 14 files changed, 815 insertions(+), 86 deletions(-) create mode 100644 backend/internal/pkg/antigravity/request_transformer_test.go create mode 100644 backend/internal/service/gemini_messages_compat_service_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index aeeddcb4..d3674932 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -119,6 +119,17 @@ type GatewayConfig struct { // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + + // 是否记录上游错误响应体摘要(避免输出请求内容) + LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` + // 上游错误响应体记录最大字节数(超过会截断) + LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"` + + // API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容) + InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"` + + // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) + FailoverOn400 bool `mapstructure:"failover_on_400"` } func (s *ServerConfig) Address() string { @@ -313,6 +324,10 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.log_upstream_error_body", false) + viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) + viper.SetDefault("gateway.inject_beta_for_apikey", false) + viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 01b805cd..34e6b1f4 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -54,6 +54,9 @@ type CustomToolSpec struct { InputSchema map[string]any `json:"input_schema"` } +// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格) +type ClaudeCustomToolSpec = CustomToolSpec + // SystemBlock system prompt 数组形式的元素 type SystemBlock struct { Type string `json:"type"` diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index e0b5b886..83b87a32 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) - // 检测是否启用 thinking - isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") + // 检测是否启用 thinking + requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + // 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等), + // 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。 + isThinkingEnabled := requestedThinkingEnabled && allowDummyThought + // 1. 构建 contents contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) if err != nil { @@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) // 3. 构建 generationConfig - generationConfig := buildGenerationConfig(claudeReq) + reqForGen := claudeReq + if requestedThinkingEnabled && !allowDummyThought { + log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel) + // shallow copy to avoid mutating caller's request + clone := *claudeReq + clone.Thinking = nil + reqForGen = &clone + } + generationConfig := buildGenerationConfig(reqForGen) // 4. 构建 tools tools := buildTools(claudeReq.Tools) @@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT if !hasThoughtPart && len(parts) > 0 { // 在开头添加 dummy thinking block parts = append([]GeminiPart{{ - Text: "Thinking...", - Thought: true, + Text: "Thinking...", + Thought: true, + ThoughtSignature: dummyThoughtSignature, }}, parts...) } } @@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures const dummyThoughtSignature = "skip_thought_signature_validator" +// isValidThoughtSignature 验证 thought signature 是否有效 +// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节 +func isValidThoughtSignature(signature string) bool { + // 空字符串无效 + if signature == "" { + return false + } + + // signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节) + // 参考 Claude API 文档和实际观察到的有效 signature + if len(signature) < 40 { + log.Printf("[Debug] Signature too short: len=%d", len(signature)) + return false + } + + // 检查是否是有效的 base64 字符 + // base64 字符集: A-Z, a-z, 0-9, +, /, = + for i, c := range signature { + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && + (c < '0' || c > '9') && c != '+' && c != '/' && c != '=' { + log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c) + return false + } + } + + return true +} + // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { @@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu } case "thinking": - part := GeminiPart{ - Text: block.Thinking, - Thought: true, - } - // 保留原有 signature(Claude 模型需要有效的 signature) - if block.Signature != "" { - part.ThoughtSignature = block.Signature - } else if !allowDummyThought { - // Claude 模型需要有效 signature,跳过无 signature 的 thinking block - log.Printf("Warning: skipping thinking block without signature for Claude model") + if allowDummyThought { + // Gemini 模型可以使用 dummy signature + parts = append(parts, GeminiPart{ + Text: block.Thinking, + Thought: true, + ThoughtSignature: dummyThoughtSignature, + }) continue - } else { - // Gemini 模型使用 dummy signature - part.ThoughtSignature = dummyThoughtSignature } - parts = append(parts, part) + + // Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。 + signature := strings.TrimSpace(block.Signature) + if signature == "" || signature == dummyThoughtSignature { + log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)") + continue + } + if !isValidThoughtSignature(signature) { + log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature)) + } + parts = append(parts, GeminiPart{ + Text: block.Thinking, + Thought: true, + ThoughtSignature: signature, + }) case "image": if block.Source != nil && block.Source.Type == "base64" { @@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ID: block.ID, }, } - // 保留原有 signature,或对 Gemini 模型使用 dummy signature - if block.Signature != "" { - part.ThoughtSignature = block.Signature - } else if allowDummyThought { + // 只有 Gemini 模型使用 dummy signature + // Claude 模型不设置 signature(避免验证问题) + if allowDummyThought { part.ThoughtSignature = dummyThoughtSignature } parts = append(parts, part) @@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl - for _, tool := range tools { + for i, tool := range tools { // 跳过无效工具名称 - if tool.Name == "" { + if strings.TrimSpace(tool.Name) == "" { log.Printf("Warning: skipping tool with empty name") continue } @@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { var inputSchema map[string]any // 检查是否为 custom 类型工具 (MCP) - if tool.Type == "custom" && tool.Custom != nil { - // Custom 格式: 从 custom 字段获取 description 和 input_schema + if tool.Type == "custom" { + if tool.Custom == nil || tool.Custom.InputSchema == nil { + log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name) + continue + } description = tool.Custom.Description inputSchema = tool.Custom.InputSchema + + // 调试日志:记录 custom 工具的 schema + if schemaJSON, err := json.Marshal(inputSchema); err == nil { + log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON)) + } } else { // 标准格式: 从顶层字段获取 description = tool.Description @@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 清理 JSON Schema params := cleanJSONSchema(inputSchema) - // 为 nil schema 提供默认值 if params == nil { params = map[string]any{ @@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } } + // 调试日志:记录清理后的 schema + if paramsJSON, err := json.Marshal(params); err == nil { + log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON)) + } + funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, @@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any { } // excludedSchemaKeys 不支持的 schema 字段 +// 基于 Claude API (Vertex AI) 的实际支持情况 +// 支持: type, description, enum, properties, required, additionalProperties, items +// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段 var excludedSchemaKeys = map[string]bool{ - "$schema": true, - "$id": true, - "$ref": true, - "additionalProperties": true, - "minLength": true, - "maxLength": true, - "minItems": true, - "maxItems": true, - "uniqueItems": true, - "minimum": true, - "maximum": true, - "exclusiveMinimum": true, - "exclusiveMaximum": true, - "pattern": true, - "format": true, - "default": true, - "strict": true, - "const": true, - "examples": true, - "deprecated": true, - "readOnly": true, - "writeOnly": true, - "contentMediaType": true, - "contentEncoding": true, + // 元 schema 字段 + "$schema": true, + "$id": true, + "$ref": true, + + // 字符串验证(Gemini 不支持) + "minLength": true, + "maxLength": true, + "pattern": true, + + // 数字验证(Claude API 通过 Vertex AI 不支持这些字段) + "minimum": true, + "maximum": true, + "exclusiveMinimum": true, + "exclusiveMaximum": true, + "multipleOf": true, + + // 数组验证(Claude API 通过 Vertex AI 不支持这些字段) + "uniqueItems": true, + "minItems": true, + "maxItems": true, + + // 组合 schema(Gemini 不支持) + "oneOf": true, + "anyOf": true, + "allOf": true, + "not": true, + "if": true, + "then": true, + "else": true, + "$defs": true, + "definitions": true, + + // 对象验证(仅保留 properties/required/additionalProperties) + "minProperties": true, + "maxProperties": true, + "patternProperties": true, + "propertyNames": true, + "dependencies": true, + "dependentSchemas": true, + "dependentRequired": true, + + // 其他不支持的字段 + "default": true, + "const": true, + "examples": true, + "deprecated": true, + "readOnly": true, + "writeOnly": true, + "contentMediaType": true, + "contentEncoding": true, + + // Claude 特有字段 + "strict": true, } // cleanSchemaValue 递归清理 schema 值 @@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any { continue } + // 特殊处理 format 字段:只保留 Gemini 支持的 format 值 + if k == "format" { + if formatStr, ok := val.(string); ok { + // Gemini 只支持 date-time, date, time + if formatStr == "date-time" || formatStr == "date" || formatStr == "time" { + result[k] = val + } + // 其他 format 值直接跳过 + } + continue + } + + // 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象 + if k == "additionalProperties" { + if boolVal, ok := val.(bool); ok { + result[k] = boolVal + log.Printf("[Debug] additionalProperties is bool: %v", boolVal) + } else { + // 如果是 schema 对象,转换为 false(更安全的默认值) + result[k] = false + log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val) + } + continue + } + // 递归清理所有值 result[k] = cleanSchemaValue(val) } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go new file mode 100644 index 00000000..56eebad0 --- /dev/null +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -0,0 +1,179 @@ +package antigravity + +import ( + "encoding/json" + "testing" +) + +// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 +func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { + tests := []struct { + name string + content string + allowDummyThought bool + expectedParts int + description string + }{ + { + name: "Claude model - skip thinking block without signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": ""}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: false, + expectedParts: 2, // 只有两个text block + description: "Claude模型应该跳过无signature的thinking block", + }, + { + name: "Claude model - keep thinking block with signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: false, + expectedParts: 3, // 三个block都保留 + description: "Claude模型应该保留有signature的thinking block", + }, + { + name: "Gemini model - use dummy signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": ""}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: true, + expectedParts: 3, // 三个block都保留,thinking使用dummy signature + description: "Gemini模型应该为无signature的thinking block使用dummy signature", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toolIDToName := make(map[string]string) + parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought) + + if err != nil { + t.Fatalf("buildParts() error = %v", err) + } + + if len(parts) != tt.expectedParts { + t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts) + } + }) + } +} + +// TestBuildTools_CustomTypeTools 测试custom类型工具转换 +func TestBuildTools_CustomTypeTools(t *testing.T) { + tests := []struct { + name string + tools []ClaudeTool + expectedLen int + description string + }{ + { + name: "Standard tool format", + tools: []ClaudeTool{ + { + Name: "get_weather", + Description: "Get weather information", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + expectedLen: 1, + description: "标准工具格式应该正常转换", + }, + { + name: "Custom type tool (MCP format)", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "mcp_tool", + Custom: &ClaudeCustomToolSpec{ + Description: "MCP tool description", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "param": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + expectedLen: 1, + description: "Custom类型工具应该从Custom字段读取description和input_schema", + }, + { + name: "Mixed standard and custom tools", + tools: []ClaudeTool{ + { + Name: "standard_tool", + Description: "Standard tool", + InputSchema: map[string]any{"type": "object"}, + }, + { + Type: "custom", + Name: "custom_tool", + Custom: &ClaudeCustomToolSpec{ + Description: "Custom tool", + InputSchema: map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations + description: "混合标准和custom工具应该都能正确转换", + }, + { + name: "Invalid custom tool - nil Custom field", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "invalid_custom", + // Custom 为 nil + }, + }, + expectedLen: 0, // 应该被跳过 + description: "Custom字段为nil的custom工具应该被跳过", + }, + { + name: "Invalid custom tool - nil InputSchema", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "invalid_custom", + Custom: &ClaudeCustomToolSpec{ + Description: "Invalid", + // InputSchema 为 nil + }, + }, + }, + expectedLen: 0, // 应该被跳过 + description: "InputSchema为nil的custom工具应该被跳过", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildTools(tt.tools) + + if len(result) != tt.expectedLen { + t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen) + } + + // 验证function declarations存在 + if len(result) > 0 && result[0].FunctionDeclarations != nil { + if len(result[0].FunctionDeclarations) != len(tt.tools) { + t.Errorf("%s: got %d function declarations, want %d", + tt.description, len(result[0].FunctionDeclarations), len(tt.tools)) + } + } + }) + } +} diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 97ad6c83..0db3ed4a 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking +// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth) +const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + +// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code) +const ApiKeyHaikuBetaHeader = BetaInterleavedThinking + // Claude Code 客户端默认请求头 var DefaultHeaders = map[string]string{ "User-Agent": "claude-cli/2.0.62 (external, cli)", diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index ae2976f8..5b3bf565 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } + // 调试:记录转换后的请求体(仅记录前 2000 字符) + if bodyJSON, err := json.Marshal(geminiBody); err == nil { + truncated := string(bodyJSON) + if len(truncated) > 2000 { + truncated = truncated[:2000] + "..." + } + log.Printf("[Debug] Transformed Gemini request: %s", truncated) + } + // 构建上游 action action := "generateContent" if claudeReq.Stream { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d542e9c2..5884602d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -684,6 +685,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { + // 可选:对部分 400 触发 failover(默认关闭以保持语义) + if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + // ReadAll failed, fall back to normal error handling without consuming the stream + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + if s.shouldFailoverOn400(respBody) { + if s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Account %d: 400 error, attempting failover: %s", + account.ID, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } else { + log.Printf("Account %d: 400 error, attempting failover", account.ID) + } + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + } return s.handleErrorResponse(ctx, resp, c, account) } @@ -786,6 +811,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } } return req, nil @@ -838,6 +870,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) return claude.DefaultBetaHeader } +func requestNeedsBetaFeatures(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { + return true + } + if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { + return true + } + return false +} + +func defaultApiKeyBetaHeader(body []byte) string { + modelID := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.ApiKeyHaikuBetaHeader + } + return claude.ApiKeyBetaHeader +} + +func truncateForLog(b []byte, maxBytes int) string { + if maxBytes <= 0 { + maxBytes = 2048 + } + if len(b) > maxBytes { + b = b[:maxBytes] + } + s := string(b) + // 保持一行,避免污染日志格式 + s = strings.ReplaceAll(s, "\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\r") + return s +} + +func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { + // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 默认保守:无法识别则不切换。 + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 + // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 + if strings.Contains(msg, "anthropic-beta") || + strings.Contains(msg, "beta feature") || + strings.Contains(msg, "requires beta") { + return true + } + + // thinking/tool streaming 等兼容性约束(常见于中间转换链路) + if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { + return true + } + if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { + return true + } + + return false +} + +func extractUpstreamErrorMessage(body []byte) string { + // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} + if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { + inner := strings.TrimSpace(m) + // 有些上游会把完整 JSON 作为字符串塞进 message + if strings.HasPrefix(inner, "{") { + if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + return m + } + + // 兜底:尝试顶层 message + return gjson.GetBytes(body, "message").String() +} + func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) @@ -850,6 +959,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res switch resp.StatusCode { case 400: + // 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开 + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Upstream 400 error (account=%d platform=%s type=%s): %s", + account.ID, + account.Platform, + account.Type, + truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } c.Data(http.StatusBadRequest, "application/json", body) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) case 401: @@ -1329,6 +1448,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 标记账号状态(429/529等) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 记录上游错误摘要便于排障(不回显请求内容) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + // 返回简化的错误响应 errMsg := "Upstream request failed" switch resp.StatusCode { @@ -1409,6 +1540,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } } return req, nil diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index a0bf1b6a..b1877800 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any { "properties": map[string]any{}, } } + // 清理 JSON Schema + cleanedParams := cleanToolSchema(params) funcDecls = append(funcDecls, map[string]any{ "name": name, "description": desc, - "parameters": params, + "parameters": cleanedParams, }) } @@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any { } } +// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段 +func cleanToolSchema(schema any) any { + if schema == nil { + return nil + } + + switch v := schema.(type) { + case map[string]any: + cleaned := make(map[string]any) + for key, value := range v { + // 跳过不支持的字段 + if key == "$schema" || key == "$id" || key == "$ref" || + key == "additionalProperties" || key == "minLength" || + key == "maxLength" || key == "minItems" || key == "maxItems" { + continue + } + // 递归清理嵌套对象 + cleaned[key] = cleanToolSchema(value) + } + // 规范化 type 字段为大写 + if typeVal, ok := cleaned["type"].(string); ok { + cleaned["type"] = strings.ToUpper(typeVal) + } + return cleaned + case []any: + cleaned := make([]any, len(v)) + for i, item := range v { + cleaned[i] = cleanToolSchema(item) + } + return cleaned + default: + return v + } +} + func convertClaudeGenerationConfig(req map[string]any) map[string]any { out := make(map[string]any) if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go new file mode 100644 index 00000000..d49f2eb3 --- /dev/null +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -0,0 +1,128 @@ +package service + +import ( + "testing" +) + +// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 +func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { + tests := []struct { + name string + tools any + expectedLen int + description string + }{ + { + name: "Standard tools", + tools: []any{ + map[string]any{ + "name": "get_weather", + "description": "Get weather info", + "input_schema": map[string]any{"type": "object"}, + }, + }, + expectedLen: 1, + description: "标准工具格式应该正常转换", + }, + { + name: "Custom type tool (MCP format)", + tools: []any{ + map[string]any{ + "type": "custom", + "name": "mcp_tool", + "custom": map[string]any{ + "description": "MCP tool description", + "input_schema": map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, + description: "Custom类型工具应该从custom字段读取", + }, + { + name: "Mixed standard and custom tools", + tools: []any{ + map[string]any{ + "name": "standard_tool", + "description": "Standard", + "input_schema": map[string]any{"type": "object"}, + }, + map[string]any{ + "type": "custom", + "name": "custom_tool", + "custom": map[string]any{ + "description": "Custom", + "input_schema": map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, + description: "混合工具应该都能正确转换", + }, + { + name: "Custom tool without custom field", + tools: []any{ + map[string]any{ + "type": "custom", + "name": "invalid_custom", + // 缺少 custom 字段 + }, + }, + expectedLen: 0, // 应该被跳过 + description: "缺少custom字段的custom工具应该被跳过", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertClaudeToolsToGeminiTools(tt.tools) + + if tt.expectedLen == 0 { + if result != nil { + t.Errorf("%s: expected nil result, got %v", tt.description, result) + } + return + } + + if result == nil { + t.Fatalf("%s: expected non-nil result", tt.description) + } + + if len(result) != 1 { + t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result)) + return + } + + toolDecl, ok := result[0].(map[string]any) + if !ok { + t.Fatalf("%s: result[0] is not map[string]any", tt.description) + } + + funcDecls, ok := toolDecl["functionDeclarations"].([]any) + if !ok { + t.Fatalf("%s: functionDeclarations is not []any", tt.description) + } + + toolsArr, _ := tt.tools.([]any) + expectedFuncCount := 0 + for _, tool := range toolsArr { + toolMap, _ := tool.(map[string]any) + if toolMap["name"] != "" { + // 检查是否为有效的custom工具 + if toolMap["type"] == "custom" { + if toolMap["custom"] != nil { + expectedFuncCount++ + } + } else { + expectedFuncCount++ + } + } + } + + if len(funcDecls) != expectedFuncCount { + t.Errorf("%s: expected %d function declarations, got %d", + tt.description, expectedFuncCount, len(funcDecls)) + } + }) + } +} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index e4bda5f8..221bd0f2 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "regexp" "strconv" "strings" "time" @@ -163,6 +164,45 @@ type GeminiTokenInfo struct { Scope string `json:"scope,omitempty"` ProjectID string `json:"project_id,omitempty"` OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" + TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA +} + +// validateTierID validates tier_id format and length +func validateTierID(tierID string) error { + if tierID == "" { + return nil // Empty is allowed + } + if len(tierID) > 64 { + return fmt.Errorf("tier_id exceeds maximum length of 64 characters") + } + // Allow alphanumeric, underscore, hyphen, and slash (for tier paths) + if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) { + return fmt.Errorf("tier_id contains invalid characters") + } + return nil +} + +// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response +// Prioritizes IsDefault tier, falls back to first non-empty tier +func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string { + tierID := "LEGACY" + // First pass: look for default tier + for _, tier := range allowedTiers { + if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + // Second pass: if still LEGACY, take first non-empty tier + if tierID == "LEGACY" { + for _, tier := range allowedTiers { + if strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + } + return tierID } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { @@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID := sessionProjectID + var tierID string // 对于 code_assist 模式,project_id 是必需的 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) if oauthType == "code_assist" { if projectID == "" { var err error - projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) @@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ExpiresAt: expiresAt, Scope: tokenResp.Scope, ProjectID: projectID, + TierID: tierID, OAuthType: oauthType, }, nil } @@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { - projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) if err != nil { return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) } @@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } tokenInfo.ProjectID = projectID + tokenInfo.TierID = tierID } return tokenInfo, nil @@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } + if tokenInfo.TierID != "" { + // Validate tier_id before storing + if err := validateTierID(tokenInfo.TierID); err == nil { + creds["tier_id"] = tokenInfo.TierID + } + // Silently skip invalid tier_id (don't block account creation) + } if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } @@ -398,34 +448,26 @@ func (s *GeminiOAuthService) Stop() { s.sessionStore.Stop() } -func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { +func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) { if s.codeAssist == nil { - return "", errors.New("code assist client not configured") + return "", "", errors.New("code assist client not configured") } loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) + + // Extract tierID from response (works whether CloudAICompanionProject is set or not) + tierID := "LEGACY" + if loadResp != nil { + tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers) + } + + // If LoadCodeAssist returned a project, use it if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { - return strings.TrimSpace(loadResp.CloudAICompanionProject), nil + return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. - tierID := "LEGACY" - if loadResp != nil { - for _, tier := range loadResp.AllowedTiers { - if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" { - for _, tier := range loadResp.AllowedTiers { - if strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - } - } + // (tierID already extracted above, reuse it) req := &geminicli.OnboardUserRequest{ TierID: tierID, @@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } - return "", err + return "", "", err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { switch v := resp.Response.CloudAICompanionProject.(type) { case string: - return strings.TrimSpace(v), nil + return strings.TrimSpace(v), tierID, nil case map[string]any: if id, ok := v["id"].(string); ok { - return strings.TrimSpace(id), nil + return strings.TrimSpace(id), tierID, nil } } } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } - return "", errors.New("onboardUser completed but no project_id returned") + return "", "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } if loadErr != nil { - return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct { diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 2195ec55..5f369de5 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -112,7 +112,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) + detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) if err != nil { log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) return accessToken, nil @@ -123,6 +123,9 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou account.Credentials = make(map[string]any) } account.Credentials["project_id"] = detected + if tierID != "" { + account.Credentials["tier_id"] = tierID + } _ = p.accountRepo.Update(ctx, account) } } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 5bd85d7d..5478d151 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -122,6 +122,21 @@ pricing: # Hash check interval in minutes hash_check_interval_minutes: 10 +# ============================================================================= +# Gateway (Optional) +# ============================================================================= +gateway: + # Wait time (in seconds) for upstream response headers (streaming body not affected) + response_header_timeout: 300 + # Log upstream error response body summary (safe/truncated; does not log request content) + log_upstream_error_body: false + # Max bytes to log from upstream error body + log_upstream_error_body_max_bytes: 2048 + # Auto inject anthropic-beta for API-key accounts when needed (default off) + inject_beta_for_apikey: false + # Allow failover on selected 400 errors (default off) + failover_on_400: false + # ============================================================================= # Gemini OAuth (Required for Gemini accounts) # ============================================================================= diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 6563ee0c..1770a985 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -952,6 +952,7 @@ "integrity": "sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -1367,6 +1368,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -1443,6 +1445,7 @@ "resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz", "integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==", "license": "MIT", + "peer": true, "dependencies": { "@kurkle/color": "^0.3.0" }, @@ -2040,6 +2043,7 @@ "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "dev": true, "license": "MIT", + "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -2348,6 +2352,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -2821,6 +2826,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -2854,6 +2860,7 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "devOptional": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -2926,6 +2933,7 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -3097,6 +3105,7 @@ "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "license": "MIT", + "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.25", "@vue/compiler-sfc": "3.5.25", @@ -3190,6 +3199,7 @@ "integrity": "sha512-P7OP77b2h/Pmk+lZdJ0YWs+5tJ6J2+uOQPo7tlBnY44QqQSPYvS0qVT4wqDJgwrZaLe47etJLLQRFia71GYITw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@volar/typescript": "2.4.15", "@vue/language-core": "2.2.12" diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index c1ca08fa..914678a5 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -83,6 +83,14 @@ > + + + + {{ tierDisplay }} + @@ -140,4 +148,23 @@ const statusText = computed(() => { return props.account.status }) +// Computed: tier display +const tierDisplay = computed(() => { + const credentials = props.account.credentials as Record | undefined + const tierId = credentials?.tier_id + if (!tierId || tierId === 'unknown') return null + + const tierMap: Record = { + 'free': 'Free', + 'payg': 'Pay-as-you-go', + 'pay-as-you-go': 'Pay-as-you-go', + 'enterprise': 'Enterprise', + 'LEGACY': 'Legacy', + 'PRO': 'Pro', + 'ULTRA': 'Ultra' + } + + return tierMap[tierId] || tierId +}) + From 9c88980483fabaca21ce41574a7355085ec73c3b Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 1 Jan 2026 04:26:01 +0800 Subject: [PATCH 32/49] =?UTF-8?q?fix(lint):=20=E4=BF=AE=E5=A4=8D=20golangc?= =?UTF-8?q?i-lint=20=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 gofmt 格式问题 - 修复 staticcheck SA4031 nil check 问题(只在成功时设置 release 函数) - 删除未使用的 sortAccountsByPriority 函数 --- backend/internal/handler/gateway_handler.go | 16 ++++++++++------ .../internal/handler/gemini_v1beta_handler.go | 8 +++++--- backend/internal/pkg/antigravity/claude_types.go | 2 +- backend/internal/service/gateway_service.go | 6 ------ 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 769e6700..70b42ffe 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -192,9 +192,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Account wait queue full: account=%d", account.ID) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return - } - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } else { + // Only set release function if increment succeeded + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( @@ -314,9 +316,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Account wait queue full: account=%d", account.ID) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return - } - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } else { + // Only set release function if increment succeeded + accountWaitRelease = func() { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 1959c0f3..93ab23c9 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -233,9 +233,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { log.Printf("Account wait queue full: account=%d", account.ID) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return - } - accountWaitRelease = func() { - geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } else { + // Only set release function if increment succeeded + accountWaitRelease = func() { + geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } } accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout( diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index f394d7e3..01b805cd 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -41,7 +41,7 @@ type ClaudeMetadata struct { // 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} } // 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } } type ClaudeTool struct { - Type string `json:"type,omitempty"` // "custom" 或空(标准格式) + Type string `json:"type,omitempty"` // "custom" 或空(标准格式) Name string `json:"name"` Description string `json:"description,omitempty"` // 标准格式使用 InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6c45ff0f..af9342b1 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -681,12 +681,6 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func sortAccountsByPriority(accounts []*Account) { - sort.SliceStable(accounts, func(i, j int) bool { - return accounts[i].Priority < accounts[j].Priority - }) -} - func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { sort.SliceStable(accounts, func(i, j int) bool { a, b := accounts[i], accounts[j] From e49281774d6e654e570ccc55ecd81878c5d28d01 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Wed, 31 Dec 2025 23:57:01 +0800 Subject: [PATCH 33/49] =?UTF-8?q?fix(gemini):=20=E4=BF=AE=E5=A4=8D=20P0/P1?= =?UTF-8?q?=20=E7=BA=A7=E5=88=AB=E9=97=AE=E9=A2=98=EF=BC=88429=E8=AF=AF?= =?UTF-8?q?=E5=88=A4/Tier=E4=B8=A2=E5=A4=B1/expires=5Fat/=E5=89=8D?= =?UTF-8?q?=E7=AB=AF=E4=B8=80=E8=87=B4=E6=80=A7=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P0 修复(Critical - 影响生产稳定性): - 修复 429 判断逻辑:使用 project_id 判断而非 account.Type 防止 AI Studio OAuth 被误判为 Code Assist 5分钟窗口 - 修复 Tier ID 丢失:刷新时始终保留旧值,默认 LEGACY 防止 fetchProjectID 失败导致 tier_id 被清空 - 修复 expires_at 下界:添加 minTTL=30s 保护 防止 expires_in <= 300 时生成过去时间引发刷新风暴 P1 修复(Important - 行为一致性): - 前端 isCodeAssist 判断与后端一致(支持 legacy) - 前端日期解析添加 NaN 保护 - 迁移脚本覆盖 legacy 账号 前端功能(新增): - AccountQuotaInfo 组件:Tier Badge + 二元进度条 + 倒计时 - 定时器动态管理:watch 监听限流状态 - 类型定义:GeminiCredentials 接口 测试: - ✅ TypeScript 类型检查通过 - ✅ 前端构建成功(3.33s) - ✅ Gemini + Codex 双 AI 审查通过 Refs: #gemini-quota --- .../service/gemini_messages_compat_service.go | 38 +++- .../internal/service/gemini_oauth_service.go | 81 ++++++-- backend/migrations/017_add_gemini_tier_id.sql | 30 +++ .../components/account/AccountQuotaInfo.vue | 194 ++++++++++++++++++ .../components/account/AccountUsageCell.vue | 12 +- frontend/src/types/index.ts | 16 ++ 6 files changed, 349 insertions(+), 22 deletions(-) create mode 100644 backend/migrations/017_add_gemini_tier_id.sql create mode 100644 frontend/src/components/account/AccountQuotaInfo.vue diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index b1877800..111ff462 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -1886,13 +1886,47 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont if statusCode != 429 { return } + + // 获取账号的 oauth_type、tier_id 和 project_id + oauthType := strings.TrimSpace(account.GetCredential("oauth_type")) + tierID := strings.TrimSpace(account.GetCredential("tier_id")) + projectID := strings.TrimSpace(account.GetCredential("project_id")) + + // 判断是否为 Code Assist:以 project_id 是否存在为准(更可靠) + isCodeAssist := projectID != "" + // Legacy 兼容:oauth_type 为空但 project_id 存在时视为 code_assist + if oauthType == "" && isCodeAssist { + oauthType = "code_assist" + } + resetAt := ParseGeminiRateLimitResetTime(body) if resetAt == nil { - ra := time.Now().Add(5 * time.Minute) + // 根据账号类型使用不同的默认重置时间 + var ra time.Time + if isCodeAssist { + // Code Assist: 5 分钟滚动窗口 + ra = time.Now().Add(5 * time.Minute) + log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, reset in 5min", account.ID, tierID, projectID) + } else { + // API Key / AI Studio OAuth: PST 午夜 + if ts := nextGeminiDailyResetUnix(); ts != nil { + ra = time.Unix(*ts, 0) + log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) + } else { + // 兜底:5 分钟 + ra = time.Now().Add(5 * time.Minute) + log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) + } + } _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) return } - _ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0)) + + // 使用解析到的重置时间 + resetTime := time.Unix(*resetAt, 0) + _ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) + log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", + account.ID, resetTime, oauthType, tierID) } // ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳 diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 221bd0f2..d1c1c5f6 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -259,8 +259,15 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch sessionProjectID := strings.TrimSpace(session.ProjectID) s.sessionStore.Delete(input.SessionID) - // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 + // 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差) + // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴) + const safetyWindow = 300 // 5 minutes + const minTTL = 30 // minimum 30 seconds + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow + minExpiresAt := time.Now().Unix() + minTTL + if expiresAt < minExpiresAt { + expiresAt = minExpiresAt + } projectID := sessionProjectID var tierID string @@ -275,10 +282,22 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) } + } else { + // 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID + _, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + if err != nil { + fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err) + } else { + tierID = fetchedTierID + } } if strings.TrimSpace(projectID) == "" { return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") } + // tierID 缺失时使用默认值 + if tierID == "" { + tierID = "LEGACY" + } } return &GeminiTokenInfo{ @@ -308,8 +327,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL) if err == nil { - // 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差 - expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 + // 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差) + // 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴) + const safetyWindow = 300 // 5 minutes + const minTTL = 30 // minimum 30 seconds + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow + minExpiresAt := time.Now().Unix() + minTTL + if expiresAt < minExpiresAt { + expiresAt = minExpiresAt + } return &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, @@ -396,19 +422,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A tokenInfo.ProjectID = existingProjectID } + // 尝试从账号凭证获取 tierID(向后兼容) + existingTierID := strings.TrimSpace(account.GetCredential("tier_id")) + // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. - if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { - projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) - if err != nil { - return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) + if oauthType == "code_assist" { + // 先设置默认值或保留旧值,确保 tier_id 始终有值 + if existingTierID != "" { + tokenInfo.TierID = existingTierID + } else { + tokenInfo.TierID = "LEGACY" // 默认值 } - projectID = strings.TrimSpace(projectID) - if projectID == "" { + + // 尝试自动探测 project_id 和 tier_id + needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == "" + if needDetect { + projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + if err != nil { + fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err) + } else { + if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" { + tokenInfo.ProjectID = projectID + } + // 只有当原来没有 tier_id 且探测成功时才更新 + if existingTierID == "" && tierID != "" { + tokenInfo.TierID = tierID + } + } + } + + if strings.TrimSpace(tokenInfo.ProjectID) == "" { return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } - tokenInfo.ProjectID = projectID - tokenInfo.TierID = tierID } return tokenInfo, nil @@ -466,9 +512,6 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } - // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. - // (tierID already extracted above, reuse it) - req := &geminicli.OnboardUserRequest{ TierID: tierID, Metadata: geminicli.LoadCodeAssistMetadata{ @@ -487,7 +530,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), tierID, nil } - return "", "", err + return "", tierID, err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { @@ -505,7 +548,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr if fbErr == nil && strings.TrimSpace(fallback) != "" { return strings.TrimSpace(fallback), tierID, nil } - return "", "", errors.New("onboardUser completed but no project_id returned") + return "", tierID, errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } @@ -515,9 +558,9 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(fallback), tierID, nil } if loadErr != nil { - return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct { diff --git a/backend/migrations/017_add_gemini_tier_id.sql b/backend/migrations/017_add_gemini_tier_id.sql new file mode 100644 index 00000000..0388a412 --- /dev/null +++ b/backend/migrations/017_add_gemini_tier_id.sql @@ -0,0 +1,30 @@ +-- +goose Up +-- +goose StatementBegin +-- 为 Gemini Code Assist OAuth 账号添加默认 tier_id +-- 包括显式标记为 code_assist 的账号,以及 legacy 账号(oauth_type 为空但 project_id 存在) +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{tier_id}', + '"LEGACY"', + true +) +WHERE platform = 'gemini' + AND type = 'oauth' + AND jsonb_typeof(credentials) = 'object' + AND credentials->>'tier_id' IS NULL + AND ( + credentials->>'oauth_type' = 'code_assist' + OR (credentials->>'oauth_type' IS NULL AND credentials->>'project_id' IS NOT NULL) + ); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +-- 回滚:删除 tier_id 字段 +UPDATE accounts +SET credentials = credentials - 'tier_id' +WHERE platform = 'gemini' + AND type = 'oauth' + AND credentials->>'oauth_type' = 'code_assist'; +-- +goose StatementEnd diff --git a/frontend/src/components/account/AccountQuotaInfo.vue b/frontend/src/components/account/AccountQuotaInfo.vue new file mode 100644 index 00000000..44fe1b41 --- /dev/null +++ b/frontend/src/components/account/AccountQuotaInfo.vue @@ -0,0 +1,194 @@ + + + diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index d064c55a..d457c2ff 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -169,6 +169,11 @@
-
+ + + From c5c12d4c8b44cbfecf2ee22ae3fd7810f724c638 Mon Sep 17 00:00:00 2001 From: Wesley Liddick Date: Wed, 31 Dec 2025 21:45:42 -0500 Subject: [PATCH 48/49] =?UTF-8?q?Revert=20"feat(gateway):=20=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E8=B4=9F=E8=BD=BD=E6=84=9F=E7=9F=A5=E7=9A=84=E8=B4=A6?= =?UTF-8?q?=E5=8F=B7=E8=B0=83=E5=BA=A6=E4=BC=98=E5=8C=96=20(#114)"=20(#117?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 8d252303fc4a6325956234079ce3fb676f680595. --- backend/cmd/server/wire_gen.go | 6 +- backend/internal/config/config.go | 57 -- backend/internal/config/config_test.go | 49 +- backend/internal/handler/gateway_handler.go | 112 +--- backend/internal/handler/gateway_helper.go | 22 +- .../internal/handler/gemini_v1beta_handler.go | 53 +- .../handler/openai_gateway_handler.go | 51 +- .../internal/pkg/antigravity/claude_types.go | 3 - .../pkg/antigravity/request_transformer.go | 223 ++------ .../antigravity/request_transformer_test.go | 179 ------ backend/internal/pkg/claude/constants.go | 6 - .../internal/repository/concurrency_cache.go | 185 +------ .../concurrency_cache_benchmark_test.go | 2 +- .../concurrency_cache_integration_test.go | 177 +----- backend/internal/repository/wire.go | 9 +- .../service/antigravity_gateway_service.go | 9 - .../internal/service/concurrency_service.go | 110 ---- .../service/gateway_multiplatform_test.go | 211 ------- backend/internal/service/gateway_service.go | 519 +----------------- .../service/gemini_messages_compat_service.go | 39 +- .../gemini_messages_compat_service_test.go | 128 ----- .../internal/service/gemini_oauth_service.go | 104 ++-- .../internal/service/gemini_token_provider.go | 5 +- .../service/openai_gateway_service.go | 260 --------- backend/internal/service/wire.go | 11 +- deploy/config.example.yaml | 15 - deploy/flow.md | 222 -------- frontend/package-lock.json | 10 - .../account/AccountStatusIndicator.vue | 27 - 29 files changed, 133 insertions(+), 2671 deletions(-) delete mode 100644 backend/internal/pkg/antigravity/request_transformer_test.go delete mode 100644 backend/internal/service/gemini_messages_compat_service_test.go delete mode 100644 deploy/flow.md diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1adabefe..83cba823 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -99,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) + concurrencyService := service.NewConcurrencyService(concurrencyCache) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) oAuthHandler := admin.NewOAuthHandler(oAuthService) @@ -127,10 +127,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 7927fec5..aeeddcb4 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -3,7 +3,6 @@ package config import ( "fmt" "strings" - "time" "github.com/spf13/viper" ) @@ -120,37 +119,6 @@ type GatewayConfig struct { // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` - - // 是否记录上游错误响应体摘要(避免输出请求内容) - LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` - // 上游错误响应体记录最大字节数(超过会截断) - LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"` - - // API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容) - InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"` - - // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) - FailoverOn400 bool `mapstructure:"failover_on_400"` - - // Scheduling: 账号调度相关配置 - Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` -} - -// GatewaySchedulingConfig accounts scheduling configuration. -type GatewaySchedulingConfig struct { - // 粘性会话排队配置 - StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"` - StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"` - - // 兜底排队配置 - FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` - FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` - - // 负载计算 - LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` - - // 过期槽位清理周期(0 表示禁用) - SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` } func (s *ServerConfig) Address() string { @@ -345,10 +313,6 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 - viper.SetDefault("gateway.log_upstream_error_body", false) - viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) - viper.SetDefault("gateway.inject_beta_for_apikey", false) - viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) @@ -359,12 +323,6 @@ func setDefaults() { viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) - viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) - viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) - viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) - viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) - viper.SetDefault("gateway.scheduling.load_batch_enabled", true) - viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -453,21 +411,6 @@ func (c *Config) Validate() error { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") } - if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { - return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") - } - if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 { - return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive") - } - if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 { - return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive") - } - if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { - return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") - } - if c.Gateway.Scheduling.SlotCleanupInterval < 0 { - return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") - } return nil } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 6e722a54..1f1becb8 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,11 +1,6 @@ package config -import ( - "testing" - "time" - - "github.com/spf13/viper" -) +import "testing" func TestNormalizeRunMode(t *testing.T) { tests := []struct { @@ -26,45 +21,3 @@ func TestNormalizeRunMode(t *testing.T) { } } } - -func TestLoadDefaultSchedulingConfig(t *testing.T) { - viper.Reset() - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 { - t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting) - } - if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second { - t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout) - } - if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second { - t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout) - } - if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 { - t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting) - } - if !cfg.Gateway.Scheduling.LoadBatchEnabled { - t.Fatalf("LoadBatchEnabled = false, want true") - } - if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second { - t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval) - } -} - -func TestLoadSchedulingConfigFromEnv(t *testing.T) { - viper.Reset() - t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 { - t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting) - } -} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 70b42ffe..a2f833ff 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -141,10 +141,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } else if apiKey.Group != nil { platform = apiKey.Group.Platform } - sessionKey := sessionHash - if platform == service.PlatformGemini && sessionHash != "" { - sessionKey = "gemini:" + sessionHash - } if platform == service.PlatformGemini { const maxAccountSwitches = 3 @@ -153,7 +149,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { lastFailoverStatus := 0 for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -162,13 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } - account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -178,46 +170,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return } // 转发请求 - 根据账号平台分流 @@ -230,9 +187,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -277,7 +231,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -286,13 +240,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } - account := selection.Account // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } if reqStream { sendMockWarmupStream(c, reqModel) } else { @@ -302,46 +252,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return } // 转发请求 - 根据账号平台分流 @@ -354,9 +269,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 4e049dbb..4c7bd0f0 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -83,16 +83,6 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 h.concurrencyService.DecrementWaitCount(ctx, userID) } -// IncrementAccountWaitCount increments the wait count for an account -func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait) -} - -// DecrementAccountWaitCount decrements the wait count for an account -func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) { - h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) -} - // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -136,12 +126,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) -} - -// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. -func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait) defer cancel() // Determine if ping is needed (streaming + ping format defined) @@ -215,11 +200,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType } } -// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). -func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) -} - // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 93ab23c9..4e99e00d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -197,17 +197,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) parsedReq, _ := service.ParseGatewayRequest(body) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) - sessionKey := sessionHash - if sessionHash != "" { - sessionKey = "gemini:" + sessionHash - } const maxAccountSwitches = 3 switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs) + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) if err != nil { if len(failedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) @@ -216,48 +212,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { handleGeminiFailoverExhausted(c, lastFailoverStatus) return } - account := selection.Account // 4) account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts") - return - } - canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - stream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - googleError(c, http.StatusTooManyRequests, err.Error()) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) + if err != nil { + googleError(c, http.StatusTooManyRequests, err.Error()) + return } // 5) forward (根据平台分流) @@ -270,9 +230,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 9931052d..7c9934c6 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) if len(failedAccountIDs) == 0 { @@ -156,50 +156,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) return } - account := selection.Account log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) // 3. Acquire account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return } // Forward request @@ -207,9 +171,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 34e6b1f4..01b805cd 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -54,9 +54,6 @@ type CustomToolSpec struct { InputSchema map[string]any `json:"input_schema"` } -// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格) -type ClaudeCustomToolSpec = CustomToolSpec - // SystemBlock system prompt 数组形式的元素 type SystemBlock struct { Type string `json:"type"` diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 83b87a32..e0b5b886 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) + // 检测是否启用 thinking + isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") - // 检测是否启用 thinking - requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - // 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等), - // 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。 - isThinkingEnabled := requestedThinkingEnabled && allowDummyThought - // 1. 构建 contents contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) if err != nil { @@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) // 3. 构建 generationConfig - reqForGen := claudeReq - if requestedThinkingEnabled && !allowDummyThought { - log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel) - // shallow copy to avoid mutating caller's request - clone := *claudeReq - clone.Thinking = nil - reqForGen = &clone - } - generationConfig := buildGenerationConfig(reqForGen) + generationConfig := buildGenerationConfig(claudeReq) // 4. 构建 tools tools := buildTools(claudeReq.Tools) @@ -159,9 +148,8 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT if !hasThoughtPart && len(parts) > 0 { // 在开头添加 dummy thinking block parts = append([]GeminiPart{{ - Text: "Thinking...", - Thought: true, - ThoughtSignature: dummyThoughtSignature, + Text: "Thinking...", + Thought: true, }}, parts...) } } @@ -183,34 +171,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures const dummyThoughtSignature = "skip_thought_signature_validator" -// isValidThoughtSignature 验证 thought signature 是否有效 -// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节 -func isValidThoughtSignature(signature string) bool { - // 空字符串无效 - if signature == "" { - return false - } - - // signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节) - // 参考 Claude API 文档和实际观察到的有效 signature - if len(signature) < 40 { - log.Printf("[Debug] Signature too short: len=%d", len(signature)) - return false - } - - // 检查是否是有效的 base64 字符 - // base64 字符集: A-Z, a-z, 0-9, +, /, = - for i, c := range signature { - if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && - (c < '0' || c > '9') && c != '+' && c != '/' && c != '=' { - log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c) - return false - } - } - - return true -} - // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { @@ -239,30 +199,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu } case "thinking": - if allowDummyThought { - // Gemini 模型可以使用 dummy signature - parts = append(parts, GeminiPart{ - Text: block.Thinking, - Thought: true, - ThoughtSignature: dummyThoughtSignature, - }) + part := GeminiPart{ + Text: block.Thinking, + Thought: true, + } + // 保留原有 signature(Claude 模型需要有效的 signature) + if block.Signature != "" { + part.ThoughtSignature = block.Signature + } else if !allowDummyThought { + // Claude 模型需要有效 signature,跳过无 signature 的 thinking block + log.Printf("Warning: skipping thinking block without signature for Claude model") continue + } else { + // Gemini 模型使用 dummy signature + part.ThoughtSignature = dummyThoughtSignature } - - // Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。 - signature := strings.TrimSpace(block.Signature) - if signature == "" || signature == dummyThoughtSignature { - log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)") - continue - } - if !isValidThoughtSignature(signature) { - log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature)) - } - parts = append(parts, GeminiPart{ - Text: block.Thinking, - Thought: true, - ThoughtSignature: signature, - }) + parts = append(parts, part) case "image": if block.Source != nil && block.Source.Type == "base64" { @@ -287,9 +239,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ID: block.ID, }, } - // 只有 Gemini 模型使用 dummy signature - // Claude 模型不设置 signature(避免验证问题) - if allowDummyThought { + // 保留原有 signature,或对 Gemini 模型使用 dummy signature + if block.Signature != "" { + part.ThoughtSignature = block.Signature + } else if allowDummyThought { part.ThoughtSignature = dummyThoughtSignature } parts = append(parts, part) @@ -433,9 +386,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl - for i, tool := range tools { + for _, tool := range tools { // 跳过无效工具名称 - if strings.TrimSpace(tool.Name) == "" { + if tool.Name == "" { log.Printf("Warning: skipping tool with empty name") continue } @@ -444,18 +397,10 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { var inputSchema map[string]any // 检查是否为 custom 类型工具 (MCP) - if tool.Type == "custom" { - if tool.Custom == nil || tool.Custom.InputSchema == nil { - log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name) - continue - } + if tool.Type == "custom" && tool.Custom != nil { + // Custom 格式: 从 custom 字段获取 description 和 input_schema description = tool.Custom.Description inputSchema = tool.Custom.InputSchema - - // 调试日志:记录 custom 工具的 schema - if schemaJSON, err := json.Marshal(inputSchema); err == nil { - log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON)) - } } else { // 标准格式: 从顶层字段获取 description = tool.Description @@ -464,6 +409,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 清理 JSON Schema params := cleanJSONSchema(inputSchema) + // 为 nil schema 提供默认值 if params == nil { params = map[string]any{ @@ -472,11 +418,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } } - // 调试日志:记录清理后的 schema - if paramsJSON, err := json.Marshal(params); err == nil { - log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON)) - } - funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, @@ -538,64 +479,31 @@ func cleanJSONSchema(schema map[string]any) map[string]any { } // excludedSchemaKeys 不支持的 schema 字段 -// 基于 Claude API (Vertex AI) 的实际支持情况 -// 支持: type, description, enum, properties, required, additionalProperties, items -// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段 var excludedSchemaKeys = map[string]bool{ - // 元 schema 字段 - "$schema": true, - "$id": true, - "$ref": true, - - // 字符串验证(Gemini 不支持) - "minLength": true, - "maxLength": true, - "pattern": true, - - // 数字验证(Claude API 通过 Vertex AI 不支持这些字段) - "minimum": true, - "maximum": true, - "exclusiveMinimum": true, - "exclusiveMaximum": true, - "multipleOf": true, - - // 数组验证(Claude API 通过 Vertex AI 不支持这些字段) - "uniqueItems": true, - "minItems": true, - "maxItems": true, - - // 组合 schema(Gemini 不支持) - "oneOf": true, - "anyOf": true, - "allOf": true, - "not": true, - "if": true, - "then": true, - "else": true, - "$defs": true, - "definitions": true, - - // 对象验证(仅保留 properties/required/additionalProperties) - "minProperties": true, - "maxProperties": true, - "patternProperties": true, - "propertyNames": true, - "dependencies": true, - "dependentSchemas": true, - "dependentRequired": true, - - // 其他不支持的字段 - "default": true, - "const": true, - "examples": true, - "deprecated": true, - "readOnly": true, - "writeOnly": true, - "contentMediaType": true, - "contentEncoding": true, - - // Claude 特有字段 - "strict": true, + "$schema": true, + "$id": true, + "$ref": true, + "additionalProperties": true, + "minLength": true, + "maxLength": true, + "minItems": true, + "maxItems": true, + "uniqueItems": true, + "minimum": true, + "maximum": true, + "exclusiveMinimum": true, + "exclusiveMaximum": true, + "pattern": true, + "format": true, + "default": true, + "strict": true, + "const": true, + "examples": true, + "deprecated": true, + "readOnly": true, + "writeOnly": true, + "contentMediaType": true, + "contentEncoding": true, } // cleanSchemaValue 递归清理 schema 值 @@ -615,31 +523,6 @@ func cleanSchemaValue(value any) any { continue } - // 特殊处理 format 字段:只保留 Gemini 支持的 format 值 - if k == "format" { - if formatStr, ok := val.(string); ok { - // Gemini 只支持 date-time, date, time - if formatStr == "date-time" || formatStr == "date" || formatStr == "time" { - result[k] = val - } - // 其他 format 值直接跳过 - } - continue - } - - // 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象 - if k == "additionalProperties" { - if boolVal, ok := val.(bool); ok { - result[k] = boolVal - log.Printf("[Debug] additionalProperties is bool: %v", boolVal) - } else { - // 如果是 schema 对象,转换为 false(更安全的默认值) - result[k] = false - log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val) - } - continue - } - // 递归清理所有值 result[k] = cleanSchemaValue(val) } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go deleted file mode 100644 index 56eebad0..00000000 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package antigravity - -import ( - "encoding/json" - "testing" -) - -// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 -func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { - tests := []struct { - name string - content string - allowDummyThought bool - expectedParts int - description string - }{ - { - name: "Claude model - skip thinking block without signature", - content: `[ - {"type": "text", "text": "Hello"}, - {"type": "thinking", "thinking": "Let me think...", "signature": ""}, - {"type": "text", "text": "World"} - ]`, - allowDummyThought: false, - expectedParts: 2, // 只有两个text block - description: "Claude模型应该跳过无signature的thinking block", - }, - { - name: "Claude model - keep thinking block with signature", - content: `[ - {"type": "text", "text": "Hello"}, - {"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"}, - {"type": "text", "text": "World"} - ]`, - allowDummyThought: false, - expectedParts: 3, // 三个block都保留 - description: "Claude模型应该保留有signature的thinking block", - }, - { - name: "Gemini model - use dummy signature", - content: `[ - {"type": "text", "text": "Hello"}, - {"type": "thinking", "thinking": "Let me think...", "signature": ""}, - {"type": "text", "text": "World"} - ]`, - allowDummyThought: true, - expectedParts: 3, // 三个block都保留,thinking使用dummy signature - description: "Gemini模型应该为无signature的thinking block使用dummy signature", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - toolIDToName := make(map[string]string) - parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought) - - if err != nil { - t.Fatalf("buildParts() error = %v", err) - } - - if len(parts) != tt.expectedParts { - t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts) - } - }) - } -} - -// TestBuildTools_CustomTypeTools 测试custom类型工具转换 -func TestBuildTools_CustomTypeTools(t *testing.T) { - tests := []struct { - name string - tools []ClaudeTool - expectedLen int - description string - }{ - { - name: "Standard tool format", - tools: []ClaudeTool{ - { - Name: "get_weather", - Description: "Get weather information", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "location": map[string]any{"type": "string"}, - }, - }, - }, - }, - expectedLen: 1, - description: "标准工具格式应该正常转换", - }, - { - name: "Custom type tool (MCP format)", - tools: []ClaudeTool{ - { - Type: "custom", - Name: "mcp_tool", - Custom: &ClaudeCustomToolSpec{ - Description: "MCP tool description", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{ - "param": map[string]any{"type": "string"}, - }, - }, - }, - }, - }, - expectedLen: 1, - description: "Custom类型工具应该从Custom字段读取description和input_schema", - }, - { - name: "Mixed standard and custom tools", - tools: []ClaudeTool{ - { - Name: "standard_tool", - Description: "Standard tool", - InputSchema: map[string]any{"type": "object"}, - }, - { - Type: "custom", - Name: "custom_tool", - Custom: &ClaudeCustomToolSpec{ - Description: "Custom tool", - InputSchema: map[string]any{"type": "object"}, - }, - }, - }, - expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations - description: "混合标准和custom工具应该都能正确转换", - }, - { - name: "Invalid custom tool - nil Custom field", - tools: []ClaudeTool{ - { - Type: "custom", - Name: "invalid_custom", - // Custom 为 nil - }, - }, - expectedLen: 0, // 应该被跳过 - description: "Custom字段为nil的custom工具应该被跳过", - }, - { - name: "Invalid custom tool - nil InputSchema", - tools: []ClaudeTool{ - { - Type: "custom", - Name: "invalid_custom", - Custom: &ClaudeCustomToolSpec{ - Description: "Invalid", - // InputSchema 为 nil - }, - }, - }, - expectedLen: 0, // 应该被跳过 - description: "InputSchema为nil的custom工具应该被跳过", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := buildTools(tt.tools) - - if len(result) != tt.expectedLen { - t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen) - } - - // 验证function declarations存在 - if len(result) > 0 && result[0].FunctionDeclarations != nil { - if len(result[0].FunctionDeclarations) != len(tt.tools) { - t.Errorf("%s: got %d function declarations, want %d", - tt.description, len(result[0].FunctionDeclarations), len(tt.tools)) - } - } - }) - } -} diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 0db3ed4a..97ad6c83 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -16,12 +16,6 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking -// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth) -const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming - -// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code) -const ApiKeyHaikuBetaHeader = BetaInterleavedThinking - // Claude Code 客户端默认请求头 var DefaultHeaders = map[string]string{ "User-Agent": "claude-cli/2.0.62 (external, cli)", diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 35296497..9205230b 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -2,9 +2,7 @@ package repository import ( "context" - "errors" "fmt" - "strconv" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" @@ -29,8 +27,6 @@ const ( userSlotKeyPrefix = "concurrency:user:" // 等待队列计数器格式: concurrency:wait:{userID} waitQueueKeyPrefix = "concurrency:wait:" - // 账号级等待队列计数器格式: wait:account:{accountID} - accountWaitKeyPrefix = "wait:account:" // 默认槽位过期时间(分钟),可通过配置覆盖 defaultSlotTTLMinutes = 15 @@ -116,112 +112,33 @@ var ( redis.call('EXPIRE', KEYS[1], ARGV[2]) end - return 1 - `) - - // incrementAccountWaitScript - account-level wait queue count - incrementAccountWaitScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current == false then - current = 0 - else - current = tonumber(current) - end - - if current >= tonumber(ARGV[1]) then - return 0 - end - - local newVal = redis.call('INCR', KEYS[1]) - - -- Only set TTL on first creation to avoid refreshing zombie data - if newVal == 1 then - redis.call('EXPIRE', KEYS[1], ARGV[2]) - end - - return 1 - `) + return 1 + `) // decrementWaitScript - same as before decrementWaitScript = redis.NewScript(` - local current = redis.call('GET', KEYS[1]) - if current ~= false and tonumber(current) > 0 then - redis.call('DECR', KEYS[1]) - end - return 1 - `) - - // getAccountsLoadBatchScript - batch load query (read-only) - // ARGV[1] = slot TTL (seconds, retained for compatibility) - // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... - getAccountsLoadBatchScript = redis.NewScript(` - local result = {} - - local i = 2 - while i <= #ARGV do - local accountID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:account:' .. accountID - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'wait:account:' .. accountID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, accountID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - - // cleanupExpiredSlotsScript - remove expired slots - // KEYS[1] = concurrency:account:{accountID} - // ARGV[1] = TTL (seconds) - cleanupExpiredSlotsScript = redis.NewScript(` - local key = KEYS[1] - local ttl = tonumber(ARGV[1]) - local timeResult = redis.call('TIME') - local now = tonumber(timeResult[1]) - local expireBefore = now - ttl - return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) - `) + local current = redis.call('GET', KEYS[1]) + if current ~= false and tonumber(current) > 0 then + redis.call('DECR', KEYS[1]) + end + return 1 + `) ) type concurrencyCache struct { - rdb *redis.Client - slotTTLSeconds int // 槽位过期时间(秒) - waitQueueTTLSeconds int // 等待队列过期时间(秒) + rdb *redis.Client + slotTTLSeconds int // 槽位过期时间(秒) } // NewConcurrencyCache 创建并发控制缓存 // slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟 -// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL -func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache { +func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache { if slotTTLMinutes <= 0 { slotTTLMinutes = defaultSlotTTLMinutes } - if waitQueueTTLSeconds <= 0 { - waitQueueTTLSeconds = slotTTLMinutes * 60 - } return &concurrencyCache{ - rdb: rdb, - slotTTLSeconds: slotTTLMinutes * 60, - waitQueueTTLSeconds: waitQueueTTLSeconds, + rdb: rdb, + slotTTLSeconds: slotTTLMinutes * 60, } } @@ -238,10 +155,6 @@ func waitQueueKey(userID int64) string { return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) } -func accountWaitKey(accountID int64) string { - return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) -} - // Account slot operations func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { @@ -312,75 +225,3 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() return err } - -// Account wait queue operations - -func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - key := accountWaitKey(accountID) - result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int() - if err != nil { - return false, err - } - return result == 1, nil -} - -func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { - key := accountWaitKey(accountID) - _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() - return err -} - -func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - key := accountWaitKey(accountID) - val, err := c.rdb.Get(ctx, key).Int() - if err != nil && !errors.Is(err, redis.Nil) { - return 0, err - } - if errors.Is(err, redis.Nil) { - return 0, nil - } - return val, nil -} - -func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { - if len(accounts) == 0 { - return map[int64]*service.AccountLoadInfo{}, nil - } - - args := []any{c.slotTTLSeconds} - for _, acc := range accounts { - args = append(args, acc.ID, acc.MaxConcurrency) - } - - result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() - if err != nil { - return nil, err - } - - loadMap := make(map[int64]*service.AccountLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break - } - - accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[accountID] = &service.AccountLoadInfo{ - AccountID: accountID, - CurrentConcurrency: currentConcurrency, - WaitingCount: waitingCount, - LoadRate: loadRate, - } - } - - return loadMap, nil -} - -func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { - key := accountSlotKey(accountID) - _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() - return err -} diff --git a/backend/internal/repository/concurrency_cache_benchmark_test.go b/backend/internal/repository/concurrency_cache_benchmark_test.go index 25697ab1..cafab9cb 100644 --- a/backend/internal/repository/concurrency_cache_benchmark_test.go +++ b/backend/internal/repository/concurrency_cache_benchmark_test.go @@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) { _ = rdb.Close() }() - cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache) + cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache) ctx := context.Background() for _, size := range []int{10, 100, 1000} { diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go index 5983c832..6a7c83f4 100644 --- a/backend/internal/repository/concurrency_cache_integration_test.go +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct { func (s *ConcurrencyCacheSuite) SetupTest() { s.IntegrationRedisSuite.SetupTest() - s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) + s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes) } func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { @@ -218,48 +218,6 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") } -func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { - accountID := int64(30) - waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) - - ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) - require.NoError(s.T(), err, "IncrementAccountWaitCount 1") - require.True(s.T(), ok) - - ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) - require.NoError(s.T(), err, "IncrementAccountWaitCount 2") - require.True(s.T(), ok) - - ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2) - require.NoError(s.T(), err, "IncrementAccountWaitCount 3") - require.False(s.T(), ok, "expected account wait increment over max to fail") - - ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() - require.NoError(s.T(), err, "TTL account waitKey") - s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL) - - require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount") - - val, err := s.rdb.Get(s.ctx, waitKey).Int() - if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err, "Get waitKey") - } - require.Equal(s.T(), 1, val, "expected account wait count 1") -} - -func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { - accountID := int64(301) - waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) - - require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") - - val, err := s.rdb.Get(s.ctx, waitKey).Int() - if !errors.Is(err, redis.Nil) { - require.NoError(s.T(), err, "Get waitKey") - } - require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") -} - func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { // When no slots exist, GetAccountConcurrency should return 0 cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) @@ -274,139 +232,6 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() { require.Equal(s.T(), 0, cur) } -func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() { - s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI") - // Setup: Create accounts with different load states - account1 := int64(100) - account2 := int64(101) - account3 := int64(102) - - // Account 1: 2/3 slots used, 1 waiting - ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5) - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Account 2: 1/2 slots used, 0 waiting - ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3") - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Account 3: 0/1 slots used, 0 waiting (idle) - - // Query batch load - accounts := []service.AccountWithConcurrency{ - {ID: account1, MaxConcurrency: 3}, - {ID: account2, MaxConcurrency: 2}, - {ID: account3, MaxConcurrency: 1}, - } - - loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts) - require.NoError(s.T(), err) - require.Len(s.T(), loadMap, 3) - - // Verify account1: (2 + 1) / 3 = 100% - load1 := loadMap[account1] - require.NotNil(s.T(), load1) - require.Equal(s.T(), account1, load1.AccountID) - require.Equal(s.T(), 2, load1.CurrentConcurrency) - require.Equal(s.T(), 1, load1.WaitingCount) - require.Equal(s.T(), 100, load1.LoadRate) - - // Verify account2: (1 + 0) / 2 = 50% - load2 := loadMap[account2] - require.NotNil(s.T(), load2) - require.Equal(s.T(), account2, load2.AccountID) - require.Equal(s.T(), 1, load2.CurrentConcurrency) - require.Equal(s.T(), 0, load2.WaitingCount) - require.Equal(s.T(), 50, load2.LoadRate) - - // Verify account3: (0 + 0) / 1 = 0% - load3 := loadMap[account3] - require.NotNil(s.T(), load3) - require.Equal(s.T(), account3, load3.AccountID) - require.Equal(s.T(), 0, load3.CurrentConcurrency) - require.Equal(s.T(), 0, load3.WaitingCount) - require.Equal(s.T(), 0, load3.LoadRate) -} - -func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() { - // Test with empty account list - loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{}) - require.NoError(s.T(), err) - require.Empty(s.T(), loadMap) -} - -func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() { - accountID := int64(200) - slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) - - // Acquire 3 slots - ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3") - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Verify 3 slots exist - cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) - require.NoError(s.T(), err) - require.Equal(s.T(), 3, cur) - - // Manually set old timestamps for req1 and req2 (simulate expired slots) - now := time.Now().Unix() - expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL - err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err() - require.NoError(s.T(), err) - err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err() - require.NoError(s.T(), err) - - // Run cleanup - err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID) - require.NoError(s.T(), err) - - // Verify only 1 slot remains (req3) - cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID) - require.NoError(s.T(), err) - require.Equal(s.T(), 1, cur) - - // Verify req3 still exists - members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result() - require.NoError(s.T(), err) - require.Len(s.T(), members, 1) - require.Equal(s.T(), "req3", members[0]) -} - -func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() { - accountID := int64(201) - - // Acquire 2 fresh slots - ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1") - require.NoError(s.T(), err) - require.True(s.T(), ok) - ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2") - require.NoError(s.T(), err) - require.True(s.T(), ok) - - // Run cleanup (should not remove anything) - err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID) - require.NoError(s.T(), err) - - // Verify both slots still exist - cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) - require.NoError(s.T(), err) - require.Equal(s.T(), 2, cur) -} - func TestConcurrencyCacheSuite(t *testing.T) { suite.Run(t, new(ConcurrencyCacheSuite)) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 0d579b23..2de2d1de 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -15,14 +15,7 @@ import ( // ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 // 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { - waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds()) - if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout { - waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds()) - } - if waitTTLSeconds <= 0 { - waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60 - } - return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds) + return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes) } // ProviderSet is the Wire provider set for all repositories diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 5b3bf565..ae2976f8 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -358,15 +358,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } - // 调试:记录转换后的请求体(仅记录前 2000 字符) - if bodyJSON, err := json.Marshal(geminiBody); err == nil { - truncated := string(bodyJSON) - if len(truncated) > 2000 { - truncated = truncated[:2000] + "..." - } - log.Printf("[Debug] Transformed Gemini request: %s", truncated) - } - // 构建上游 action action := "generateContent" if claudeReq.Stream { diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 65ef16db..b5229491 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -18,11 +18,6 @@ type ConcurrencyCache interface { ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) - // 账号等待队列(账号级) - IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) - DecrementAccountWaitCount(ctx context.Context, accountID int64) error - GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) - // 用户槽位管理 // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) @@ -32,12 +27,6 @@ type ConcurrencyCache interface { // 等待队列计数(只在首次创建时设置 TTL) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) DecrementWaitCount(ctx context.Context, userID int64) error - - // 批量负载查询(只读) - GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) - - // 清理过期槽位(后台任务) - CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } // generateRequestID generates a unique request ID for concurrency slot tracking @@ -72,18 +61,6 @@ type AcquireResult struct { ReleaseFunc func() // Must be called when done (typically via defer) } -type AccountWithConcurrency struct { - ID int64 - MaxConcurrency int -} - -type AccountLoadInfo struct { - AccountID int64 - CurrentConcurrency int - WaitingCount int - LoadRate int // 0-100+ (percent) -} - // AcquireAccountSlot attempts to acquire a concurrency slot for an account. // If the account is at max concurrency, it waits until a slot is available or timeout. // Returns a release function that MUST be called when the request completes. @@ -200,42 +177,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 } } -// IncrementAccountWaitCount increments the wait queue counter for an account. -func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - if s.cache == nil { - return true, nil - } - - result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) - if err != nil { - log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err) - return true, nil - } - return result, nil -} - -// DecrementAccountWaitCount decrements the wait queue counter for an account. -func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) { - if s.cache == nil { - return - } - - bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { - log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err) - } -} - -// GetAccountWaitingCount gets current wait queue count for an account. -func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - if s.cache == nil { - return 0, nil - } - return s.cache.GetAccountWaitingCount(ctx, accountID) -} - // CalculateMaxWait calculates the maximum wait queue size for a user // maxWait = userConcurrency + defaultExtraWaitSlots func CalculateMaxWait(userConcurrency int) int { @@ -245,57 +186,6 @@ func CalculateMaxWait(userConcurrency int) int { return userConcurrency + defaultExtraWaitSlots } -// GetAccountsLoadBatch returns load info for multiple accounts. -func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { - if s.cache == nil { - return map[int64]*AccountLoadInfo{}, nil - } - return s.cache.GetAccountsLoadBatch(ctx, accounts) -} - -// CleanupExpiredAccountSlots removes expired slots for one account (background task). -func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { - if s.cache == nil { - return nil - } - return s.cache.CleanupExpiredAccountSlots(ctx, accountID) -} - -// StartSlotCleanupWorker starts a background cleanup worker for expired account slots. -func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) { - if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 { - return - } - - runCleanup := func() { - listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - accounts, err := accountRepo.ListSchedulable(listCtx) - cancel() - if err != nil { - log.Printf("Warning: list schedulable accounts failed: %v", err) - return - } - for _, account := range accounts { - accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second) - err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) - accountCancel() - if err != nil { - log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err) - } - } - } - - go func() { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - runCleanup() - for range ticker.C { - runCleanup() - } - }() -} - // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 560c7767..d779bcfa 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -261,34 +261,6 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户") } -func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) { - ctx := context.Background() - - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, - {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: testConfig(), - } - - acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) - require.NoError(t, err) - require.NotNil(t, acc) - require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") -} - // TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户 func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) { ctx := context.Background() @@ -604,32 +576,6 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { ctx := context.Background() - t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey}, - {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: testConfig(), - } - - acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini) - require.NoError(t, err) - require.NotNil(t, acc) - require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户") - }) - t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ @@ -837,160 +783,3 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) { }) } } - -// mockConcurrencyService for testing -type mockConcurrencyService struct { - accountLoads map[int64]*AccountLoadInfo - accountWaitCounts map[int64]int - acquireResults map[int64]bool -} - -func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { - if m.accountLoads == nil { - return map[int64]*AccountLoadInfo{}, nil - } - result := make(map[int64]*AccountLoadInfo) - for _, acc := range accounts { - if load, ok := m.accountLoads[acc.ID]; ok { - result[acc.ID] = load - } else { - result[acc.ID] = &AccountLoadInfo{ - AccountID: acc.ID, - CurrentConcurrency: 0, - WaitingCount: 0, - LoadRate: 0, - } - } - } - return result, nil -} - -func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - if m.accountWaitCounts == nil { - return 0, nil - } - return m.accountWaitCounts[accountID], nil -} - -// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection -func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { - ctx := context.Background() - - t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, - {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = false - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, // No concurrency service - } - - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) - require.NoError(t, err) - require.NotNil(t, result) - require.NotNil(t, result.Account) - require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号") - }) - - t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, - {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = true - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, - } - - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) - require.NoError(t, err) - require.NotNil(t, result) - require.NotNil(t, result.Account) - require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号") - }) - - t.Run("排除账号-不选择被排除的账号", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{ - {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, - {ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, - }, - accountsByID: map[int64]*Account{}, - } - for i := range repo.accounts { - repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = false - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, - } - - excludedIDs := map[int64]struct{}{1: {}} - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs) - require.NoError(t, err) - require.NotNil(t, result) - require.NotNil(t, result.Account) - require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号") - }) - - t.Run("无可用账号-返回错误", func(t *testing.T) { - repo := &mockAccountRepoForPlatform{ - accounts: []Account{}, - accountsByID: map[int64]*Account{}, - } - - cache := &mockGatewayCacheForPlatform{} - - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = false - - svc := &GatewayService{ - accountRepo: repo, - cache: cache, - cfg: cfg, - concurrencyService: nil, - } - - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) - require.Error(t, err) - require.Nil(t, result) - require.Contains(t, err.Error(), "no available accounts") - }) -} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index cb60131b..d542e9c2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -13,14 +13,12 @@ import ( "log" "net/http" "regexp" - "sort" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -68,20 +66,6 @@ type GatewayCache interface { RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error } -type AccountWaitPlan struct { - AccountID int64 - MaxConcurrency int - Timeout time.Duration - MaxWaiting int -} - -type AccountSelectionResult struct { - Account *Account - Acquired bool - ReleaseFunc func() - WaitPlan *AccountWaitPlan // nil means no wait allowed -} - // ClaudeUsage 表示Claude API返回的usage信息 type ClaudeUsage struct { InputTokens int `json:"input_tokens"` @@ -124,7 +108,6 @@ type GatewayService struct { identityService *IdentityService httpUpstream HTTPUpstream deferredService *DeferredService - concurrencyService *ConcurrencyService } // NewGatewayService creates a new GatewayService @@ -136,7 +119,6 @@ func NewGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, - concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -152,7 +134,6 @@ func NewGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, - concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -202,14 +183,6 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return "" } -// BindStickySession sets session -> account binding with standard TTL. -func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 { - return nil - } - return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) -} - func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" @@ -359,354 +332,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } -// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { - stickyAccountID = accountID - } - } - if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) - if err != nil { - return nil, err - } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - - platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) - if err != nil { - return nil, err - } - preferOAuth := platform == PlatformGemini - - accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err != nil { - return nil, err - } - if len(accounts) == 0 { - return nil, errors.New("no available accounts") - } - - isExcluded := func(accountID int64) bool { - if excludedIDs == nil { - return false - } - _, excluded := excludedIDs[accountID] - return excluded - } - - // ============ Layer 1: 粘性会话优先 ============ - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) && - account.IsSchedulable() && - (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - } - } - - // ============ Layer 2: 负载感知选择 ============ - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] - if isExcluded(acc.ID) { - continue - } - if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { - continue - } - candidates = append(candidates, acc) - } - - if len(candidates) == 0 { - return nil, errors.New("no available accounts") - } - - accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) - for _, acc := range candidates { - accountLoads = append(accountLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.Concurrency, - }) - } - - loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) - if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { - return result, nil - } - } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } - } - - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - if preferOAuth && a.account.Type != b.account.Type { - return a.account.Type == AccountTypeOAuth - } - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - - for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - } - } - } - - // ============ Layer 3: 兜底排队 ============ - sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) - for _, acc := range candidates { - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - return nil, errors.New("no available accounts") -} - -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { - ordered := append([]*Account(nil), candidates...) - sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) - - for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, true - } - } - - return nil, false -} - -func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { - if s.cfg != nil { - return s.cfg.Gateway.Scheduling - } - return config.GatewaySchedulingConfig{ - StickySessionMaxWaiting: 3, - StickySessionWaitTimeout: 45 * time.Second, - FallbackWaitTimeout: 30 * time.Second, - FallbackMaxWaiting: 100, - LoadBatchEnabled: true, - SlotCleanupInterval: 30 * time.Second, - } -} - -func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform != "" { - return forcePlatform, true, nil - } - if groupID != nil { - group, err := s.groupRepo.GetByID(ctx, *groupID) - if err != nil { - return "", false, fmt.Errorf("get group failed: %w", err) - } - return group.Platform, false, nil - } - return PlatformAnthropic, false, nil -} - -func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform - if useMixed { - platforms := []string{platform, PlatformAntigravity} - var accounts []Account - var err error - if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) - } - if err != nil { - return nil, useMixed, err - } - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - filtered = append(filtered, acc) - } - return filtered, useMixed, nil - } - - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) - if err == nil && len(accounts) == 0 && hasForcePlatform { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } - if err != nil { - return nil, useMixed, err - } - return accounts, useMixed, nil -} - -func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { - if account == nil { - return false - } - if useMixed { - if account.Platform == platform { - return true - } - return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() - } - return account.Platform == platform -} - -func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - if s.concurrencyService == nil { - return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil - } - return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) -} - -func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { - sort.SliceStable(accounts, func(i, j int) bool { - a, b := accounts[i], accounts[j] - if a.Priority != b.Priority { - return a.Priority < b.Priority - } - switch { - case a.LastUsedAt == nil && b.LastUsedAt != nil: - return true - case a.LastUsedAt != nil && b.LastUsedAt == nil: - return false - case a.LastUsedAt == nil && b.LastUsedAt == nil: - if preferOAuth && a.Type != b.Type { - return a.Type == AccountTypeOAuth - } - return false - default: - return a.LastUsedAt.Before(*b.LastUsedAt) - } - }) -} - // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { - preferOAuth := platform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) @@ -762,9 +389,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } + // keep selected (both never used) default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc @@ -794,7 +419,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { platforms := []string{nativePlatform, PlatformAntigravity} - preferOAuth := nativePlatform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" { @@ -854,9 +478,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } + // keep selected (both never used) default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc @@ -1062,30 +684,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { - // 可选:对部分 400 触发 failover(默认关闭以保持语义) - if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - // ReadAll failed, fall back to normal error handling without consuming the stream - return s.handleErrorResponse(ctx, resp, c, account) - } - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - if s.shouldFailoverOn400(respBody) { - if s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "Account %d: 400 error, attempting failover: %s", - account.ID, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } else { - log.Printf("Account %d: 400 error, attempting failover", account.ID) - } - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} - } - } return s.handleErrorResponse(ctx, resp, c, account) } @@ -1188,13 +786,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { - // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultApiKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) - } - } } return req, nil @@ -1247,83 +838,6 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) return claude.DefaultBetaHeader } -func requestNeedsBetaFeatures(body []byte) bool { - tools := gjson.GetBytes(body, "tools") - if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { - return true - } - if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { - return true - } - return false -} - -func defaultApiKeyBetaHeader(body []byte) string { - modelID := gjson.GetBytes(body, "model").String() - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.ApiKeyHaikuBetaHeader - } - return claude.ApiKeyBetaHeader -} - -func truncateForLog(b []byte, maxBytes int) string { - if maxBytes <= 0 { - maxBytes = 2048 - } - if len(b) > maxBytes { - b = b[:maxBytes] - } - s := string(b) - // 保持一行,避免污染日志格式 - s = strings.ReplaceAll(s, "\n", "\\n") - s = strings.ReplaceAll(s, "\r", "\\r") - return s -} - -func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 - // 默认保守:无法识别则不切换。 - msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) - if msg == "" { - return false - } - - // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 - // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 - if strings.Contains(msg, "anthropic-beta") || - strings.Contains(msg, "beta feature") || - strings.Contains(msg, "requires beta") { - return true - } - - // thinking/tool streaming 等兼容性约束(常见于中间转换链路) - if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { - return true - } - if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { - return true - } - - return false -} - -func extractUpstreamErrorMessage(body []byte) string { - // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} - if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { - inner := strings.TrimSpace(m) - // 有些上游会把完整 JSON 作为字符串塞进 message - if strings.HasPrefix(inner, "{") { - if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { - return innerMsg - } - } - return m - } - - // 兜底:尝试顶层 message - return gjson.GetBytes(body, "message").String() -} - func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) @@ -1336,16 +850,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res switch resp.StatusCode { case 400: - // 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开 - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "Upstream 400 error (account=%d platform=%s type=%s): %s", - account.ID, - account.Platform, - account.Type, - truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } c.Data(http.StatusBadRequest, "application/json", body) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) case 401: @@ -1825,18 +1329,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 标记账号状态(429/529等) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - // 记录上游错误摘要便于排障(不回显请求内容) - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", - resp.StatusCode, - account.ID, - account.Platform, - account.Type, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } - // 返回简化的错误响应 errMsg := "Upstream request failed" switch resp.StatusCode { @@ -1917,13 +1409,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { - // API-key:与 messages 同步的按需 beta 注入(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultApiKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) - } - } } return req, nil diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index b1877800..a0bf1b6a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2278,13 +2278,11 @@ func convertClaudeToolsToGeminiTools(tools any) []any { "properties": map[string]any{}, } } - // 清理 JSON Schema - cleanedParams := cleanToolSchema(params) funcDecls = append(funcDecls, map[string]any{ "name": name, "description": desc, - "parameters": cleanedParams, + "parameters": params, }) } @@ -2298,41 +2296,6 @@ func convertClaudeToolsToGeminiTools(tools any) []any { } } -// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段 -func cleanToolSchema(schema any) any { - if schema == nil { - return nil - } - - switch v := schema.(type) { - case map[string]any: - cleaned := make(map[string]any) - for key, value := range v { - // 跳过不支持的字段 - if key == "$schema" || key == "$id" || key == "$ref" || - key == "additionalProperties" || key == "minLength" || - key == "maxLength" || key == "minItems" || key == "maxItems" { - continue - } - // 递归清理嵌套对象 - cleaned[key] = cleanToolSchema(value) - } - // 规范化 type 字段为大写 - if typeVal, ok := cleaned["type"].(string); ok { - cleaned["type"] = strings.ToUpper(typeVal) - } - return cleaned - case []any: - cleaned := make([]any, len(v)) - for i, item := range v { - cleaned[i] = cleanToolSchema(item) - } - return cleaned - default: - return v - } -} - func convertClaudeGenerationConfig(req map[string]any) map[string]any { out := make(map[string]any) if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go deleted file mode 100644 index d49f2eb3..00000000 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package service - -import ( - "testing" -) - -// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 -func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { - tests := []struct { - name string - tools any - expectedLen int - description string - }{ - { - name: "Standard tools", - tools: []any{ - map[string]any{ - "name": "get_weather", - "description": "Get weather info", - "input_schema": map[string]any{"type": "object"}, - }, - }, - expectedLen: 1, - description: "标准工具格式应该正常转换", - }, - { - name: "Custom type tool (MCP format)", - tools: []any{ - map[string]any{ - "type": "custom", - "name": "mcp_tool", - "custom": map[string]any{ - "description": "MCP tool description", - "input_schema": map[string]any{"type": "object"}, - }, - }, - }, - expectedLen: 1, - description: "Custom类型工具应该从custom字段读取", - }, - { - name: "Mixed standard and custom tools", - tools: []any{ - map[string]any{ - "name": "standard_tool", - "description": "Standard", - "input_schema": map[string]any{"type": "object"}, - }, - map[string]any{ - "type": "custom", - "name": "custom_tool", - "custom": map[string]any{ - "description": "Custom", - "input_schema": map[string]any{"type": "object"}, - }, - }, - }, - expectedLen: 1, - description: "混合工具应该都能正确转换", - }, - { - name: "Custom tool without custom field", - tools: []any{ - map[string]any{ - "type": "custom", - "name": "invalid_custom", - // 缺少 custom 字段 - }, - }, - expectedLen: 0, // 应该被跳过 - description: "缺少custom字段的custom工具应该被跳过", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertClaudeToolsToGeminiTools(tt.tools) - - if tt.expectedLen == 0 { - if result != nil { - t.Errorf("%s: expected nil result, got %v", tt.description, result) - } - return - } - - if result == nil { - t.Fatalf("%s: expected non-nil result", tt.description) - } - - if len(result) != 1 { - t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result)) - return - } - - toolDecl, ok := result[0].(map[string]any) - if !ok { - t.Fatalf("%s: result[0] is not map[string]any", tt.description) - } - - funcDecls, ok := toolDecl["functionDeclarations"].([]any) - if !ok { - t.Fatalf("%s: functionDeclarations is not []any", tt.description) - } - - toolsArr, _ := tt.tools.([]any) - expectedFuncCount := 0 - for _, tool := range toolsArr { - toolMap, _ := tool.(map[string]any) - if toolMap["name"] != "" { - // 检查是否为有效的custom工具 - if toolMap["type"] == "custom" { - if toolMap["custom"] != nil { - expectedFuncCount++ - } - } else { - expectedFuncCount++ - } - } - } - - if len(funcDecls) != expectedFuncCount { - t.Errorf("%s: expected %d function declarations, got %d", - tt.description, expectedFuncCount, len(funcDecls)) - } - }) - } -} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 221bd0f2..e4bda5f8 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net/http" - "regexp" "strconv" "strings" "time" @@ -164,45 +163,6 @@ type GeminiTokenInfo struct { Scope string `json:"scope,omitempty"` ProjectID string `json:"project_id,omitempty"` OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" - TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA -} - -// validateTierID validates tier_id format and length -func validateTierID(tierID string) error { - if tierID == "" { - return nil // Empty is allowed - } - if len(tierID) > 64 { - return fmt.Errorf("tier_id exceeds maximum length of 64 characters") - } - // Allow alphanumeric, underscore, hyphen, and slash (for tier paths) - if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) { - return fmt.Errorf("tier_id contains invalid characters") - } - return nil -} - -// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response -// Prioritizes IsDefault tier, falls back to first non-empty tier -func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string { - tierID := "LEGACY" - // First pass: look for default tier - for _, tier := range allowedTiers { - if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - // Second pass: if still LEGACY, take first non-empty tier - if tierID == "LEGACY" { - for _, tier := range allowedTiers { - if strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - } - return tierID } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { @@ -263,14 +223,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID := sessionProjectID - var tierID string // 对于 code_assist 模式,project_id 是必需的 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) if oauthType == "code_assist" { if projectID == "" { var err error - projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) @@ -289,7 +248,6 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ExpiresAt: expiresAt, Scope: tokenResp.Scope, ProjectID: projectID, - TierID: tierID, OAuthType: oauthType, }, nil } @@ -399,7 +357,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { - projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) if err != nil { return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) } @@ -408,7 +366,6 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } tokenInfo.ProjectID = projectID - tokenInfo.TierID = tierID } return tokenInfo, nil @@ -431,13 +388,6 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } - if tokenInfo.TierID != "" { - // Validate tier_id before storing - if err := validateTierID(tokenInfo.TierID); err == nil { - creds["tier_id"] = tokenInfo.TierID - } - // Silently skip invalid tier_id (don't block account creation) - } if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } @@ -448,26 +398,34 @@ func (s *GeminiOAuthService) Stop() { s.sessionStore.Stop() } -func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) { +func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { if s.codeAssist == nil { - return "", "", errors.New("code assist client not configured") + return "", errors.New("code assist client not configured") } loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) - - // Extract tierID from response (works whether CloudAICompanionProject is set or not) - tierID := "LEGACY" - if loadResp != nil { - tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers) - } - - // If LoadCodeAssist returned a project, use it if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { - return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil + return strings.TrimSpace(loadResp.CloudAICompanionProject), nil } // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. - // (tierID already extracted above, reuse it) + tierID := "LEGACY" + if loadResp != nil { + for _, tier := range loadResp.AllowedTiers { + if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" { + for _, tier := range loadResp.AllowedTiers { + if strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + } + } req := &geminicli.OnboardUserRequest{ TierID: tierID, @@ -485,39 +443,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), tierID, nil + return strings.TrimSpace(fallback), nil } - return "", "", err + return "", err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { switch v := resp.Response.CloudAICompanionProject.(type) { case string: - return strings.TrimSpace(v), tierID, nil + return strings.TrimSpace(v), nil case map[string]any: if id, ok := v["id"].(string); ok { - return strings.TrimSpace(id), tierID, nil + return strings.TrimSpace(id), nil } } } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), tierID, nil + return strings.TrimSpace(fallback), nil } - return "", "", errors.New("onboardUser completed but no project_id returned") + return "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), tierID, nil + return strings.TrimSpace(fallback), nil } if loadErr != nil { - return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct { diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 5f369de5..2195ec55 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -112,7 +112,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) + detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) if err != nil { log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) return accessToken, nil @@ -123,9 +123,6 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou account.Credentials = make(map[string]any) } account.Credentials["project_id"] = detected - if tierID != "" { - account.Credentials["tier_id"] = tierID - } _ = p.accountRepo.Update(ctx, account) } } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f8eb29bd..84e98679 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -13,7 +13,6 @@ import ( "log" "net/http" "regexp" - "sort" "strconv" "strings" "time" @@ -81,7 +80,6 @@ type OpenAIGatewayService struct { userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config - concurrencyService *ConcurrencyService billingService *BillingService rateLimitService *RateLimitService billingCacheService *BillingCacheService @@ -97,7 +95,6 @@ func NewOpenAIGatewayService( userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, - concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, @@ -111,7 +108,6 @@ func NewOpenAIGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, - concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, @@ -130,14 +126,6 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { return hex.EncodeToString(hash[:]) } -// BindStickySession sets session -> account binding with standard TTL. -func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 { - return nil - } - return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL) -} - // SelectAccount selects an OpenAI account with sticky session support func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") @@ -230,254 +218,6 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C return selected, nil } -// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. -func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil { - stickyAccountID = accountID - } - } - if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) - if err != nil { - return nil, err - } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - - accounts, err := s.listSchedulableAccounts(ctx, groupID) - if err != nil { - return nil, err - } - if len(accounts) == 0 { - return nil, errors.New("no available accounts") - } - - isExcluded := func(accountID int64) bool { - if excludedIDs == nil { - return false - } - _, excluded := excludedIDs[accountID] - return excluded - } - - // ============ Layer 1: Sticky session ============ - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && account.IsSchedulable() && account.IsOpenAI() && - (requestedModel == "" || account.IsModelSupported(requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - } - } - - // ============ Layer 2: Load-aware selection ============ - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] - if isExcluded(acc.ID) { - continue - } - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { - continue - } - candidates = append(candidates, acc) - } - - if len(candidates) == 0 { - return nil, errors.New("no available accounts") - } - - accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) - for _, acc := range candidates { - accountLoads = append(accountLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.Concurrency, - }) - } - - loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) - if err != nil { - ordered := append([]*Account(nil), candidates...) - sortAccountsByPriorityAndLastUsed(ordered, false) - for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) - } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - } - } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } - } - - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - - for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) - } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - } - } - } - - // ============ Layer 3: Fallback wait ============ - sortAccountsByPriorityAndLastUsed(candidates, false) - for _, acc := range candidates { - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - - return nil, errors.New("no available accounts") -} - -func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) - } - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - return accounts, nil -} - -func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - if s.concurrencyService == nil { - return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil - } - return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) -} - -func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { - if s.cfg != nil { - return s.cfg.Gateway.Scheduling - } - return config.GatewaySchedulingConfig{ - StickySessionMaxWaiting: 3, - StickySessionWaitTimeout: 45 * time.Second, - FallbackWaitTimeout: 30 * time.Second, - FallbackMaxWaiting: 100, - LoadBatchEnabled: true, - SlotCleanupInterval: 30 * time.Second, - } -} - // GetAccessToken gets the access token for an OpenAI account func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index a202ccf2..81e01d47 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -73,15 +73,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh return svc } -// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. -func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { - svc := NewConcurrencyService(cache) - if cfg != nil { - svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) - } - return svc -} - // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -116,7 +107,7 @@ var ProviderSet = wire.NewSet( ProvideEmailQueueService, NewTurnstileService, NewSubscriptionService, - ProvideConcurrencyService, + NewConcurrencyService, NewIdentityService, NewCRSSyncService, ProvideUpdateService, diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 5478d151..5bd85d7d 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -122,21 +122,6 @@ pricing: # Hash check interval in minutes hash_check_interval_minutes: 10 -# ============================================================================= -# Gateway (Optional) -# ============================================================================= -gateway: - # Wait time (in seconds) for upstream response headers (streaming body not affected) - response_header_timeout: 300 - # Log upstream error response body summary (safe/truncated; does not log request content) - log_upstream_error_body: false - # Max bytes to log from upstream error body - log_upstream_error_body_max_bytes: 2048 - # Auto inject anthropic-beta for API-key accounts when needed (default off) - inject_beta_for_apikey: false - # Allow failover on selected 400 errors (default off) - failover_on_400: false - # ============================================================================= # Gemini OAuth (Required for Gemini accounts) # ============================================================================= diff --git a/deploy/flow.md b/deploy/flow.md deleted file mode 100644 index 0904c72f..00000000 --- a/deploy/flow.md +++ /dev/null @@ -1,222 +0,0 @@ -```mermaid -flowchart TD - %% Master dispatch - A[HTTP Request] --> B{Route} - B -->|v1 messages| GA0 - B -->|openai v1 responses| OA0 - B -->|v1beta models model action| GM0 - B -->|v1 messages count tokens| GT0 - B -->|v1beta models list or get| GL0 - - %% ========================= - %% FLOW A: Claude Gateway - %% ========================= - subgraph FLOW_A["v1 messages Claude Gateway"] - GA0[Auth middleware] --> GA1[Read body] - GA1 -->|empty| GA1E[400 invalid_request_error] - GA1 --> GA2[ParseGatewayRequest] - GA2 -->|parse error| GA2E[400 invalid_request_error] - GA2 --> GA3{model present} - GA3 -->|no| GA3E[400 invalid_request_error] - GA3 --> GA4[streamStarted false] - GA4 --> GA5[IncrementWaitCount user] - GA5 -->|queue full| GA5E[429 rate_limit_error] - GA5 --> GA6[AcquireUserSlotWithWait] - GA6 -->|timeout or fail| GA6E[429 rate_limit_error] - GA6 --> GA7[BillingEligibility check post wait] - GA7 -->|fail| GA7E[403 billing_error] - GA7 --> GA8[Generate sessionHash] - GA8 --> GA9[Resolve platform] - GA9 --> GA10{platform gemini} - GA10 -->|yes| GA10Y[sessionKey gemini hash] - GA10 -->|no| GA10N[sessionKey hash] - GA10Y --> GA11 - GA10N --> GA11 - - GA11[SelectAccountWithLoadAwareness] -->|err and no failed| GA11E1[503 no available accounts] - GA11 -->|err and failed| GA11E2[map failover error] - GA11 --> GA12[Warmup intercept] - GA12 -->|yes| GA12Y[return mock and release if held] - GA12 -->|no| GA13[Acquire account slot or wait] - GA13 -->|wait queue full| GA13E1[429 rate_limit_error] - GA13 -->|wait timeout| GA13E2[429 concurrency limit] - GA13 --> GA14[BindStickySession if waited] - GA14 --> GA15{account platform antigravity} - GA15 -->|yes| GA15Y[ForwardGemini antigravity] - GA15 -->|no| GA15N[Forward Claude] - GA15Y --> GA16[Release account slot and dec account wait] - GA15N --> GA16 - GA16 --> GA17{UpstreamFailoverError} - GA17 -->|yes| GA18[mark failedAccountIDs and map error if exceed] - GA18 -->|loop| GA11 - GA17 -->|no| GA19[success async RecordUsage and return] - GA19 --> GA20[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW B: OpenAI - %% ========================= - subgraph FLOW_B["openai v1 responses"] - OA0[Auth middleware] --> OA1[Read body] - OA1 -->|empty| OA1E[400 invalid_request_error] - OA1 --> OA2[json Unmarshal body] - OA2 -->|parse error| OA2E[400 invalid_request_error] - OA2 --> OA3{model present} - OA3 -->|no| OA3E[400 invalid_request_error] - OA3 --> OA4{User Agent Codex CLI} - OA4 -->|no| OA4N[set default instructions] - OA4 -->|yes| OA4Y[no change] - OA4N --> OA5 - OA4Y --> OA5 - OA5[streamStarted false] --> OA6[IncrementWaitCount user] - OA6 -->|queue full| OA6E[429 rate_limit_error] - OA6 --> OA7[AcquireUserSlotWithWait] - OA7 -->|timeout or fail| OA7E[429 rate_limit_error] - OA7 --> OA8[BillingEligibility check post wait] - OA8 -->|fail| OA8E[403 billing_error] - OA8 --> OA9[sessionHash sha256 session_id] - OA9 --> OA10[SelectAccountWithLoadAwareness] - OA10 -->|err and no failed| OA10E1[503 no available accounts] - OA10 -->|err and failed| OA10E2[map failover error] - OA10 --> OA11[Acquire account slot or wait] - OA11 -->|wait queue full| OA11E1[429 rate_limit_error] - OA11 -->|wait timeout| OA11E2[429 concurrency limit] - OA11 --> OA12[BindStickySession openai hash if waited] - OA12 --> OA13[Forward OpenAI upstream] - OA13 --> OA14[Release account slot and dec account wait] - OA14 --> OA15{UpstreamFailoverError} - OA15 -->|yes| OA16[mark failedAccountIDs and map error if exceed] - OA16 -->|loop| OA10 - OA15 -->|no| OA17[success async RecordUsage and return] - OA17 --> OA18[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW C: Gemini Native - %% ========================= - subgraph FLOW_C["v1beta models model action Gemini Native"] - GM0[Auth middleware] --> GM1[Validate platform] - GM1 -->|invalid| GM1E[400 googleError] - GM1 --> GM2[Parse path modelName action] - GM2 -->|invalid| GM2E[400 googleError] - GM2 --> GM3{action supported} - GM3 -->|no| GM3E[404 googleError] - GM3 --> GM4[Read body] - GM4 -->|empty| GM4E[400 googleError] - GM4 --> GM5[streamStarted false] - GM5 --> GM6[IncrementWaitCount user] - GM6 -->|queue full| GM6E[429 googleError] - GM6 --> GM7[AcquireUserSlotWithWait] - GM7 -->|timeout or fail| GM7E[429 googleError] - GM7 --> GM8[BillingEligibility check post wait] - GM8 -->|fail| GM8E[403 googleError] - GM8 --> GM9[Generate sessionHash] - GM9 --> GM10[sessionKey gemini hash] - GM10 --> GM11[SelectAccountWithLoadAwareness] - GM11 -->|err and no failed| GM11E1[503 googleError] - GM11 -->|err and failed| GM11E2[mapGeminiUpstreamError] - GM11 --> GM12[Acquire account slot or wait] - GM12 -->|wait queue full| GM12E1[429 googleError] - GM12 -->|wait timeout| GM12E2[429 googleError] - GM12 --> GM13[BindStickySession if waited] - GM13 --> GM14{account platform antigravity} - GM14 -->|yes| GM14Y[ForwardGemini antigravity] - GM14 -->|no| GM14N[ForwardNative] - GM14Y --> GM15[Release account slot and dec account wait] - GM14N --> GM15 - GM15 --> GM16{UpstreamFailoverError} - GM16 -->|yes| GM17[mark failedAccountIDs and map error if exceed] - GM17 -->|loop| GM11 - GM16 -->|no| GM18[success async RecordUsage and return] - GM18 --> GM19[defer release user slot and dec wait count] - end - - %% ========================= - %% FLOW D: CountTokens - %% ========================= - subgraph FLOW_D["v1 messages count tokens"] - GT0[Auth middleware] --> GT1[Read body] - GT1 -->|empty| GT1E[400 invalid_request_error] - GT1 --> GT2[ParseGatewayRequest] - GT2 -->|parse error| GT2E[400 invalid_request_error] - GT2 --> GT3{model present} - GT3 -->|no| GT3E[400 invalid_request_error] - GT3 --> GT4[BillingEligibility check] - GT4 -->|fail| GT4E[403 billing_error] - GT4 --> GT5[ForwardCountTokens] - end - - %% ========================= - %% FLOW E: Gemini Models List Get - %% ========================= - subgraph FLOW_E["v1beta models list or get"] - GL0[Auth middleware] --> GL1[Validate platform] - GL1 -->|invalid| GL1E[400 googleError] - GL1 --> GL2{force platform antigravity} - GL2 -->|yes| GL2Y[return static fallback models] - GL2 -->|no| GL3[SelectAccountForAIStudioEndpoints] - GL3 -->|no gemini and has antigravity| GL3Y[return fallback models] - GL3 -->|no accounts| GL3E[503 googleError] - GL3 --> GL4[ForwardAIStudioGET] - GL4 -->|error| GL4E[502 googleError] - GL4 --> GL5[Passthrough response or fallback] - end - - %% ========================= - %% SHARED: Account Selection - %% ========================= - subgraph SELECT["SelectAccountWithLoadAwareness detail"] - S0[Start] --> S1{concurrencyService nil OR load batch disabled} - S1 -->|yes| S2[SelectAccountForModelWithExclusions legacy] - S2 --> S3[tryAcquireAccountSlot] - S3 -->|acquired| S3Y[SelectionResult Acquired true ReleaseFunc] - S3 -->|not acquired| S3N[WaitPlan FallbackTimeout MaxWaiting] - S1 -->|no| S4[Resolve platform] - S4 --> S5[List schedulable accounts] - S5 --> S6[Layer1 Sticky session] - S6 -->|hit and valid| S6A[tryAcquireAccountSlot] - S6A -->|acquired| S6AY[SelectionResult Acquired true] - S6A -->|not acquired and waitingCount < StickyMax| S6AN[WaitPlan StickyTimeout Max] - S6 --> S7[Layer2 Load aware] - S7 --> S7A[Load batch concurrency plus wait to loadRate] - S7A --> S7B[Sort priority load LRU OAuth prefer for Gemini] - S7B --> S7C[tryAcquireAccountSlot in order] - S7C -->|first success| S7CY[SelectionResult Acquired true] - S7C -->|none| S8[Layer3 Fallback wait] - S8 --> S8A[Sort priority LRU] - S8A --> S8B[WaitPlan FallbackTimeout Max] - end - - %% ========================= - %% SHARED: Wait Acquire - %% ========================= - subgraph WAIT["AcquireXSlotWithWait detail"] - W0[Try AcquireXSlot immediately] -->|acquired| W1[return ReleaseFunc] - W0 -->|not acquired| W2[Wait loop with timeout] - W2 --> W3[Backoff 100ms x1.5 jitter max2s] - W2 --> W4[If streaming and ping format send SSE ping] - W2 --> W5[Retry AcquireXSlot on timer] - W5 -->|acquired| W1 - W2 -->|timeout| W6[ConcurrencyError IsTimeout true] - end - - %% ========================= - %% SHARED: Account Wait Queue - %% ========================= - subgraph AQ["Account Wait Queue Redis Lua"] - Q1[IncrementAccountWaitCount] --> Q2{current >= max} - Q2 -->|yes| Q2Y[return false] - Q2 -->|no| Q3[INCR and if first set TTL] - Q3 --> Q4[return true] - Q5[DecrementAccountWaitCount] --> Q6[if current > 0 then DECR] - end - - %% ========================= - %% SHARED: Background cleanup - %% ========================= - subgraph CLEANUP["Slot Cleanup Worker"] - C0[StartSlotCleanupWorker interval] --> C1[List schedulable accounts] - C1 --> C2[CleanupExpiredAccountSlots per account] - C2 --> C3[Repeat every interval] - end -``` diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 1770a985..6563ee0c 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -952,7 +952,6 @@ "integrity": "sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -1368,7 +1367,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -1445,7 +1443,6 @@ "resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz", "integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==", "license": "MIT", - "peer": true, "dependencies": { "@kurkle/color": "^0.3.0" }, @@ -2043,7 +2040,6 @@ "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "dev": true, "license": "MIT", - "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -2352,7 +2348,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -2826,7 +2821,6 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -2860,7 +2854,6 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "devOptional": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -2933,7 +2926,6 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -3105,7 +3097,6 @@ "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "license": "MIT", - "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.25", "@vue/compiler-sfc": "3.5.25", @@ -3199,7 +3190,6 @@ "integrity": "sha512-P7OP77b2h/Pmk+lZdJ0YWs+5tJ6J2+uOQPo7tlBnY44QqQSPYvS0qVT4wqDJgwrZaLe47etJLLQRFia71GYITw==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@volar/typescript": "2.4.15", "@vue/language-core": "2.2.12" diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 914678a5..c1ca08fa 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -83,14 +83,6 @@ > - - - - {{ tierDisplay }} - @@ -148,23 +140,4 @@ const statusText = computed(() => { return props.account.status }) -// Computed: tier display -const tierDisplay = computed(() => { - const credentials = props.account.credentials as Record | undefined - const tierId = credentials?.tier_id - if (!tierId || tierId === 'unknown') return null - - const tierMap: Record = { - 'free': 'Free', - 'payg': 'Pay-as-you-go', - 'pay-as-you-go': 'Pay-as-you-go', - 'enterprise': 'Enterprise', - 'LEGACY': 'Legacy', - 'PRO': 'Pro', - 'ULTRA': 'Ultra' - } - - return tierMap[tierId] || tierId -}) - From b6d1e7a0846d8946ce3a9dbd3d3606db2e410d55 Mon Sep 17 00:00:00 2001 From: IanShaw <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 1 Jan 2026 10:45:57 +0800 Subject: [PATCH 49/49] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20/v1/messages?= =?UTF-8?q?=20=E9=97=B4=E6=AD=87=E6=80=A7=20400=20=E9=94=99=E8=AF=AF=20(#1?= =?UTF-8?q?12)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(upstream): 修复上游格式兼容性问题 - 跳过Claude模型无signature的thinking block - 支持custom类型工具(MCP)格式转换 - 添加ClaudeCustomToolSpec结构体支持MCP工具 - 添加Custom字段验证,跳过无效custom工具 - 在convertClaudeToolsToGeminiTools中添加schema清理 - 完整的单元测试覆盖,包含边界情况 修复: Issue 0.1 signature缺失, Issue 0.2 custom工具格式 改进: Codex审查发现的2个重要问题 测试: - TestBuildParts_ThinkingBlockWithoutSignature: 验证thinking block处理 - TestBuildTools_CustomTypeTools: 验证custom工具转换和边界情况 - TestConvertClaudeToolsToGeminiTools_CustomType: 验证service层转换 * feat(gemini): 添加Gemini限额与TierID支持 实现PR1:Gemini限额与TierID功能 后端修改: - GeminiTokenInfo结构体添加TierID字段 - fetchProjectID函数返回(projectID, tierID, error) - 从LoadCodeAssist响应中提取tierID(优先IsDefault,回退到第一个非空tier) - ExchangeCode、RefreshAccountToken、GetAccessToken函数更新以处理tierID - BuildAccountCredentials函数保存tier_id到credentials 前端修改: - AccountStatusIndicator组件添加tier显示 - 支持LEGACY/PRO/ULTRA等tier类型的友好显示 - 使用蓝色badge展示tier信息 技术细节: - tierID提取逻辑:优先选择IsDefault的tier,否则选择第一个非空tier - 所有fetchProjectID调用点已更新以处理新的返回签名 - 前端gracefully处理missing/unknown tier_id * refactor(gemini): 优化TierID实现并添加安全验证 根据并发代码审查(code-reviewer, security-auditor, gemini, codex)的反馈进行改进: 安全改进: - 添加validateTierID函数验证tier_id格式和长度(最大64字符) - 限制tier_id字符集为字母数字、下划线、连字符和斜杠 - 在BuildAccountCredentials中验证tier_id后再存储 - 静默跳过无效tier_id,不阻塞账户创建 代码质量改进: - 提取extractTierIDFromAllowedTiers辅助函数消除重复代码 - 重构fetchProjectID函数,tierID提取逻辑只执行一次 - 改进代码可读性和可维护性 审查工具: - code-reviewer agent (a09848e) - security-auditor agent (a9a149c) - gemini CLI (bcc7c81) - codex (b5d8919) 修复问题: - HIGH: 未验证的tier_id输入 - MEDIUM: 代码重复(tierID提取逻辑重复2次) * fix(format): 修复 gofmt 格式问题 - 修复 claude_types.go 中的字段对齐问题 - 修复 gemini_messages_compat_service.go 中的缩进问题 * fix(upstream): 修复上游格式兼容性问题 (#14) * fix(upstream): 修复上游格式兼容性问题 - 跳过Claude模型无signature的thinking block - 支持custom类型工具(MCP)格式转换 - 添加ClaudeCustomToolSpec结构体支持MCP工具 - 添加Custom字段验证,跳过无效custom工具 - 在convertClaudeToolsToGeminiTools中添加schema清理 - 完整的单元测试覆盖,包含边界情况 修复: Issue 0.1 signature缺失, Issue 0.2 custom工具格式 改进: Codex审查发现的2个重要问题 测试: - TestBuildParts_ThinkingBlockWithoutSignature: 验证thinking block处理 - TestBuildTools_CustomTypeTools: 验证custom工具转换和边界情况 - TestConvertClaudeToolsToGeminiTools_CustomType: 验证service层转换 * fix(format): 修复 gofmt 格式问题 - 修复 claude_types.go 中的字段对齐问题 - 修复 gemini_messages_compat_service.go 中的缩进问题 * fix(format): 修复 claude_types.go 的 gofmt 格式问题 * feat(antigravity): 优化 thinking block 和 schema 处理 - 为 dummy thinking block 添加 ThoughtSignature - 重构 thinking block 处理逻辑,在每个条件分支内创建 part - 优化 excludedSchemaKeys,移除 Gemini 实际支持的字段 (minItems, maxItems, minimum, maximum, additionalProperties, format) - 添加详细注释说明 Gemini API 支持的 schema 字段 * fix(antigravity): 增强 schema 清理的安全性 基于 Codex review 建议: - 添加 format 字段白名单过滤,只保留 Gemini 支持的 date-time/date/time - 补充更多不支持的 schema 关键字到黑名单: * 组合 schema: oneOf, anyOf, allOf, not, if/then/else * 对象验证: minProperties, maxProperties, patternProperties 等 * 定义引用: $defs, definitions - 避免不支持的 schema 字段导致 Gemini API 校验失败 * fix(lint): 修复 gemini_messages_compat_service 空分支警告 - 在 cleanToolSchema 的 if 语句中添加 continue - 移除重复的注释 * fix(antigravity): 移除 minItems/maxItems 以兼容 Claude API - 将 minItems 和 maxItems 添加到 schema 黑名单 - Claude API (Vertex AI) 不支持这些数组验证字段 - 添加调试日志记录工具 schema 转换过程 - 修复 tools.14.custom.input_schema 验证错误 * fix(antigravity): 修复 additionalProperties schema 对象问题 - 将 additionalProperties 的 schema 对象转换为布尔值 true - Claude API 只支持 additionalProperties: false,不支持 schema 对象 - 修复 tools.14.custom.input_schema 验证错误 - 参考 Claude 官方文档的 JSON Schema 限制 * fix(antigravity): 修复 Claude 模型 thinking 块兼容性问题 - 完全跳过 Claude 模型的 thinking 块以避免 signature 验证失败 - 只在 Gemini 模型中使用 dummy thought signature - 修改 additionalProperties 默认值为 false(更安全) - 添加调试日志以便排查问题 * fix(upstream): 修复跨模型切换时的 dummy signature 问题 基于 Codex review 和用户场景分析的修复: 1. 问题场景 - Gemini (thinking) → Claude (thinking) 切换时 - Gemini 返回的 thinking 块使用 dummy signature - Claude API 会拒绝 dummy signature,导致 400 错误 2. 修复内容 - request_transformer.go:262: 跳过 dummy signature - 只保留真实的 Claude signature - 支持频繁的跨模型切换 3. 其他修复(基于 Codex review) - gateway_service.go:691: 修复 io.ReadAll 错误处理 - gateway_service.go:687: 条件日志(尊重 LogUpstreamErrorBody 配置) - gateway_service.go:915: 收紧 400 failover 启发式 - request_transformer.go:188: 移除签名成功日志 4. 新增功能(默认关闭) - 阶段 1: 上游错误日志(GATEWAY_LOG_UPSTREAM_ERROR_BODY) - 阶段 2: Antigravity thinking 修复 - 阶段 3: API-key beta 注入(GATEWAY_INJECT_BETA_FOR_APIKEY) - 阶段 3: 智能 400 failover(GATEWAY_FAILOVER_ON_400) 测试:所有测试通过 * fix(lint): 修复 golangci-lint 问题 - 应用 De Morgan 定律简化条件判断 - 修复 gofmt 格式问题 - 移除未使用的 min 函数 --- backend/internal/config/config.go | 15 ++ .../internal/pkg/antigravity/claude_types.go | 3 + .../pkg/antigravity/request_transformer.go | 223 +++++++++++++----- .../antigravity/request_transformer_test.go | 179 ++++++++++++++ backend/internal/pkg/claude/constants.go | 6 + .../service/antigravity_gateway_service.go | 9 + backend/internal/service/gateway_service.go | 138 +++++++++++ .../service/gemini_messages_compat_service.go | 39 ++- .../gemini_messages_compat_service_test.go | 128 ++++++++++ .../internal/service/gemini_oauth_service.go | 104 +++++--- .../internal/service/gemini_token_provider.go | 5 +- deploy/config.example.yaml | 15 ++ frontend/package-lock.json | 10 + .../account/AccountStatusIndicator.vue | 27 +++ 14 files changed, 815 insertions(+), 86 deletions(-) create mode 100644 backend/internal/pkg/antigravity/request_transformer_test.go create mode 100644 backend/internal/service/gemini_messages_compat_service_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index aeeddcb4..d3674932 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -119,6 +119,17 @@ type GatewayConfig struct { // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + + // 是否记录上游错误响应体摘要(避免输出请求内容) + LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` + // 上游错误响应体记录最大字节数(超过会截断) + LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"` + + // API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容) + InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"` + + // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) + FailoverOn400 bool `mapstructure:"failover_on_400"` } func (s *ServerConfig) Address() string { @@ -313,6 +324,10 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.log_upstream_error_body", false) + viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) + viper.SetDefault("gateway.inject_beta_for_apikey", false) + viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 01b805cd..34e6b1f4 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -54,6 +54,9 @@ type CustomToolSpec struct { InputSchema map[string]any `json:"input_schema"` } +// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格) +type ClaudeCustomToolSpec = CustomToolSpec + // SystemBlock system prompt 数组形式的元素 type SystemBlock struct { Type string `json:"type"` diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index e0b5b886..83b87a32 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) - // 检测是否启用 thinking - isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") + // 检测是否启用 thinking + requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + // 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等), + // 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。 + isThinkingEnabled := requestedThinkingEnabled && allowDummyThought + // 1. 构建 contents contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) if err != nil { @@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) // 3. 构建 generationConfig - generationConfig := buildGenerationConfig(claudeReq) + reqForGen := claudeReq + if requestedThinkingEnabled && !allowDummyThought { + log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel) + // shallow copy to avoid mutating caller's request + clone := *claudeReq + clone.Thinking = nil + reqForGen = &clone + } + generationConfig := buildGenerationConfig(reqForGen) // 4. 构建 tools tools := buildTools(claudeReq.Tools) @@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT if !hasThoughtPart && len(parts) > 0 { // 在开头添加 dummy thinking block parts = append([]GeminiPart{{ - Text: "Thinking...", - Thought: true, + Text: "Thinking...", + Thought: true, + ThoughtSignature: dummyThoughtSignature, }}, parts...) } } @@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures const dummyThoughtSignature = "skip_thought_signature_validator" +// isValidThoughtSignature 验证 thought signature 是否有效 +// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节 +func isValidThoughtSignature(signature string) bool { + // 空字符串无效 + if signature == "" { + return false + } + + // signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节) + // 参考 Claude API 文档和实际观察到的有效 signature + if len(signature) < 40 { + log.Printf("[Debug] Signature too short: len=%d", len(signature)) + return false + } + + // 检查是否是有效的 base64 字符 + // base64 字符集: A-Z, a-z, 0-9, +, /, = + for i, c := range signature { + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && + (c < '0' || c > '9') && c != '+' && c != '/' && c != '=' { + log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c) + return false + } + } + + return true +} + // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { @@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu } case "thinking": - part := GeminiPart{ - Text: block.Thinking, - Thought: true, - } - // 保留原有 signature(Claude 模型需要有效的 signature) - if block.Signature != "" { - part.ThoughtSignature = block.Signature - } else if !allowDummyThought { - // Claude 模型需要有效 signature,跳过无 signature 的 thinking block - log.Printf("Warning: skipping thinking block without signature for Claude model") + if allowDummyThought { + // Gemini 模型可以使用 dummy signature + parts = append(parts, GeminiPart{ + Text: block.Thinking, + Thought: true, + ThoughtSignature: dummyThoughtSignature, + }) continue - } else { - // Gemini 模型使用 dummy signature - part.ThoughtSignature = dummyThoughtSignature } - parts = append(parts, part) + + // Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。 + signature := strings.TrimSpace(block.Signature) + if signature == "" || signature == dummyThoughtSignature { + log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)") + continue + } + if !isValidThoughtSignature(signature) { + log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature)) + } + parts = append(parts, GeminiPart{ + Text: block.Thinking, + Thought: true, + ThoughtSignature: signature, + }) case "image": if block.Source != nil && block.Source.Type == "base64" { @@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ID: block.ID, }, } - // 保留原有 signature,或对 Gemini 模型使用 dummy signature - if block.Signature != "" { - part.ThoughtSignature = block.Signature - } else if allowDummyThought { + // 只有 Gemini 模型使用 dummy signature + // Claude 模型不设置 signature(避免验证问题) + if allowDummyThought { part.ThoughtSignature = dummyThoughtSignature } parts = append(parts, part) @@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl - for _, tool := range tools { + for i, tool := range tools { // 跳过无效工具名称 - if tool.Name == "" { + if strings.TrimSpace(tool.Name) == "" { log.Printf("Warning: skipping tool with empty name") continue } @@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { var inputSchema map[string]any // 检查是否为 custom 类型工具 (MCP) - if tool.Type == "custom" && tool.Custom != nil { - // Custom 格式: 从 custom 字段获取 description 和 input_schema + if tool.Type == "custom" { + if tool.Custom == nil || tool.Custom.InputSchema == nil { + log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name) + continue + } description = tool.Custom.Description inputSchema = tool.Custom.InputSchema + + // 调试日志:记录 custom 工具的 schema + if schemaJSON, err := json.Marshal(inputSchema); err == nil { + log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON)) + } } else { // 标准格式: 从顶层字段获取 description = tool.Description @@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 清理 JSON Schema params := cleanJSONSchema(inputSchema) - // 为 nil schema 提供默认值 if params == nil { params = map[string]any{ @@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } } + // 调试日志:记录清理后的 schema + if paramsJSON, err := json.Marshal(params); err == nil { + log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON)) + } + funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, @@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any { } // excludedSchemaKeys 不支持的 schema 字段 +// 基于 Claude API (Vertex AI) 的实际支持情况 +// 支持: type, description, enum, properties, required, additionalProperties, items +// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段 var excludedSchemaKeys = map[string]bool{ - "$schema": true, - "$id": true, - "$ref": true, - "additionalProperties": true, - "minLength": true, - "maxLength": true, - "minItems": true, - "maxItems": true, - "uniqueItems": true, - "minimum": true, - "maximum": true, - "exclusiveMinimum": true, - "exclusiveMaximum": true, - "pattern": true, - "format": true, - "default": true, - "strict": true, - "const": true, - "examples": true, - "deprecated": true, - "readOnly": true, - "writeOnly": true, - "contentMediaType": true, - "contentEncoding": true, + // 元 schema 字段 + "$schema": true, + "$id": true, + "$ref": true, + + // 字符串验证(Gemini 不支持) + "minLength": true, + "maxLength": true, + "pattern": true, + + // 数字验证(Claude API 通过 Vertex AI 不支持这些字段) + "minimum": true, + "maximum": true, + "exclusiveMinimum": true, + "exclusiveMaximum": true, + "multipleOf": true, + + // 数组验证(Claude API 通过 Vertex AI 不支持这些字段) + "uniqueItems": true, + "minItems": true, + "maxItems": true, + + // 组合 schema(Gemini 不支持) + "oneOf": true, + "anyOf": true, + "allOf": true, + "not": true, + "if": true, + "then": true, + "else": true, + "$defs": true, + "definitions": true, + + // 对象验证(仅保留 properties/required/additionalProperties) + "minProperties": true, + "maxProperties": true, + "patternProperties": true, + "propertyNames": true, + "dependencies": true, + "dependentSchemas": true, + "dependentRequired": true, + + // 其他不支持的字段 + "default": true, + "const": true, + "examples": true, + "deprecated": true, + "readOnly": true, + "writeOnly": true, + "contentMediaType": true, + "contentEncoding": true, + + // Claude 特有字段 + "strict": true, } // cleanSchemaValue 递归清理 schema 值 @@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any { continue } + // 特殊处理 format 字段:只保留 Gemini 支持的 format 值 + if k == "format" { + if formatStr, ok := val.(string); ok { + // Gemini 只支持 date-time, date, time + if formatStr == "date-time" || formatStr == "date" || formatStr == "time" { + result[k] = val + } + // 其他 format 值直接跳过 + } + continue + } + + // 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象 + if k == "additionalProperties" { + if boolVal, ok := val.(bool); ok { + result[k] = boolVal + log.Printf("[Debug] additionalProperties is bool: %v", boolVal) + } else { + // 如果是 schema 对象,转换为 false(更安全的默认值) + result[k] = false + log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val) + } + continue + } + // 递归清理所有值 result[k] = cleanSchemaValue(val) } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go new file mode 100644 index 00000000..56eebad0 --- /dev/null +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -0,0 +1,179 @@ +package antigravity + +import ( + "encoding/json" + "testing" +) + +// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 +func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { + tests := []struct { + name string + content string + allowDummyThought bool + expectedParts int + description string + }{ + { + name: "Claude model - skip thinking block without signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": ""}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: false, + expectedParts: 2, // 只有两个text block + description: "Claude模型应该跳过无signature的thinking block", + }, + { + name: "Claude model - keep thinking block with signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: false, + expectedParts: 3, // 三个block都保留 + description: "Claude模型应该保留有signature的thinking block", + }, + { + name: "Gemini model - use dummy signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": ""}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: true, + expectedParts: 3, // 三个block都保留,thinking使用dummy signature + description: "Gemini模型应该为无signature的thinking block使用dummy signature", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toolIDToName := make(map[string]string) + parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought) + + if err != nil { + t.Fatalf("buildParts() error = %v", err) + } + + if len(parts) != tt.expectedParts { + t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts) + } + }) + } +} + +// TestBuildTools_CustomTypeTools 测试custom类型工具转换 +func TestBuildTools_CustomTypeTools(t *testing.T) { + tests := []struct { + name string + tools []ClaudeTool + expectedLen int + description string + }{ + { + name: "Standard tool format", + tools: []ClaudeTool{ + { + Name: "get_weather", + Description: "Get weather information", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + expectedLen: 1, + description: "标准工具格式应该正常转换", + }, + { + name: "Custom type tool (MCP format)", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "mcp_tool", + Custom: &ClaudeCustomToolSpec{ + Description: "MCP tool description", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "param": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + expectedLen: 1, + description: "Custom类型工具应该从Custom字段读取description和input_schema", + }, + { + name: "Mixed standard and custom tools", + tools: []ClaudeTool{ + { + Name: "standard_tool", + Description: "Standard tool", + InputSchema: map[string]any{"type": "object"}, + }, + { + Type: "custom", + Name: "custom_tool", + Custom: &ClaudeCustomToolSpec{ + Description: "Custom tool", + InputSchema: map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations + description: "混合标准和custom工具应该都能正确转换", + }, + { + name: "Invalid custom tool - nil Custom field", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "invalid_custom", + // Custom 为 nil + }, + }, + expectedLen: 0, // 应该被跳过 + description: "Custom字段为nil的custom工具应该被跳过", + }, + { + name: "Invalid custom tool - nil InputSchema", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "invalid_custom", + Custom: &ClaudeCustomToolSpec{ + Description: "Invalid", + // InputSchema 为 nil + }, + }, + }, + expectedLen: 0, // 应该被跳过 + description: "InputSchema为nil的custom工具应该被跳过", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildTools(tt.tools) + + if len(result) != tt.expectedLen { + t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen) + } + + // 验证function declarations存在 + if len(result) > 0 && result[0].FunctionDeclarations != nil { + if len(result[0].FunctionDeclarations) != len(tt.tools) { + t.Errorf("%s: got %d function declarations, want %d", + tt.description, len(result[0].FunctionDeclarations), len(tt.tools)) + } + } + }) + } +} diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 97ad6c83..0db3ed4a 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking +// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth) +const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + +// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code) +const ApiKeyHaikuBetaHeader = BetaInterleavedThinking + // Claude Code 客户端默认请求头 var DefaultHeaders = map[string]string{ "User-Agent": "claude-cli/2.0.62 (external, cli)", diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index ae2976f8..5b3bf565 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } + // 调试:记录转换后的请求体(仅记录前 2000 字符) + if bodyJSON, err := json.Marshal(geminiBody); err == nil { + truncated := string(bodyJSON) + if len(truncated) > 2000 { + truncated = truncated[:2000] + "..." + } + log.Printf("[Debug] Transformed Gemini request: %s", truncated) + } + // 构建上游 action action := "generateContent" if claudeReq.Stream { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d542e9c2..5884602d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -684,6 +685,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { + // 可选:对部分 400 触发 failover(默认关闭以保持语义) + if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + // ReadAll failed, fall back to normal error handling without consuming the stream + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + if s.shouldFailoverOn400(respBody) { + if s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Account %d: 400 error, attempting failover: %s", + account.ID, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } else { + log.Printf("Account %d: 400 error, attempting failover", account.ID) + } + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + } return s.handleErrorResponse(ctx, resp, c, account) } @@ -786,6 +811,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } } return req, nil @@ -838,6 +870,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) return claude.DefaultBetaHeader } +func requestNeedsBetaFeatures(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { + return true + } + if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { + return true + } + return false +} + +func defaultApiKeyBetaHeader(body []byte) string { + modelID := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.ApiKeyHaikuBetaHeader + } + return claude.ApiKeyBetaHeader +} + +func truncateForLog(b []byte, maxBytes int) string { + if maxBytes <= 0 { + maxBytes = 2048 + } + if len(b) > maxBytes { + b = b[:maxBytes] + } + s := string(b) + // 保持一行,避免污染日志格式 + s = strings.ReplaceAll(s, "\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\r") + return s +} + +func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { + // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 默认保守:无法识别则不切换。 + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 + // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 + if strings.Contains(msg, "anthropic-beta") || + strings.Contains(msg, "beta feature") || + strings.Contains(msg, "requires beta") { + return true + } + + // thinking/tool streaming 等兼容性约束(常见于中间转换链路) + if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { + return true + } + if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { + return true + } + + return false +} + +func extractUpstreamErrorMessage(body []byte) string { + // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} + if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { + inner := strings.TrimSpace(m) + // 有些上游会把完整 JSON 作为字符串塞进 message + if strings.HasPrefix(inner, "{") { + if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + return m + } + + // 兜底:尝试顶层 message + return gjson.GetBytes(body, "message").String() +} + func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) @@ -850,6 +959,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res switch resp.StatusCode { case 400: + // 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开 + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Upstream 400 error (account=%d platform=%s type=%s): %s", + account.ID, + account.Platform, + account.Type, + truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } c.Data(http.StatusBadRequest, "application/json", body) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) case 401: @@ -1329,6 +1448,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 标记账号状态(429/529等) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 记录上游错误摘要便于排障(不回显请求内容) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + // 返回简化的错误响应 errMsg := "Upstream request failed" switch resp.StatusCode { @@ -1409,6 +1540,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } } return req, nil diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index a0bf1b6a..b1877800 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any { "properties": map[string]any{}, } } + // 清理 JSON Schema + cleanedParams := cleanToolSchema(params) funcDecls = append(funcDecls, map[string]any{ "name": name, "description": desc, - "parameters": params, + "parameters": cleanedParams, }) } @@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any { } } +// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段 +func cleanToolSchema(schema any) any { + if schema == nil { + return nil + } + + switch v := schema.(type) { + case map[string]any: + cleaned := make(map[string]any) + for key, value := range v { + // 跳过不支持的字段 + if key == "$schema" || key == "$id" || key == "$ref" || + key == "additionalProperties" || key == "minLength" || + key == "maxLength" || key == "minItems" || key == "maxItems" { + continue + } + // 递归清理嵌套对象 + cleaned[key] = cleanToolSchema(value) + } + // 规范化 type 字段为大写 + if typeVal, ok := cleaned["type"].(string); ok { + cleaned["type"] = strings.ToUpper(typeVal) + } + return cleaned + case []any: + cleaned := make([]any, len(v)) + for i, item := range v { + cleaned[i] = cleanToolSchema(item) + } + return cleaned + default: + return v + } +} + func convertClaudeGenerationConfig(req map[string]any) map[string]any { out := make(map[string]any) if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go new file mode 100644 index 00000000..d49f2eb3 --- /dev/null +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -0,0 +1,128 @@ +package service + +import ( + "testing" +) + +// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 +func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { + tests := []struct { + name string + tools any + expectedLen int + description string + }{ + { + name: "Standard tools", + tools: []any{ + map[string]any{ + "name": "get_weather", + "description": "Get weather info", + "input_schema": map[string]any{"type": "object"}, + }, + }, + expectedLen: 1, + description: "标准工具格式应该正常转换", + }, + { + name: "Custom type tool (MCP format)", + tools: []any{ + map[string]any{ + "type": "custom", + "name": "mcp_tool", + "custom": map[string]any{ + "description": "MCP tool description", + "input_schema": map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, + description: "Custom类型工具应该从custom字段读取", + }, + { + name: "Mixed standard and custom tools", + tools: []any{ + map[string]any{ + "name": "standard_tool", + "description": "Standard", + "input_schema": map[string]any{"type": "object"}, + }, + map[string]any{ + "type": "custom", + "name": "custom_tool", + "custom": map[string]any{ + "description": "Custom", + "input_schema": map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, + description: "混合工具应该都能正确转换", + }, + { + name: "Custom tool without custom field", + tools: []any{ + map[string]any{ + "type": "custom", + "name": "invalid_custom", + // 缺少 custom 字段 + }, + }, + expectedLen: 0, // 应该被跳过 + description: "缺少custom字段的custom工具应该被跳过", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertClaudeToolsToGeminiTools(tt.tools) + + if tt.expectedLen == 0 { + if result != nil { + t.Errorf("%s: expected nil result, got %v", tt.description, result) + } + return + } + + if result == nil { + t.Fatalf("%s: expected non-nil result", tt.description) + } + + if len(result) != 1 { + t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result)) + return + } + + toolDecl, ok := result[0].(map[string]any) + if !ok { + t.Fatalf("%s: result[0] is not map[string]any", tt.description) + } + + funcDecls, ok := toolDecl["functionDeclarations"].([]any) + if !ok { + t.Fatalf("%s: functionDeclarations is not []any", tt.description) + } + + toolsArr, _ := tt.tools.([]any) + expectedFuncCount := 0 + for _, tool := range toolsArr { + toolMap, _ := tool.(map[string]any) + if toolMap["name"] != "" { + // 检查是否为有效的custom工具 + if toolMap["type"] == "custom" { + if toolMap["custom"] != nil { + expectedFuncCount++ + } + } else { + expectedFuncCount++ + } + } + } + + if len(funcDecls) != expectedFuncCount { + t.Errorf("%s: expected %d function declarations, got %d", + tt.description, expectedFuncCount, len(funcDecls)) + } + }) + } +} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index e4bda5f8..221bd0f2 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "regexp" "strconv" "strings" "time" @@ -163,6 +164,45 @@ type GeminiTokenInfo struct { Scope string `json:"scope,omitempty"` ProjectID string `json:"project_id,omitempty"` OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" + TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA +} + +// validateTierID validates tier_id format and length +func validateTierID(tierID string) error { + if tierID == "" { + return nil // Empty is allowed + } + if len(tierID) > 64 { + return fmt.Errorf("tier_id exceeds maximum length of 64 characters") + } + // Allow alphanumeric, underscore, hyphen, and slash (for tier paths) + if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) { + return fmt.Errorf("tier_id contains invalid characters") + } + return nil +} + +// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response +// Prioritizes IsDefault tier, falls back to first non-empty tier +func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string { + tierID := "LEGACY" + // First pass: look for default tier + for _, tier := range allowedTiers { + if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + // Second pass: if still LEGACY, take first non-empty tier + if tierID == "LEGACY" { + for _, tier := range allowedTiers { + if strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + } + return tierID } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { @@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID := sessionProjectID + var tierID string // 对于 code_assist 模式,project_id 是必需的 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) if oauthType == "code_assist" { if projectID == "" { var err error - projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) @@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ExpiresAt: expiresAt, Scope: tokenResp.Scope, ProjectID: projectID, + TierID: tierID, OAuthType: oauthType, }, nil } @@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { - projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) if err != nil { return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) } @@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } tokenInfo.ProjectID = projectID + tokenInfo.TierID = tierID } return tokenInfo, nil @@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } + if tokenInfo.TierID != "" { + // Validate tier_id before storing + if err := validateTierID(tokenInfo.TierID); err == nil { + creds["tier_id"] = tokenInfo.TierID + } + // Silently skip invalid tier_id (don't block account creation) + } if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } @@ -398,34 +448,26 @@ func (s *GeminiOAuthService) Stop() { s.sessionStore.Stop() } -func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { +func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) { if s.codeAssist == nil { - return "", errors.New("code assist client not configured") + return "", "", errors.New("code assist client not configured") } loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) + + // Extract tierID from response (works whether CloudAICompanionProject is set or not) + tierID := "LEGACY" + if loadResp != nil { + tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers) + } + + // If LoadCodeAssist returned a project, use it if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { - return strings.TrimSpace(loadResp.CloudAICompanionProject), nil + return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. - tierID := "LEGACY" - if loadResp != nil { - for _, tier := range loadResp.AllowedTiers { - if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" { - for _, tier := range loadResp.AllowedTiers { - if strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - } - } + // (tierID already extracted above, reuse it) req := &geminicli.OnboardUserRequest{ TierID: tierID, @@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } - return "", err + return "", "", err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { switch v := resp.Response.CloudAICompanionProject.(type) { case string: - return strings.TrimSpace(v), nil + return strings.TrimSpace(v), tierID, nil case map[string]any: if id, ok := v["id"].(string); ok { - return strings.TrimSpace(id), nil + return strings.TrimSpace(id), tierID, nil } } } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } - return "", errors.New("onboardUser completed but no project_id returned") + return "", "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } if loadErr != nil { - return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct { diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 2195ec55..5f369de5 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -112,7 +112,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) + detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) if err != nil { log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) return accessToken, nil @@ -123,6 +123,9 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou account.Credentials = make(map[string]any) } account.Credentials["project_id"] = detected + if tierID != "" { + account.Credentials["tier_id"] = tierID + } _ = p.accountRepo.Update(ctx, account) } } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 5bd85d7d..5478d151 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -122,6 +122,21 @@ pricing: # Hash check interval in minutes hash_check_interval_minutes: 10 +# ============================================================================= +# Gateway (Optional) +# ============================================================================= +gateway: + # Wait time (in seconds) for upstream response headers (streaming body not affected) + response_header_timeout: 300 + # Log upstream error response body summary (safe/truncated; does not log request content) + log_upstream_error_body: false + # Max bytes to log from upstream error body + log_upstream_error_body_max_bytes: 2048 + # Auto inject anthropic-beta for API-key accounts when needed (default off) + inject_beta_for_apikey: false + # Allow failover on selected 400 errors (default off) + failover_on_400: false + # ============================================================================= # Gemini OAuth (Required for Gemini accounts) # ============================================================================= diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 6563ee0c..1770a985 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -952,6 +952,7 @@ "integrity": "sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -1367,6 +1368,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -1443,6 +1445,7 @@ "resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz", "integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==", "license": "MIT", + "peer": true, "dependencies": { "@kurkle/color": "^0.3.0" }, @@ -2040,6 +2043,7 @@ "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "dev": true, "license": "MIT", + "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -2348,6 +2352,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -2821,6 +2826,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -2854,6 +2860,7 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "devOptional": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -2926,6 +2933,7 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -3097,6 +3105,7 @@ "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "license": "MIT", + "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.25", "@vue/compiler-sfc": "3.5.25", @@ -3190,6 +3199,7 @@ "integrity": "sha512-P7OP77b2h/Pmk+lZdJ0YWs+5tJ6J2+uOQPo7tlBnY44QqQSPYvS0qVT4wqDJgwrZaLe47etJLLQRFia71GYITw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@volar/typescript": "2.4.15", "@vue/language-core": "2.2.12" diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index c1ca08fa..914678a5 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -83,6 +83,14 @@ > + + + + {{ tierDisplay }} + @@ -140,4 +148,23 @@ const statusText = computed(() => { return props.account.status }) +// Computed: tier display +const tierDisplay = computed(() => { + const credentials = props.account.credentials as Record | undefined + const tierId = credentials?.tier_id + if (!tierId || tierId === 'unknown') return null + + const tierMap: Record = { + 'free': 'Free', + 'payg': 'Pay-as-you-go', + 'pay-as-you-go': 'Pay-as-you-go', + 'enterprise': 'Enterprise', + 'LEGACY': 'Legacy', + 'PRO': 'Pro', + 'ULTRA': 'Ultra' + } + + return tierMap[tierId] || tierId +}) +