From bece1b52012333dd6cf3aa8cea24b019563d1233 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 24 Jan 2026 20:01:03 +0800 Subject: [PATCH 001/363] =?UTF-8?q?perf(=E6=9C=8D=E5=8A=A1=E7=AB=AF):=20?= =?UTF-8?q?=E5=90=AF=E7=94=A8=20h2c=20=E5=B9=B6=E4=BF=9D=E7=95=99=20HTTP/1?= =?UTF-8?q?.1=20=E5=9B=9E=E9=80=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README_CN.md | 23 +++++++++++++++++++++++ backend/cmd/server/main.go | 11 ++++++++++- backend/internal/server/http.go | 5 ++++- deploy/Caddyfile | 1 + 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/README_CN.md b/README_CN.md index 41d399d5..8129c3b2 100644 --- a/README_CN.md +++ b/README_CN.md @@ -358,6 +358,29 @@ Invalid base URL: invalid url scheme: http ./sub2api ``` +#### HTTP/2 (h2c) 与 HTTP/1.1 回退 + +后端明文端口默认支持 h2c,并保留 HTTP/1.1 回退用于 WebSocket 与旧客户端。浏览器通常不支持 h2c,性能收益主要在反向代理或内网链路。 + +**反向代理示例(Caddy):** + +```caddyfile +transport http { + versions h2c h1 +} +``` + +**验证:** + +```bash +# h2c prior knowledge +curl --http2-prior-knowledge -I http://localhost:8080/health +# HTTP/1.1 回退 +curl --http1.1 -I http://localhost:8080/health +# WebSocket 回退验证(需管理员 token) +websocat -H="Sec-WebSocket-Protocol: sub2api-admin, jwt." ws://localhost:8080/api/v1/admin/ops/ws/qps +``` + #### 开发模式 ```bash diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index f8a7d313..65b8c659 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -24,6 +24,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/web" "github.com/gin-gonic/gin" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) //go:embed VERSION @@ -122,7 +124,14 @@ func runSetupServer() { log.Printf("Setup wizard available at http://%s", addr) log.Println("Complete the setup wizard to configure Sub2API") - if err := r.Run(addr); err != nil { + server := &http.Server{ + Addr: addr, + Handler: h2c.NewHandler(r, &http2.Server{}), + ReadHeaderTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Failed to start setup server: %v", err) } } diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 52d5c926..f3e15006 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -14,6 +14,8 @@ import ( "github.com/gin-gonic/gin" "github.com/google/wire" "github.com/redis/go-redis/v9" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) // ProviderSet 提供服务器层的依赖 @@ -56,9 +58,10 @@ func ProvideRouter( // ProvideHTTPServer 提供 HTTP 服务器 func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { + handler := h2c.NewHandler(router, &http2.Server{}) return &http.Server{ Addr: cfg.Server.Address(), - Handler: router, + Handler: handler, // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 diff --git a/deploy/Caddyfile b/deploy/Caddyfile index e5636213..d4144057 100644 --- a/deploy/Caddyfile +++ b/deploy/Caddyfile @@ -87,6 +87,7 @@ example.com { # 连接池优化 transport http { + versions h2c h1 keepalive 120s keepalive_idle_conns 256 read_buffer 16KB From 13262a569845014ba23390577f4ca19b716db448 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 29 Jan 2026 16:18:38 +0800 Subject: [PATCH 002/363] =?UTF-8?q?feat(sora):=20=E6=96=B0=E5=A2=9E=20Sora?= =?UTF-8?q?=20=E5=B9=B3=E5=8F=B0=E6=94=AF=E6=8C=81=E5=B9=B6=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E9=AB=98=E5=8D=B1=E5=AE=89=E5=85=A8=E5=92=8C=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增功能: - 新增 Sora 账号管理和 OAuth 认证 - 新增 Sora 视频/图片生成 API 网关 - 新增 Sora 任务调度和缓存机制 - 新增 Sora 使用统计和计费支持 - 前端增加 Sora 平台配置界面 安全修复(代码审核): - [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击 - [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽 - [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置 BUG 修复(代码审核): - [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏 - [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏 性能优化(代码审核): - [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销 技术细节: - 使用 io.LimitReader 限制所有外部输入的大小 - 添加 urlvalidator 验证防止 SSRF 攻击 - 使用 sync.Map 实现线程安全的包级缓存 - 优化并发槽位管理,添加 releaseAll 模式防止泄漏 影响范围: - 后端:新增 Sora 相关数据模型、服务、网关和管理接口 - 前端:新增 Sora 平台配置、账号管理和监控界面 - 配置:新增 Sora 相关配置项和环境变量 Co-Authored-By: Claude Sonnet 4.5 --- backend/cmd/server/wire.go | 14 + backend/cmd/server/wire_gen.go | 30 +- backend/ent/client.go | 586 +- backend/ent/ent.go | 8 + backend/ent/hook/hook.go | 48 + backend/ent/intercept/intercept.go | 120 + backend/ent/migrate/schema.go | 182 + backend/ent/mutation.go | 5306 +++++++++++++++++ backend/ent/predicate/predicate.go | 12 + backend/ent/runtime/runtime.go | 148 + backend/ent/schema/sora_account.go | 115 + backend/ent/schema/sora_cache_file.go | 60 + backend/ent/schema/sora_task.go | 70 + backend/ent/schema/sora_usage_stat.go | 71 + backend/ent/soraaccount.go | 422 ++ backend/ent/soraaccount/soraaccount.go | 278 + backend/ent/soraaccount/where.go | 1500 +++++ backend/ent/soraaccount_create.go | 2367 ++++++++ backend/ent/soraaccount_delete.go | 88 + backend/ent/soraaccount_query.go | 564 ++ backend/ent/soraaccount_update.go | 1402 +++++ backend/ent/soracachefile.go | 197 + backend/ent/soracachefile/soracachefile.go | 124 + backend/ent/soracachefile/where.go | 610 ++ backend/ent/soracachefile_create.go | 1004 ++++ backend/ent/soracachefile_delete.go | 88 + backend/ent/soracachefile_query.go | 564 ++ backend/ent/soracachefile_update.go | 596 ++ backend/ent/soratask.go | 227 + backend/ent/soratask/soratask.go | 146 + backend/ent/soratask/where.go | 745 +++ backend/ent/soratask_create.go | 1189 ++++ backend/ent/soratask_delete.go | 88 + backend/ent/soratask_query.go | 564 ++ backend/ent/soratask_update.go | 710 +++ backend/ent/sorausagestat.go | 231 + backend/ent/sorausagestat/sorausagestat.go | 160 + backend/ent/sorausagestat/where.go | 630 ++ backend/ent/sorausagestat_create.go | 1334 +++++ backend/ent/sorausagestat_delete.go | 88 + backend/ent/sorausagestat_query.go | 564 ++ backend/ent/sorausagestat_update.go | 748 +++ backend/ent/tx.go | 12 + backend/internal/config/config.go | 51 + .../internal/handler/admin/group_handler.go | 4 +- .../internal/handler/admin/setting_handler.go | 228 +- .../handler/admin/sora_account_handler.go | 355 ++ backend/internal/handler/dto/mappers.go | 66 + backend/internal/handler/dto/settings.go | 19 + backend/internal/handler/dto/types.go | 50 + backend/internal/handler/gateway_handler.go | 8 + backend/internal/handler/handler.go | 2 + backend/internal/handler/ops_error_logger.go | 2 + .../internal/handler/sora_gateway_handler.go | 364 ++ backend/internal/handler/wire.go | 6 + backend/internal/pkg/sora/character.go | 148 + backend/internal/pkg/sora/client.go | 612 ++ backend/internal/pkg/sora/models.go | 263 + backend/internal/pkg/sora/prompt.go | 63 + backend/internal/pkg/uuidv7/uuidv7.go | 31 + backend/internal/repository/sora_repo.go | 498 ++ backend/internal/repository/wire.go | 4 + backend/internal/server/router.go | 19 + backend/internal/server/routes/admin.go | 14 + backend/internal/server/routes/gateway.go | 1 + backend/internal/service/domain_constants.go | 23 + .../service/scheduler_snapshot_service.go | 4 +- backend/internal/service/setting_service.go | 230 +- backend/internal/service/settings_view.go | 19 + .../service/sora_cache_cleanup_service.go | 156 + .../internal/service/sora_cache_service.go | 246 + backend/internal/service/sora_cache_utils.go | 28 + .../internal/service/sora_gateway_service.go | 853 +++ backend/internal/service/sora_repository.go | 113 + .../service/sora_token_refresh_service.go | 313 + backend/internal/service/wire.go | 28 + backend/migrations/044_add_sora_tables.sql | 94 + config.yaml | 60 + deploy/.env.example | 22 + deploy/config.example.yaml | 60 + frontend/src/api/admin/settings.ts | 36 + .../components/account/CreateAccountModal.vue | 39 +- .../components/account/EditAccountModal.vue | 35 +- .../account/OAuthAuthorizationFlow.vue | 2 +- .../admin/account/AccountTableFilters.vue | 2 +- frontend/src/components/common/GroupBadge.vue | 8 + .../src/components/common/PlatformIcon.vue | 4 + .../components/common/PlatformTypeBadge.vue | 7 + frontend/src/components/keys/UseKeyModal.vue | 9 +- frontend/src/composables/useModelWhitelist.ts | 35 + frontend/src/i18n/locales/en.ts | 48 + frontend/src/i18n/locales/zh.ts | 48 + frontend/src/types/index.ts | 4 +- frontend/src/views/admin/GroupsView.vue | 2 + frontend/src/views/admin/SettingsView.vue | 261 +- .../ops/components/OpsDashboardHeader.vue | 1 + frontend/src/views/user/KeysView.vue | 1 + 97 files changed, 29541 insertions(+), 68 deletions(-) create mode 100644 backend/ent/schema/sora_account.go create mode 100644 backend/ent/schema/sora_cache_file.go create mode 100644 backend/ent/schema/sora_task.go create mode 100644 backend/ent/schema/sora_usage_stat.go create mode 100644 backend/ent/soraaccount.go create mode 100644 backend/ent/soraaccount/soraaccount.go create mode 100644 backend/ent/soraaccount/where.go create mode 100644 backend/ent/soraaccount_create.go create mode 100644 backend/ent/soraaccount_delete.go create mode 100644 backend/ent/soraaccount_query.go create mode 100644 backend/ent/soraaccount_update.go create mode 100644 backend/ent/soracachefile.go create mode 100644 backend/ent/soracachefile/soracachefile.go create mode 100644 backend/ent/soracachefile/where.go create mode 100644 backend/ent/soracachefile_create.go create mode 100644 backend/ent/soracachefile_delete.go create mode 100644 backend/ent/soracachefile_query.go create mode 100644 backend/ent/soracachefile_update.go create mode 100644 backend/ent/soratask.go create mode 100644 backend/ent/soratask/soratask.go create mode 100644 backend/ent/soratask/where.go create mode 100644 backend/ent/soratask_create.go create mode 100644 backend/ent/soratask_delete.go create mode 100644 backend/ent/soratask_query.go create mode 100644 backend/ent/soratask_update.go create mode 100644 backend/ent/sorausagestat.go create mode 100644 backend/ent/sorausagestat/sorausagestat.go create mode 100644 backend/ent/sorausagestat/where.go create mode 100644 backend/ent/sorausagestat_create.go create mode 100644 backend/ent/sorausagestat_delete.go create mode 100644 backend/ent/sorausagestat_query.go create mode 100644 backend/ent/sorausagestat_update.go create mode 100644 backend/internal/handler/admin/sora_account_handler.go create mode 100644 backend/internal/handler/sora_gateway_handler.go create mode 100644 backend/internal/pkg/sora/character.go create mode 100644 backend/internal/pkg/sora/client.go create mode 100644 backend/internal/pkg/sora/models.go create mode 100644 backend/internal/pkg/sora/prompt.go create mode 100644 backend/internal/pkg/uuidv7/uuidv7.go create mode 100644 backend/internal/repository/sora_repo.go create mode 100644 backend/internal/service/sora_cache_cleanup_service.go create mode 100644 backend/internal/service/sora_cache_service.go create mode 100644 backend/internal/service/sora_cache_utils.go create mode 100644 backend/internal/service/sora_gateway_service.go create mode 100644 backend/internal/service/sora_repository.go create mode 100644 backend/internal/service/sora_token_refresh_service.go create mode 100644 backend/migrations/044_add_sora_tables.sql diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 5ef04a66..d9d45793 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -69,6 +69,8 @@ func provideCleanup( opsScheduledReport *service.OpsScheduledReportService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, + soraTokenRefresh *service.SoraTokenRefreshService, + soraCacheCleanup *service.SoraCacheCleanupService, accountExpiry *service.AccountExpiryService, usageCleanup *service.UsageCleanupService, pricing *service.PricingService, @@ -134,6 +136,18 @@ func provideCleanup( tokenRefresh.Stop() return nil }}, + {"SoraTokenRefreshService", func() error { + if soraTokenRefresh != nil { + soraTokenRefresh.Stop() + } + return nil + }}, + {"SoraCacheCleanupService", func() error { + if soraCacheCleanup != nil { + soraCacheCleanup.Stop() + } + return nil + }}, {"AccountExpiryService", func() error { accountExpiry.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 7b22a31e..befa93d7 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -129,6 +129,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyHandler := admin.NewProxyHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService) promoHandler := admin.NewPromoHandler(promoService) + soraAccountRepository := repository.NewSoraAccountRepository(client) + soraUsageStatRepository := repository.NewSoraUsageStatRepository(client, db) + soraAccountHandler := admin.NewSoraAccountHandler(adminService, soraAccountRepository, soraUsageStatRepository) opsRepository := repository.NewOpsRepository(db) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) @@ -161,11 +164,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, soraAccountHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) + soraTaskRepository := repository.NewSoraTaskRepository(client) + soraCacheFileRepository := repository.NewSoraCacheFileRepository(client) + soraCacheService := service.NewSoraCacheService(configConfig, soraCacheFileRepository, settingService, accountRepository, httpUpstream) + soraGatewayService := service.NewSoraGatewayService(accountRepository, soraAccountRepository, soraUsageStatRepository, soraTaskRepository, soraCacheService, settingService, concurrencyService, configConfig, httpUpstream) + soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -177,8 +185,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) + soraTokenRefreshService := service.ProvideSoraTokenRefreshService(accountRepository, soraAccountRepository, settingService, httpUpstream, configConfig) + soraCacheCleanupService := service.ProvideSoraCacheCleanupService(soraCacheFileRepository, settingService, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, soraTokenRefreshService, soraCacheCleanupService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -210,6 +220,8 @@ func provideCleanup( opsScheduledReport *service.OpsScheduledReportService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, + soraTokenRefresh *service.SoraTokenRefreshService, + soraCacheCleanup *service.SoraCacheCleanupService, accountExpiry *service.AccountExpiryService, usageCleanup *service.UsageCleanupService, pricing *service.PricingService, @@ -274,6 +286,18 @@ func provideCleanup( tokenRefresh.Stop() return nil }}, + {"SoraTokenRefreshService", func() error { + if soraTokenRefresh != nil { + soraTokenRefresh.Stop() + } + return nil + }}, + {"SoraCacheCleanupService", func() error { + if soraCacheCleanup != nil { + soraCacheCleanup.Stop() + } + return nil + }}, {"AccountExpiryService", func() error { accountExpiry.Stop() return nil diff --git a/backend/ent/client.go b/backend/ent/client.go index f6c13e84..58302850 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -24,6 +24,10 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" + "github.com/Wei-Shaw/sub2api/ent/soratask" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -58,6 +62,14 @@ type Client struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // SoraAccount is the client for interacting with the SoraAccount builders. + SoraAccount *SoraAccountClient + // SoraCacheFile is the client for interacting with the SoraCacheFile builders. + SoraCacheFile *SoraCacheFileClient + // SoraTask is the client for interacting with the SoraTask builders. + SoraTask *SoraTaskClient + // SoraUsageStat is the client for interacting with the SoraUsageStat builders. + SoraUsageStat *SoraUsageStatClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. UsageCleanupTask *UsageCleanupTaskClient // UsageLog is the client for interacting with the UsageLog builders. @@ -92,6 +104,10 @@ func (c *Client) init() { c.Proxy = NewProxyClient(c.config) c.RedeemCode = NewRedeemCodeClient(c.config) c.Setting = NewSettingClient(c.config) + c.SoraAccount = NewSoraAccountClient(c.config) + c.SoraCacheFile = NewSoraCacheFileClient(c.config) + c.SoraTask = NewSoraTaskClient(c.config) + c.SoraUsageStat = NewSoraUsageStatClient(c.config) c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config) c.UsageLog = NewUsageLogClient(c.config) c.User = NewUserClient(c.config) @@ -200,6 +216,10 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + SoraAccount: NewSoraAccountClient(cfg), + SoraCacheFile: NewSoraCacheFileClient(cfg), + SoraTask: NewSoraTaskClient(cfg), + SoraUsageStat: NewSoraUsageStatClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), @@ -235,6 +255,10 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + SoraAccount: NewSoraAccountClient(cfg), + SoraCacheFile: NewSoraCacheFileClient(cfg), + SoraTask: NewSoraTaskClient(cfg), + SoraUsageStat: NewSoraUsageStatClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), @@ -272,9 +296,9 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.Proxy, c.RedeemCode, c.Setting, c.SoraAccount, c.SoraCacheFile, c.SoraTask, + c.SoraUsageStat, c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) } @@ -285,9 +309,9 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.Proxy, c.RedeemCode, c.Setting, c.SoraAccount, c.SoraCacheFile, c.SoraTask, + c.SoraUsageStat, c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -314,6 +338,14 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.RedeemCode.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) + case *SoraAccountMutation: + return c.SoraAccount.mutate(ctx, m) + case *SoraCacheFileMutation: + return c.SoraCacheFile.mutate(ctx, m) + case *SoraTaskMutation: + return c.SoraTask.mutate(ctx, m) + case *SoraUsageStatMutation: + return c.SoraUsageStat.mutate(ctx, m) case *UsageCleanupTaskMutation: return c.UsageCleanupTask.mutate(ctx, m) case *UsageLogMutation: @@ -1857,6 +1889,538 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value, } } +// SoraAccountClient is a client for the SoraAccount schema. +type SoraAccountClient struct { + config +} + +// NewSoraAccountClient returns a client for the SoraAccount from the given config. +func NewSoraAccountClient(c config) *SoraAccountClient { + return &SoraAccountClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `soraaccount.Hooks(f(g(h())))`. +func (c *SoraAccountClient) Use(hooks ...Hook) { + c.hooks.SoraAccount = append(c.hooks.SoraAccount, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `soraaccount.Intercept(f(g(h())))`. +func (c *SoraAccountClient) Intercept(interceptors ...Interceptor) { + c.inters.SoraAccount = append(c.inters.SoraAccount, interceptors...) +} + +// Create returns a builder for creating a SoraAccount entity. +func (c *SoraAccountClient) Create() *SoraAccountCreate { + mutation := newSoraAccountMutation(c.config, OpCreate) + return &SoraAccountCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of SoraAccount entities. +func (c *SoraAccountClient) CreateBulk(builders ...*SoraAccountCreate) *SoraAccountCreateBulk { + return &SoraAccountCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SoraAccountClient) MapCreateBulk(slice any, setFunc func(*SoraAccountCreate, int)) *SoraAccountCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SoraAccountCreateBulk{err: fmt.Errorf("calling to SoraAccountClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SoraAccountCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SoraAccountCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for SoraAccount. +func (c *SoraAccountClient) Update() *SoraAccountUpdate { + mutation := newSoraAccountMutation(c.config, OpUpdate) + return &SoraAccountUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SoraAccountClient) UpdateOne(_m *SoraAccount) *SoraAccountUpdateOne { + mutation := newSoraAccountMutation(c.config, OpUpdateOne, withSoraAccount(_m)) + return &SoraAccountUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SoraAccountClient) UpdateOneID(id int64) *SoraAccountUpdateOne { + mutation := newSoraAccountMutation(c.config, OpUpdateOne, withSoraAccountID(id)) + return &SoraAccountUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for SoraAccount. +func (c *SoraAccountClient) Delete() *SoraAccountDelete { + mutation := newSoraAccountMutation(c.config, OpDelete) + return &SoraAccountDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SoraAccountClient) DeleteOne(_m *SoraAccount) *SoraAccountDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SoraAccountClient) DeleteOneID(id int64) *SoraAccountDeleteOne { + builder := c.Delete().Where(soraaccount.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SoraAccountDeleteOne{builder} +} + +// Query returns a query builder for SoraAccount. +func (c *SoraAccountClient) Query() *SoraAccountQuery { + return &SoraAccountQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSoraAccount}, + inters: c.Interceptors(), + } +} + +// Get returns a SoraAccount entity by its id. +func (c *SoraAccountClient) Get(ctx context.Context, id int64) (*SoraAccount, error) { + return c.Query().Where(soraaccount.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SoraAccountClient) GetX(ctx context.Context, id int64) *SoraAccount { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SoraAccountClient) Hooks() []Hook { + return c.hooks.SoraAccount +} + +// Interceptors returns the client interceptors. +func (c *SoraAccountClient) Interceptors() []Interceptor { + return c.inters.SoraAccount +} + +func (c *SoraAccountClient) mutate(ctx context.Context, m *SoraAccountMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SoraAccountCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SoraAccountUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SoraAccountUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SoraAccountDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SoraAccount mutation op: %q", m.Op()) + } +} + +// SoraCacheFileClient is a client for the SoraCacheFile schema. +type SoraCacheFileClient struct { + config +} + +// NewSoraCacheFileClient returns a client for the SoraCacheFile from the given config. +func NewSoraCacheFileClient(c config) *SoraCacheFileClient { + return &SoraCacheFileClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `soracachefile.Hooks(f(g(h())))`. +func (c *SoraCacheFileClient) Use(hooks ...Hook) { + c.hooks.SoraCacheFile = append(c.hooks.SoraCacheFile, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `soracachefile.Intercept(f(g(h())))`. +func (c *SoraCacheFileClient) Intercept(interceptors ...Interceptor) { + c.inters.SoraCacheFile = append(c.inters.SoraCacheFile, interceptors...) +} + +// Create returns a builder for creating a SoraCacheFile entity. +func (c *SoraCacheFileClient) Create() *SoraCacheFileCreate { + mutation := newSoraCacheFileMutation(c.config, OpCreate) + return &SoraCacheFileCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of SoraCacheFile entities. +func (c *SoraCacheFileClient) CreateBulk(builders ...*SoraCacheFileCreate) *SoraCacheFileCreateBulk { + return &SoraCacheFileCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SoraCacheFileClient) MapCreateBulk(slice any, setFunc func(*SoraCacheFileCreate, int)) *SoraCacheFileCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SoraCacheFileCreateBulk{err: fmt.Errorf("calling to SoraCacheFileClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SoraCacheFileCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SoraCacheFileCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for SoraCacheFile. +func (c *SoraCacheFileClient) Update() *SoraCacheFileUpdate { + mutation := newSoraCacheFileMutation(c.config, OpUpdate) + return &SoraCacheFileUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SoraCacheFileClient) UpdateOne(_m *SoraCacheFile) *SoraCacheFileUpdateOne { + mutation := newSoraCacheFileMutation(c.config, OpUpdateOne, withSoraCacheFile(_m)) + return &SoraCacheFileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SoraCacheFileClient) UpdateOneID(id int64) *SoraCacheFileUpdateOne { + mutation := newSoraCacheFileMutation(c.config, OpUpdateOne, withSoraCacheFileID(id)) + return &SoraCacheFileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for SoraCacheFile. +func (c *SoraCacheFileClient) Delete() *SoraCacheFileDelete { + mutation := newSoraCacheFileMutation(c.config, OpDelete) + return &SoraCacheFileDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SoraCacheFileClient) DeleteOne(_m *SoraCacheFile) *SoraCacheFileDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SoraCacheFileClient) DeleteOneID(id int64) *SoraCacheFileDeleteOne { + builder := c.Delete().Where(soracachefile.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SoraCacheFileDeleteOne{builder} +} + +// Query returns a query builder for SoraCacheFile. +func (c *SoraCacheFileClient) Query() *SoraCacheFileQuery { + return &SoraCacheFileQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSoraCacheFile}, + inters: c.Interceptors(), + } +} + +// Get returns a SoraCacheFile entity by its id. +func (c *SoraCacheFileClient) Get(ctx context.Context, id int64) (*SoraCacheFile, error) { + return c.Query().Where(soracachefile.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SoraCacheFileClient) GetX(ctx context.Context, id int64) *SoraCacheFile { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SoraCacheFileClient) Hooks() []Hook { + return c.hooks.SoraCacheFile +} + +// Interceptors returns the client interceptors. +func (c *SoraCacheFileClient) Interceptors() []Interceptor { + return c.inters.SoraCacheFile +} + +func (c *SoraCacheFileClient) mutate(ctx context.Context, m *SoraCacheFileMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SoraCacheFileCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SoraCacheFileUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SoraCacheFileUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SoraCacheFileDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SoraCacheFile mutation op: %q", m.Op()) + } +} + +// SoraTaskClient is a client for the SoraTask schema. +type SoraTaskClient struct { + config +} + +// NewSoraTaskClient returns a client for the SoraTask from the given config. +func NewSoraTaskClient(c config) *SoraTaskClient { + return &SoraTaskClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `soratask.Hooks(f(g(h())))`. +func (c *SoraTaskClient) Use(hooks ...Hook) { + c.hooks.SoraTask = append(c.hooks.SoraTask, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `soratask.Intercept(f(g(h())))`. +func (c *SoraTaskClient) Intercept(interceptors ...Interceptor) { + c.inters.SoraTask = append(c.inters.SoraTask, interceptors...) +} + +// Create returns a builder for creating a SoraTask entity. +func (c *SoraTaskClient) Create() *SoraTaskCreate { + mutation := newSoraTaskMutation(c.config, OpCreate) + return &SoraTaskCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of SoraTask entities. +func (c *SoraTaskClient) CreateBulk(builders ...*SoraTaskCreate) *SoraTaskCreateBulk { + return &SoraTaskCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SoraTaskClient) MapCreateBulk(slice any, setFunc func(*SoraTaskCreate, int)) *SoraTaskCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SoraTaskCreateBulk{err: fmt.Errorf("calling to SoraTaskClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SoraTaskCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SoraTaskCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for SoraTask. +func (c *SoraTaskClient) Update() *SoraTaskUpdate { + mutation := newSoraTaskMutation(c.config, OpUpdate) + return &SoraTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SoraTaskClient) UpdateOne(_m *SoraTask) *SoraTaskUpdateOne { + mutation := newSoraTaskMutation(c.config, OpUpdateOne, withSoraTask(_m)) + return &SoraTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SoraTaskClient) UpdateOneID(id int64) *SoraTaskUpdateOne { + mutation := newSoraTaskMutation(c.config, OpUpdateOne, withSoraTaskID(id)) + return &SoraTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for SoraTask. +func (c *SoraTaskClient) Delete() *SoraTaskDelete { + mutation := newSoraTaskMutation(c.config, OpDelete) + return &SoraTaskDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SoraTaskClient) DeleteOne(_m *SoraTask) *SoraTaskDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SoraTaskClient) DeleteOneID(id int64) *SoraTaskDeleteOne { + builder := c.Delete().Where(soratask.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SoraTaskDeleteOne{builder} +} + +// Query returns a query builder for SoraTask. +func (c *SoraTaskClient) Query() *SoraTaskQuery { + return &SoraTaskQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSoraTask}, + inters: c.Interceptors(), + } +} + +// Get returns a SoraTask entity by its id. +func (c *SoraTaskClient) Get(ctx context.Context, id int64) (*SoraTask, error) { + return c.Query().Where(soratask.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SoraTaskClient) GetX(ctx context.Context, id int64) *SoraTask { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SoraTaskClient) Hooks() []Hook { + return c.hooks.SoraTask +} + +// Interceptors returns the client interceptors. +func (c *SoraTaskClient) Interceptors() []Interceptor { + return c.inters.SoraTask +} + +func (c *SoraTaskClient) mutate(ctx context.Context, m *SoraTaskMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SoraTaskCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SoraTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SoraTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SoraTaskDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SoraTask mutation op: %q", m.Op()) + } +} + +// SoraUsageStatClient is a client for the SoraUsageStat schema. +type SoraUsageStatClient struct { + config +} + +// NewSoraUsageStatClient returns a client for the SoraUsageStat from the given config. +func NewSoraUsageStatClient(c config) *SoraUsageStatClient { + return &SoraUsageStatClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `sorausagestat.Hooks(f(g(h())))`. +func (c *SoraUsageStatClient) Use(hooks ...Hook) { + c.hooks.SoraUsageStat = append(c.hooks.SoraUsageStat, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `sorausagestat.Intercept(f(g(h())))`. +func (c *SoraUsageStatClient) Intercept(interceptors ...Interceptor) { + c.inters.SoraUsageStat = append(c.inters.SoraUsageStat, interceptors...) +} + +// Create returns a builder for creating a SoraUsageStat entity. +func (c *SoraUsageStatClient) Create() *SoraUsageStatCreate { + mutation := newSoraUsageStatMutation(c.config, OpCreate) + return &SoraUsageStatCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of SoraUsageStat entities. +func (c *SoraUsageStatClient) CreateBulk(builders ...*SoraUsageStatCreate) *SoraUsageStatCreateBulk { + return &SoraUsageStatCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SoraUsageStatClient) MapCreateBulk(slice any, setFunc func(*SoraUsageStatCreate, int)) *SoraUsageStatCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SoraUsageStatCreateBulk{err: fmt.Errorf("calling to SoraUsageStatClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SoraUsageStatCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SoraUsageStatCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for SoraUsageStat. +func (c *SoraUsageStatClient) Update() *SoraUsageStatUpdate { + mutation := newSoraUsageStatMutation(c.config, OpUpdate) + return &SoraUsageStatUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SoraUsageStatClient) UpdateOne(_m *SoraUsageStat) *SoraUsageStatUpdateOne { + mutation := newSoraUsageStatMutation(c.config, OpUpdateOne, withSoraUsageStat(_m)) + return &SoraUsageStatUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SoraUsageStatClient) UpdateOneID(id int64) *SoraUsageStatUpdateOne { + mutation := newSoraUsageStatMutation(c.config, OpUpdateOne, withSoraUsageStatID(id)) + return &SoraUsageStatUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for SoraUsageStat. +func (c *SoraUsageStatClient) Delete() *SoraUsageStatDelete { + mutation := newSoraUsageStatMutation(c.config, OpDelete) + return &SoraUsageStatDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SoraUsageStatClient) DeleteOne(_m *SoraUsageStat) *SoraUsageStatDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SoraUsageStatClient) DeleteOneID(id int64) *SoraUsageStatDeleteOne { + builder := c.Delete().Where(sorausagestat.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SoraUsageStatDeleteOne{builder} +} + +// Query returns a query builder for SoraUsageStat. +func (c *SoraUsageStatClient) Query() *SoraUsageStatQuery { + return &SoraUsageStatQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSoraUsageStat}, + inters: c.Interceptors(), + } +} + +// Get returns a SoraUsageStat entity by its id. +func (c *SoraUsageStatClient) Get(ctx context.Context, id int64) (*SoraUsageStat, error) { + return c.Query().Where(sorausagestat.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SoraUsageStatClient) GetX(ctx context.Context, id int64) *SoraUsageStat { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SoraUsageStatClient) Hooks() []Hook { + return c.hooks.SoraUsageStat +} + +// Interceptors returns the client interceptors. +func (c *SoraUsageStatClient) Interceptors() []Interceptor { + return c.inters.SoraUsageStat +} + +func (c *SoraUsageStatClient) mutate(ctx context.Context, m *SoraUsageStatMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SoraUsageStatCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SoraUsageStatUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SoraUsageStatUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SoraUsageStatDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SoraUsageStat mutation op: %q", m.Op()) + } +} + // UsageCleanupTaskClient is a client for the UsageCleanupTask schema. type UsageCleanupTaskClient struct { config @@ -3117,13 +3681,15 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy, - RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + RedeemCode, Setting, SoraAccount, SoraCacheFile, SoraTask, SoraUsageStat, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy, - RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + RedeemCode, Setting, SoraAccount, SoraCacheFile, SoraTask, SoraUsageStat, + UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, + UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 4bcc2642..e0b0a6cf 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -21,6 +21,10 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" + "github.com/Wei-Shaw/sub2api/ent/soratask" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -97,6 +101,10 @@ func checkColumn(t, c string) error { proxy.Table: proxy.ValidColumn, redeemcode.Table: redeemcode.ValidColumn, setting.Table: setting.ValidColumn, + soraaccount.Table: soraaccount.ValidColumn, + soracachefile.Table: soracachefile.ValidColumn, + soratask.Table: soratask.ValidColumn, + sorausagestat.Table: sorausagestat.ValidColumn, usagecleanuptask.Table: usagecleanuptask.ValidColumn, usagelog.Table: usagelog.ValidColumn, user.Table: user.ValidColumn, diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index edd84f5e..311b2cdc 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -117,6 +117,54 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m) } +// The SoraAccountFunc type is an adapter to allow the use of ordinary +// function as SoraAccount mutator. +type SoraAccountFunc func(context.Context, *ent.SoraAccountMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SoraAccountFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SoraAccountMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SoraAccountMutation", m) +} + +// The SoraCacheFileFunc type is an adapter to allow the use of ordinary +// function as SoraCacheFile mutator. +type SoraCacheFileFunc func(context.Context, *ent.SoraCacheFileMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SoraCacheFileFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SoraCacheFileMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SoraCacheFileMutation", m) +} + +// The SoraTaskFunc type is an adapter to allow the use of ordinary +// function as SoraTask mutator. +type SoraTaskFunc func(context.Context, *ent.SoraTaskMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SoraTaskFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SoraTaskMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SoraTaskMutation", m) +} + +// The SoraUsageStatFunc type is an adapter to allow the use of ordinary +// function as SoraUsageStat mutator. +type SoraUsageStatFunc func(context.Context, *ent.SoraUsageStatMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SoraUsageStatFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SoraUsageStatMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SoraUsageStatMutation", m) +} + // The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary // function as UsageCleanupTask mutator. type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index f18c0624..8181e70e 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -18,6 +18,10 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" + "github.com/Wei-Shaw/sub2api/ent/soratask" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -326,6 +330,114 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q) } +// The SoraAccountFunc type is an adapter to allow the use of ordinary function as a Querier. +type SoraAccountFunc func(context.Context, *ent.SoraAccountQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SoraAccountFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SoraAccountQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SoraAccountQuery", q) +} + +// The TraverseSoraAccount type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSoraAccount func(context.Context, *ent.SoraAccountQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSoraAccount) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSoraAccount) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SoraAccountQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SoraAccountQuery", q) +} + +// The SoraCacheFileFunc type is an adapter to allow the use of ordinary function as a Querier. +type SoraCacheFileFunc func(context.Context, *ent.SoraCacheFileQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SoraCacheFileFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SoraCacheFileQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SoraCacheFileQuery", q) +} + +// The TraverseSoraCacheFile type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSoraCacheFile func(context.Context, *ent.SoraCacheFileQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSoraCacheFile) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSoraCacheFile) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SoraCacheFileQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SoraCacheFileQuery", q) +} + +// The SoraTaskFunc type is an adapter to allow the use of ordinary function as a Querier. +type SoraTaskFunc func(context.Context, *ent.SoraTaskQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SoraTaskFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SoraTaskQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SoraTaskQuery", q) +} + +// The TraverseSoraTask type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSoraTask func(context.Context, *ent.SoraTaskQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSoraTask) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSoraTask) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SoraTaskQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SoraTaskQuery", q) +} + +// The SoraUsageStatFunc type is an adapter to allow the use of ordinary function as a Querier. +type SoraUsageStatFunc func(context.Context, *ent.SoraUsageStatQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SoraUsageStatFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SoraUsageStatQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SoraUsageStatQuery", q) +} + +// The TraverseSoraUsageStat type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSoraUsageStat func(context.Context, *ent.SoraUsageStatQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSoraUsageStat) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSoraUsageStat) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SoraUsageStatQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SoraUsageStatQuery", q) +} + // The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier. type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error) @@ -536,6 +648,14 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil + case *ent.SoraAccountQuery: + return &query[*ent.SoraAccountQuery, predicate.SoraAccount, soraaccount.OrderOption]{typ: ent.TypeSoraAccount, tq: q}, nil + case *ent.SoraCacheFileQuery: + return &query[*ent.SoraCacheFileQuery, predicate.SoraCacheFile, soracachefile.OrderOption]{typ: ent.TypeSoraCacheFile, tq: q}, nil + case *ent.SoraTaskQuery: + return &query[*ent.SoraTaskQuery, predicate.SoraTask, soratask.OrderOption]{typ: ent.TypeSoraTask, tq: q}, nil + case *ent.SoraUsageStatQuery: + return &query[*ent.SoraUsageStatQuery, predicate.SoraUsageStat, sorausagestat.OrderOption]{typ: ent.TypeSoraUsageStat, tq: q}, nil case *ent.UsageCleanupTaskQuery: return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil case *ent.UsageLogQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d1f05186..a8a247b5 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -434,6 +434,172 @@ var ( Columns: SettingsColumns, PrimaryKey: []*schema.Column{SettingsColumns[0]}, } + // SoraAccountsColumns holds the columns for the "sora_accounts" table. + SoraAccountsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "account_id", Type: field.TypeInt64}, + {Name: "access_token", Type: field.TypeString, Nullable: true}, + {Name: "session_token", Type: field.TypeString, Nullable: true}, + {Name: "refresh_token", Type: field.TypeString, Nullable: true}, + {Name: "client_id", Type: field.TypeString, Nullable: true}, + {Name: "email", Type: field.TypeString, Nullable: true}, + {Name: "username", Type: field.TypeString, Nullable: true}, + {Name: "remark", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "use_count", Type: field.TypeInt, Default: 0}, + {Name: "plan_type", Type: field.TypeString, Nullable: true}, + {Name: "plan_title", Type: field.TypeString, Nullable: true}, + {Name: "subscription_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "sora_supported", Type: field.TypeBool, Default: false}, + {Name: "sora_invite_code", Type: field.TypeString, Nullable: true}, + {Name: "sora_redeemed_count", Type: field.TypeInt, Default: 0}, + {Name: "sora_remaining_count", Type: field.TypeInt, Default: 0}, + {Name: "sora_total_count", Type: field.TypeInt, Default: 0}, + {Name: "sora_cooldown_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "cooled_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "image_enabled", Type: field.TypeBool, Default: true}, + {Name: "video_enabled", Type: field.TypeBool, Default: true}, + {Name: "image_concurrency", Type: field.TypeInt, Default: -1}, + {Name: "video_concurrency", Type: field.TypeInt, Default: -1}, + {Name: "is_expired", Type: field.TypeBool, Default: false}, + } + // SoraAccountsTable holds the schema information for the "sora_accounts" table. + SoraAccountsTable = &schema.Table{ + Name: "sora_accounts", + Columns: SoraAccountsColumns, + PrimaryKey: []*schema.Column{SoraAccountsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "soraaccount_account_id", + Unique: true, + Columns: []*schema.Column{SoraAccountsColumns[3]}, + }, + { + Name: "soraaccount_plan_type", + Unique: false, + Columns: []*schema.Column{SoraAccountsColumns[12]}, + }, + { + Name: "soraaccount_sora_supported", + Unique: false, + Columns: []*schema.Column{SoraAccountsColumns[15]}, + }, + { + Name: "soraaccount_image_enabled", + Unique: false, + Columns: []*schema.Column{SoraAccountsColumns[22]}, + }, + { + Name: "soraaccount_video_enabled", + Unique: false, + Columns: []*schema.Column{SoraAccountsColumns[23]}, + }, + }, + } + // SoraCacheFilesColumns holds the columns for the "sora_cache_files" table. + SoraCacheFilesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "task_id", Type: field.TypeString, Nullable: true, Size: 120}, + {Name: "account_id", Type: field.TypeInt64}, + {Name: "user_id", Type: field.TypeInt64}, + {Name: "media_type", Type: field.TypeString, Size: 32}, + {Name: "original_url", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "cache_path", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "cache_url", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "size_bytes", Type: field.TypeInt64, Default: 0}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + } + // SoraCacheFilesTable holds the schema information for the "sora_cache_files" table. + SoraCacheFilesTable = &schema.Table{ + Name: "sora_cache_files", + Columns: SoraCacheFilesColumns, + PrimaryKey: []*schema.Column{SoraCacheFilesColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "soracachefile_account_id", + Unique: false, + Columns: []*schema.Column{SoraCacheFilesColumns[2]}, + }, + { + Name: "soracachefile_user_id", + Unique: false, + Columns: []*schema.Column{SoraCacheFilesColumns[3]}, + }, + { + Name: "soracachefile_media_type", + Unique: false, + Columns: []*schema.Column{SoraCacheFilesColumns[4]}, + }, + }, + } + // SoraTasksColumns holds the columns for the "sora_tasks" table. + SoraTasksColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "task_id", Type: field.TypeString, Unique: true, Size: 120}, + {Name: "account_id", Type: field.TypeInt64}, + {Name: "model", Type: field.TypeString, Size: 120}, + {Name: "prompt", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "status", Type: field.TypeString, Size: 32, Default: "processing"}, + {Name: "progress", Type: field.TypeFloat64, Default: 0}, + {Name: "result_urls", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "retry_count", Type: field.TypeInt, Default: 0}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "completed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + } + // SoraTasksTable holds the schema information for the "sora_tasks" table. + SoraTasksTable = &schema.Table{ + Name: "sora_tasks", + Columns: SoraTasksColumns, + PrimaryKey: []*schema.Column{SoraTasksColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "soratask_account_id", + Unique: false, + Columns: []*schema.Column{SoraTasksColumns[2]}, + }, + { + Name: "soratask_status", + Unique: false, + Columns: []*schema.Column{SoraTasksColumns[5]}, + }, + }, + } + // SoraUsageStatsColumns holds the columns for the "sora_usage_stats" table. + SoraUsageStatsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "account_id", Type: field.TypeInt64}, + {Name: "image_count", Type: field.TypeInt, Default: 0}, + {Name: "video_count", Type: field.TypeInt, Default: 0}, + {Name: "error_count", Type: field.TypeInt, Default: 0}, + {Name: "last_error_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "today_image_count", Type: field.TypeInt, Default: 0}, + {Name: "today_video_count", Type: field.TypeInt, Default: 0}, + {Name: "today_error_count", Type: field.TypeInt, Default: 0}, + {Name: "today_date", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "date"}}, + {Name: "consecutive_error_count", Type: field.TypeInt, Default: 0}, + } + // SoraUsageStatsTable holds the schema information for the "sora_usage_stats" table. + SoraUsageStatsTable = &schema.Table{ + Name: "sora_usage_stats", + Columns: SoraUsageStatsColumns, + PrimaryKey: []*schema.Column{SoraUsageStatsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "sorausagestat_account_id", + Unique: true, + Columns: []*schema.Column{SoraUsageStatsColumns[3]}, + }, + { + Name: "sorausagestat_today_date", + Unique: false, + Columns: []*schema.Column{SoraUsageStatsColumns[11]}, + }, + }, + } // UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table. UsageCleanupTasksColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -843,6 +1009,10 @@ var ( ProxiesTable, RedeemCodesTable, SettingsTable, + SoraAccountsTable, + SoraCacheFilesTable, + SoraTasksTable, + SoraUsageStatsTable, UsageCleanupTasksTable, UsageLogsTable, UsersTable, @@ -890,6 +1060,18 @@ func init() { SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } + SoraAccountsTable.Annotation = &entsql.Annotation{ + Table: "sora_accounts", + } + SoraCacheFilesTable.Annotation = &entsql.Annotation{ + Table: "sora_cache_files", + } + SoraTasksTable.Annotation = &entsql.Annotation{ + Table: "sora_tasks", + } + SoraUsageStatsTable.Annotation = &entsql.Annotation{ + Table: "sora_usage_stats", + } UsageCleanupTasksTable.Annotation = &entsql.Annotation{ Table: "usage_cleanup_tasks", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 9b330616..eaf5b483 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -22,6 +22,10 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" + "github.com/Wei-Shaw/sub2api/ent/soratask" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -49,6 +53,10 @@ const ( TypeProxy = "Proxy" TypeRedeemCode = "RedeemCode" TypeSetting = "Setting" + TypeSoraAccount = "SoraAccount" + TypeSoraCacheFile = "SoraCacheFile" + TypeSoraTask = "SoraTask" + TypeSoraUsageStat = "SoraUsageStat" TypeUsageCleanupTask = "UsageCleanupTask" TypeUsageLog = "UsageLog" TypeUser = "User" @@ -10373,6 +10381,5304 @@ func (m *SettingMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Setting edge %s", name) } +// SoraAccountMutation represents an operation that mutates the SoraAccount nodes in the graph. +type SoraAccountMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + account_id *int64 + addaccount_id *int64 + access_token *string + session_token *string + refresh_token *string + client_id *string + email *string + username *string + remark *string + use_count *int + adduse_count *int + plan_type *string + plan_title *string + subscription_end *time.Time + sora_supported *bool + sora_invite_code *string + sora_redeemed_count *int + addsora_redeemed_count *int + sora_remaining_count *int + addsora_remaining_count *int + sora_total_count *int + addsora_total_count *int + sora_cooldown_until *time.Time + cooled_until *time.Time + image_enabled *bool + video_enabled *bool + image_concurrency *int + addimage_concurrency *int + video_concurrency *int + addvideo_concurrency *int + is_expired *bool + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*SoraAccount, error) + predicates []predicate.SoraAccount +} + +var _ ent.Mutation = (*SoraAccountMutation)(nil) + +// soraaccountOption allows management of the mutation configuration using functional options. +type soraaccountOption func(*SoraAccountMutation) + +// newSoraAccountMutation creates new mutation for the SoraAccount entity. +func newSoraAccountMutation(c config, op Op, opts ...soraaccountOption) *SoraAccountMutation { + m := &SoraAccountMutation{ + config: c, + op: op, + typ: TypeSoraAccount, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSoraAccountID sets the ID field of the mutation. +func withSoraAccountID(id int64) soraaccountOption { + return func(m *SoraAccountMutation) { + var ( + err error + once sync.Once + value *SoraAccount + ) + m.oldValue = func(ctx context.Context) (*SoraAccount, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().SoraAccount.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSoraAccount sets the old SoraAccount of the mutation. +func withSoraAccount(node *SoraAccount) soraaccountOption { + return func(m *SoraAccountMutation) { + m.oldValue = func(context.Context) (*SoraAccount, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SoraAccountMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SoraAccountMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SoraAccountMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SoraAccountMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().SoraAccount.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *SoraAccountMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SoraAccountMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SoraAccountMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *SoraAccountMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *SoraAccountMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *SoraAccountMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetAccountID sets the "account_id" field. +func (m *SoraAccountMutation) SetAccountID(i int64) { + m.account_id = &i + m.addaccount_id = nil +} + +// AccountID returns the value of the "account_id" field in the mutation. +func (m *SoraAccountMutation) AccountID() (r int64, exists bool) { + v := m.account_id + if v == nil { + return + } + return *v, true +} + +// OldAccountID returns the old "account_id" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldAccountID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccountID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccountID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccountID: %w", err) + } + return oldValue.AccountID, nil +} + +// AddAccountID adds i to the "account_id" field. +func (m *SoraAccountMutation) AddAccountID(i int64) { + if m.addaccount_id != nil { + *m.addaccount_id += i + } else { + m.addaccount_id = &i + } +} + +// AddedAccountID returns the value that was added to the "account_id" field in this mutation. +func (m *SoraAccountMutation) AddedAccountID() (r int64, exists bool) { + v := m.addaccount_id + if v == nil { + return + } + return *v, true +} + +// ResetAccountID resets all changes to the "account_id" field. +func (m *SoraAccountMutation) ResetAccountID() { + m.account_id = nil + m.addaccount_id = nil +} + +// SetAccessToken sets the "access_token" field. +func (m *SoraAccountMutation) SetAccessToken(s string) { + m.access_token = &s +} + +// AccessToken returns the value of the "access_token" field in the mutation. +func (m *SoraAccountMutation) AccessToken() (r string, exists bool) { + v := m.access_token + if v == nil { + return + } + return *v, true +} + +// OldAccessToken returns the old "access_token" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldAccessToken(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccessToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccessToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccessToken: %w", err) + } + return oldValue.AccessToken, nil +} + +// ClearAccessToken clears the value of the "access_token" field. +func (m *SoraAccountMutation) ClearAccessToken() { + m.access_token = nil + m.clearedFields[soraaccount.FieldAccessToken] = struct{}{} +} + +// AccessTokenCleared returns if the "access_token" field was cleared in this mutation. +func (m *SoraAccountMutation) AccessTokenCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldAccessToken] + return ok +} + +// ResetAccessToken resets all changes to the "access_token" field. +func (m *SoraAccountMutation) ResetAccessToken() { + m.access_token = nil + delete(m.clearedFields, soraaccount.FieldAccessToken) +} + +// SetSessionToken sets the "session_token" field. +func (m *SoraAccountMutation) SetSessionToken(s string) { + m.session_token = &s +} + +// SessionToken returns the value of the "session_token" field in the mutation. +func (m *SoraAccountMutation) SessionToken() (r string, exists bool) { + v := m.session_token + if v == nil { + return + } + return *v, true +} + +// OldSessionToken returns the old "session_token" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSessionToken(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSessionToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSessionToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSessionToken: %w", err) + } + return oldValue.SessionToken, nil +} + +// ClearSessionToken clears the value of the "session_token" field. +func (m *SoraAccountMutation) ClearSessionToken() { + m.session_token = nil + m.clearedFields[soraaccount.FieldSessionToken] = struct{}{} +} + +// SessionTokenCleared returns if the "session_token" field was cleared in this mutation. +func (m *SoraAccountMutation) SessionTokenCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldSessionToken] + return ok +} + +// ResetSessionToken resets all changes to the "session_token" field. +func (m *SoraAccountMutation) ResetSessionToken() { + m.session_token = nil + delete(m.clearedFields, soraaccount.FieldSessionToken) +} + +// SetRefreshToken sets the "refresh_token" field. +func (m *SoraAccountMutation) SetRefreshToken(s string) { + m.refresh_token = &s +} + +// RefreshToken returns the value of the "refresh_token" field in the mutation. +func (m *SoraAccountMutation) RefreshToken() (r string, exists bool) { + v := m.refresh_token + if v == nil { + return + } + return *v, true +} + +// OldRefreshToken returns the old "refresh_token" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldRefreshToken(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefreshToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefreshToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefreshToken: %w", err) + } + return oldValue.RefreshToken, nil +} + +// ClearRefreshToken clears the value of the "refresh_token" field. +func (m *SoraAccountMutation) ClearRefreshToken() { + m.refresh_token = nil + m.clearedFields[soraaccount.FieldRefreshToken] = struct{}{} +} + +// RefreshTokenCleared returns if the "refresh_token" field was cleared in this mutation. +func (m *SoraAccountMutation) RefreshTokenCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldRefreshToken] + return ok +} + +// ResetRefreshToken resets all changes to the "refresh_token" field. +func (m *SoraAccountMutation) ResetRefreshToken() { + m.refresh_token = nil + delete(m.clearedFields, soraaccount.FieldRefreshToken) +} + +// SetClientID sets the "client_id" field. +func (m *SoraAccountMutation) SetClientID(s string) { + m.client_id = &s +} + +// ClientID returns the value of the "client_id" field in the mutation. +func (m *SoraAccountMutation) ClientID() (r string, exists bool) { + v := m.client_id + if v == nil { + return + } + return *v, true +} + +// OldClientID returns the old "client_id" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldClientID(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClientID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClientID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClientID: %w", err) + } + return oldValue.ClientID, nil +} + +// ClearClientID clears the value of the "client_id" field. +func (m *SoraAccountMutation) ClearClientID() { + m.client_id = nil + m.clearedFields[soraaccount.FieldClientID] = struct{}{} +} + +// ClientIDCleared returns if the "client_id" field was cleared in this mutation. +func (m *SoraAccountMutation) ClientIDCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldClientID] + return ok +} + +// ResetClientID resets all changes to the "client_id" field. +func (m *SoraAccountMutation) ResetClientID() { + m.client_id = nil + delete(m.clearedFields, soraaccount.FieldClientID) +} + +// SetEmail sets the "email" field. +func (m *SoraAccountMutation) SetEmail(s string) { + m.email = &s +} + +// Email returns the value of the "email" field in the mutation. +func (m *SoraAccountMutation) Email() (r string, exists bool) { + v := m.email + if v == nil { + return + } + return *v, true +} + +// OldEmail returns the old "email" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldEmail(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEmail is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEmail requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEmail: %w", err) + } + return oldValue.Email, nil +} + +// ClearEmail clears the value of the "email" field. +func (m *SoraAccountMutation) ClearEmail() { + m.email = nil + m.clearedFields[soraaccount.FieldEmail] = struct{}{} +} + +// EmailCleared returns if the "email" field was cleared in this mutation. +func (m *SoraAccountMutation) EmailCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldEmail] + return ok +} + +// ResetEmail resets all changes to the "email" field. +func (m *SoraAccountMutation) ResetEmail() { + m.email = nil + delete(m.clearedFields, soraaccount.FieldEmail) +} + +// SetUsername sets the "username" field. +func (m *SoraAccountMutation) SetUsername(s string) { + m.username = &s +} + +// Username returns the value of the "username" field in the mutation. +func (m *SoraAccountMutation) Username() (r string, exists bool) { + v := m.username + if v == nil { + return + } + return *v, true +} + +// OldUsername returns the old "username" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldUsername(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsername is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsername requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsername: %w", err) + } + return oldValue.Username, nil +} + +// ClearUsername clears the value of the "username" field. +func (m *SoraAccountMutation) ClearUsername() { + m.username = nil + m.clearedFields[soraaccount.FieldUsername] = struct{}{} +} + +// UsernameCleared returns if the "username" field was cleared in this mutation. +func (m *SoraAccountMutation) UsernameCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldUsername] + return ok +} + +// ResetUsername resets all changes to the "username" field. +func (m *SoraAccountMutation) ResetUsername() { + m.username = nil + delete(m.clearedFields, soraaccount.FieldUsername) +} + +// SetRemark sets the "remark" field. +func (m *SoraAccountMutation) SetRemark(s string) { + m.remark = &s +} + +// Remark returns the value of the "remark" field in the mutation. +func (m *SoraAccountMutation) Remark() (r string, exists bool) { + v := m.remark + if v == nil { + return + } + return *v, true +} + +// OldRemark returns the old "remark" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldRemark(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRemark is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRemark requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRemark: %w", err) + } + return oldValue.Remark, nil +} + +// ClearRemark clears the value of the "remark" field. +func (m *SoraAccountMutation) ClearRemark() { + m.remark = nil + m.clearedFields[soraaccount.FieldRemark] = struct{}{} +} + +// RemarkCleared returns if the "remark" field was cleared in this mutation. +func (m *SoraAccountMutation) RemarkCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldRemark] + return ok +} + +// ResetRemark resets all changes to the "remark" field. +func (m *SoraAccountMutation) ResetRemark() { + m.remark = nil + delete(m.clearedFields, soraaccount.FieldRemark) +} + +// SetUseCount sets the "use_count" field. +func (m *SoraAccountMutation) SetUseCount(i int) { + m.use_count = &i + m.adduse_count = nil +} + +// UseCount returns the value of the "use_count" field in the mutation. +func (m *SoraAccountMutation) UseCount() (r int, exists bool) { + v := m.use_count + if v == nil { + return + } + return *v, true +} + +// OldUseCount returns the old "use_count" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldUseCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUseCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUseCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUseCount: %w", err) + } + return oldValue.UseCount, nil +} + +// AddUseCount adds i to the "use_count" field. +func (m *SoraAccountMutation) AddUseCount(i int) { + if m.adduse_count != nil { + *m.adduse_count += i + } else { + m.adduse_count = &i + } +} + +// AddedUseCount returns the value that was added to the "use_count" field in this mutation. +func (m *SoraAccountMutation) AddedUseCount() (r int, exists bool) { + v := m.adduse_count + if v == nil { + return + } + return *v, true +} + +// ResetUseCount resets all changes to the "use_count" field. +func (m *SoraAccountMutation) ResetUseCount() { + m.use_count = nil + m.adduse_count = nil +} + +// SetPlanType sets the "plan_type" field. +func (m *SoraAccountMutation) SetPlanType(s string) { + m.plan_type = &s +} + +// PlanType returns the value of the "plan_type" field in the mutation. +func (m *SoraAccountMutation) PlanType() (r string, exists bool) { + v := m.plan_type + if v == nil { + return + } + return *v, true +} + +// OldPlanType returns the old "plan_type" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldPlanType(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlanType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlanType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlanType: %w", err) + } + return oldValue.PlanType, nil +} + +// ClearPlanType clears the value of the "plan_type" field. +func (m *SoraAccountMutation) ClearPlanType() { + m.plan_type = nil + m.clearedFields[soraaccount.FieldPlanType] = struct{}{} +} + +// PlanTypeCleared returns if the "plan_type" field was cleared in this mutation. +func (m *SoraAccountMutation) PlanTypeCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldPlanType] + return ok +} + +// ResetPlanType resets all changes to the "plan_type" field. +func (m *SoraAccountMutation) ResetPlanType() { + m.plan_type = nil + delete(m.clearedFields, soraaccount.FieldPlanType) +} + +// SetPlanTitle sets the "plan_title" field. +func (m *SoraAccountMutation) SetPlanTitle(s string) { + m.plan_title = &s +} + +// PlanTitle returns the value of the "plan_title" field in the mutation. +func (m *SoraAccountMutation) PlanTitle() (r string, exists bool) { + v := m.plan_title + if v == nil { + return + } + return *v, true +} + +// OldPlanTitle returns the old "plan_title" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldPlanTitle(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlanTitle is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlanTitle requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlanTitle: %w", err) + } + return oldValue.PlanTitle, nil +} + +// ClearPlanTitle clears the value of the "plan_title" field. +func (m *SoraAccountMutation) ClearPlanTitle() { + m.plan_title = nil + m.clearedFields[soraaccount.FieldPlanTitle] = struct{}{} +} + +// PlanTitleCleared returns if the "plan_title" field was cleared in this mutation. +func (m *SoraAccountMutation) PlanTitleCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldPlanTitle] + return ok +} + +// ResetPlanTitle resets all changes to the "plan_title" field. +func (m *SoraAccountMutation) ResetPlanTitle() { + m.plan_title = nil + delete(m.clearedFields, soraaccount.FieldPlanTitle) +} + +// SetSubscriptionEnd sets the "subscription_end" field. +func (m *SoraAccountMutation) SetSubscriptionEnd(t time.Time) { + m.subscription_end = &t +} + +// SubscriptionEnd returns the value of the "subscription_end" field in the mutation. +func (m *SoraAccountMutation) SubscriptionEnd() (r time.Time, exists bool) { + v := m.subscription_end + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionEnd returns the old "subscription_end" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSubscriptionEnd(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionEnd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionEnd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionEnd: %w", err) + } + return oldValue.SubscriptionEnd, nil +} + +// ClearSubscriptionEnd clears the value of the "subscription_end" field. +func (m *SoraAccountMutation) ClearSubscriptionEnd() { + m.subscription_end = nil + m.clearedFields[soraaccount.FieldSubscriptionEnd] = struct{}{} +} + +// SubscriptionEndCleared returns if the "subscription_end" field was cleared in this mutation. +func (m *SoraAccountMutation) SubscriptionEndCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldSubscriptionEnd] + return ok +} + +// ResetSubscriptionEnd resets all changes to the "subscription_end" field. +func (m *SoraAccountMutation) ResetSubscriptionEnd() { + m.subscription_end = nil + delete(m.clearedFields, soraaccount.FieldSubscriptionEnd) +} + +// SetSoraSupported sets the "sora_supported" field. +func (m *SoraAccountMutation) SetSoraSupported(b bool) { + m.sora_supported = &b +} + +// SoraSupported returns the value of the "sora_supported" field in the mutation. +func (m *SoraAccountMutation) SoraSupported() (r bool, exists bool) { + v := m.sora_supported + if v == nil { + return + } + return *v, true +} + +// OldSoraSupported returns the old "sora_supported" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSoraSupported(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraSupported is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraSupported requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraSupported: %w", err) + } + return oldValue.SoraSupported, nil +} + +// ResetSoraSupported resets all changes to the "sora_supported" field. +func (m *SoraAccountMutation) ResetSoraSupported() { + m.sora_supported = nil +} + +// SetSoraInviteCode sets the "sora_invite_code" field. +func (m *SoraAccountMutation) SetSoraInviteCode(s string) { + m.sora_invite_code = &s +} + +// SoraInviteCode returns the value of the "sora_invite_code" field in the mutation. +func (m *SoraAccountMutation) SoraInviteCode() (r string, exists bool) { + v := m.sora_invite_code + if v == nil { + return + } + return *v, true +} + +// OldSoraInviteCode returns the old "sora_invite_code" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSoraInviteCode(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraInviteCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraInviteCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraInviteCode: %w", err) + } + return oldValue.SoraInviteCode, nil +} + +// ClearSoraInviteCode clears the value of the "sora_invite_code" field. +func (m *SoraAccountMutation) ClearSoraInviteCode() { + m.sora_invite_code = nil + m.clearedFields[soraaccount.FieldSoraInviteCode] = struct{}{} +} + +// SoraInviteCodeCleared returns if the "sora_invite_code" field was cleared in this mutation. +func (m *SoraAccountMutation) SoraInviteCodeCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldSoraInviteCode] + return ok +} + +// ResetSoraInviteCode resets all changes to the "sora_invite_code" field. +func (m *SoraAccountMutation) ResetSoraInviteCode() { + m.sora_invite_code = nil + delete(m.clearedFields, soraaccount.FieldSoraInviteCode) +} + +// SetSoraRedeemedCount sets the "sora_redeemed_count" field. +func (m *SoraAccountMutation) SetSoraRedeemedCount(i int) { + m.sora_redeemed_count = &i + m.addsora_redeemed_count = nil +} + +// SoraRedeemedCount returns the value of the "sora_redeemed_count" field in the mutation. +func (m *SoraAccountMutation) SoraRedeemedCount() (r int, exists bool) { + v := m.sora_redeemed_count + if v == nil { + return + } + return *v, true +} + +// OldSoraRedeemedCount returns the old "sora_redeemed_count" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSoraRedeemedCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraRedeemedCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraRedeemedCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraRedeemedCount: %w", err) + } + return oldValue.SoraRedeemedCount, nil +} + +// AddSoraRedeemedCount adds i to the "sora_redeemed_count" field. +func (m *SoraAccountMutation) AddSoraRedeemedCount(i int) { + if m.addsora_redeemed_count != nil { + *m.addsora_redeemed_count += i + } else { + m.addsora_redeemed_count = &i + } +} + +// AddedSoraRedeemedCount returns the value that was added to the "sora_redeemed_count" field in this mutation. +func (m *SoraAccountMutation) AddedSoraRedeemedCount() (r int, exists bool) { + v := m.addsora_redeemed_count + if v == nil { + return + } + return *v, true +} + +// ResetSoraRedeemedCount resets all changes to the "sora_redeemed_count" field. +func (m *SoraAccountMutation) ResetSoraRedeemedCount() { + m.sora_redeemed_count = nil + m.addsora_redeemed_count = nil +} + +// SetSoraRemainingCount sets the "sora_remaining_count" field. +func (m *SoraAccountMutation) SetSoraRemainingCount(i int) { + m.sora_remaining_count = &i + m.addsora_remaining_count = nil +} + +// SoraRemainingCount returns the value of the "sora_remaining_count" field in the mutation. +func (m *SoraAccountMutation) SoraRemainingCount() (r int, exists bool) { + v := m.sora_remaining_count + if v == nil { + return + } + return *v, true +} + +// OldSoraRemainingCount returns the old "sora_remaining_count" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSoraRemainingCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraRemainingCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraRemainingCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraRemainingCount: %w", err) + } + return oldValue.SoraRemainingCount, nil +} + +// AddSoraRemainingCount adds i to the "sora_remaining_count" field. +func (m *SoraAccountMutation) AddSoraRemainingCount(i int) { + if m.addsora_remaining_count != nil { + *m.addsora_remaining_count += i + } else { + m.addsora_remaining_count = &i + } +} + +// AddedSoraRemainingCount returns the value that was added to the "sora_remaining_count" field in this mutation. +func (m *SoraAccountMutation) AddedSoraRemainingCount() (r int, exists bool) { + v := m.addsora_remaining_count + if v == nil { + return + } + return *v, true +} + +// ResetSoraRemainingCount resets all changes to the "sora_remaining_count" field. +func (m *SoraAccountMutation) ResetSoraRemainingCount() { + m.sora_remaining_count = nil + m.addsora_remaining_count = nil +} + +// SetSoraTotalCount sets the "sora_total_count" field. +func (m *SoraAccountMutation) SetSoraTotalCount(i int) { + m.sora_total_count = &i + m.addsora_total_count = nil +} + +// SoraTotalCount returns the value of the "sora_total_count" field in the mutation. +func (m *SoraAccountMutation) SoraTotalCount() (r int, exists bool) { + v := m.sora_total_count + if v == nil { + return + } + return *v, true +} + +// OldSoraTotalCount returns the old "sora_total_count" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSoraTotalCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraTotalCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraTotalCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraTotalCount: %w", err) + } + return oldValue.SoraTotalCount, nil +} + +// AddSoraTotalCount adds i to the "sora_total_count" field. +func (m *SoraAccountMutation) AddSoraTotalCount(i int) { + if m.addsora_total_count != nil { + *m.addsora_total_count += i + } else { + m.addsora_total_count = &i + } +} + +// AddedSoraTotalCount returns the value that was added to the "sora_total_count" field in this mutation. +func (m *SoraAccountMutation) AddedSoraTotalCount() (r int, exists bool) { + v := m.addsora_total_count + if v == nil { + return + } + return *v, true +} + +// ResetSoraTotalCount resets all changes to the "sora_total_count" field. +func (m *SoraAccountMutation) ResetSoraTotalCount() { + m.sora_total_count = nil + m.addsora_total_count = nil +} + +// SetSoraCooldownUntil sets the "sora_cooldown_until" field. +func (m *SoraAccountMutation) SetSoraCooldownUntil(t time.Time) { + m.sora_cooldown_until = &t +} + +// SoraCooldownUntil returns the value of the "sora_cooldown_until" field in the mutation. +func (m *SoraAccountMutation) SoraCooldownUntil() (r time.Time, exists bool) { + v := m.sora_cooldown_until + if v == nil { + return + } + return *v, true +} + +// OldSoraCooldownUntil returns the old "sora_cooldown_until" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldSoraCooldownUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraCooldownUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraCooldownUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraCooldownUntil: %w", err) + } + return oldValue.SoraCooldownUntil, nil +} + +// ClearSoraCooldownUntil clears the value of the "sora_cooldown_until" field. +func (m *SoraAccountMutation) ClearSoraCooldownUntil() { + m.sora_cooldown_until = nil + m.clearedFields[soraaccount.FieldSoraCooldownUntil] = struct{}{} +} + +// SoraCooldownUntilCleared returns if the "sora_cooldown_until" field was cleared in this mutation. +func (m *SoraAccountMutation) SoraCooldownUntilCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldSoraCooldownUntil] + return ok +} + +// ResetSoraCooldownUntil resets all changes to the "sora_cooldown_until" field. +func (m *SoraAccountMutation) ResetSoraCooldownUntil() { + m.sora_cooldown_until = nil + delete(m.clearedFields, soraaccount.FieldSoraCooldownUntil) +} + +// SetCooledUntil sets the "cooled_until" field. +func (m *SoraAccountMutation) SetCooledUntil(t time.Time) { + m.cooled_until = &t +} + +// CooledUntil returns the value of the "cooled_until" field in the mutation. +func (m *SoraAccountMutation) CooledUntil() (r time.Time, exists bool) { + v := m.cooled_until + if v == nil { + return + } + return *v, true +} + +// OldCooledUntil returns the old "cooled_until" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldCooledUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCooledUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCooledUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCooledUntil: %w", err) + } + return oldValue.CooledUntil, nil +} + +// ClearCooledUntil clears the value of the "cooled_until" field. +func (m *SoraAccountMutation) ClearCooledUntil() { + m.cooled_until = nil + m.clearedFields[soraaccount.FieldCooledUntil] = struct{}{} +} + +// CooledUntilCleared returns if the "cooled_until" field was cleared in this mutation. +func (m *SoraAccountMutation) CooledUntilCleared() bool { + _, ok := m.clearedFields[soraaccount.FieldCooledUntil] + return ok +} + +// ResetCooledUntil resets all changes to the "cooled_until" field. +func (m *SoraAccountMutation) ResetCooledUntil() { + m.cooled_until = nil + delete(m.clearedFields, soraaccount.FieldCooledUntil) +} + +// SetImageEnabled sets the "image_enabled" field. +func (m *SoraAccountMutation) SetImageEnabled(b bool) { + m.image_enabled = &b +} + +// ImageEnabled returns the value of the "image_enabled" field in the mutation. +func (m *SoraAccountMutation) ImageEnabled() (r bool, exists bool) { + v := m.image_enabled + if v == nil { + return + } + return *v, true +} + +// OldImageEnabled returns the old "image_enabled" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldImageEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageEnabled: %w", err) + } + return oldValue.ImageEnabled, nil +} + +// ResetImageEnabled resets all changes to the "image_enabled" field. +func (m *SoraAccountMutation) ResetImageEnabled() { + m.image_enabled = nil +} + +// SetVideoEnabled sets the "video_enabled" field. +func (m *SoraAccountMutation) SetVideoEnabled(b bool) { + m.video_enabled = &b +} + +// VideoEnabled returns the value of the "video_enabled" field in the mutation. +func (m *SoraAccountMutation) VideoEnabled() (r bool, exists bool) { + v := m.video_enabled + if v == nil { + return + } + return *v, true +} + +// OldVideoEnabled returns the old "video_enabled" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldVideoEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVideoEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVideoEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVideoEnabled: %w", err) + } + return oldValue.VideoEnabled, nil +} + +// ResetVideoEnabled resets all changes to the "video_enabled" field. +func (m *SoraAccountMutation) ResetVideoEnabled() { + m.video_enabled = nil +} + +// SetImageConcurrency sets the "image_concurrency" field. +func (m *SoraAccountMutation) SetImageConcurrency(i int) { + m.image_concurrency = &i + m.addimage_concurrency = nil +} + +// ImageConcurrency returns the value of the "image_concurrency" field in the mutation. +func (m *SoraAccountMutation) ImageConcurrency() (r int, exists bool) { + v := m.image_concurrency + if v == nil { + return + } + return *v, true +} + +// OldImageConcurrency returns the old "image_concurrency" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldImageConcurrency(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageConcurrency is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageConcurrency requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageConcurrency: %w", err) + } + return oldValue.ImageConcurrency, nil +} + +// AddImageConcurrency adds i to the "image_concurrency" field. +func (m *SoraAccountMutation) AddImageConcurrency(i int) { + if m.addimage_concurrency != nil { + *m.addimage_concurrency += i + } else { + m.addimage_concurrency = &i + } +} + +// AddedImageConcurrency returns the value that was added to the "image_concurrency" field in this mutation. +func (m *SoraAccountMutation) AddedImageConcurrency() (r int, exists bool) { + v := m.addimage_concurrency + if v == nil { + return + } + return *v, true +} + +// ResetImageConcurrency resets all changes to the "image_concurrency" field. +func (m *SoraAccountMutation) ResetImageConcurrency() { + m.image_concurrency = nil + m.addimage_concurrency = nil +} + +// SetVideoConcurrency sets the "video_concurrency" field. +func (m *SoraAccountMutation) SetVideoConcurrency(i int) { + m.video_concurrency = &i + m.addvideo_concurrency = nil +} + +// VideoConcurrency returns the value of the "video_concurrency" field in the mutation. +func (m *SoraAccountMutation) VideoConcurrency() (r int, exists bool) { + v := m.video_concurrency + if v == nil { + return + } + return *v, true +} + +// OldVideoConcurrency returns the old "video_concurrency" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldVideoConcurrency(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVideoConcurrency is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVideoConcurrency requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVideoConcurrency: %w", err) + } + return oldValue.VideoConcurrency, nil +} + +// AddVideoConcurrency adds i to the "video_concurrency" field. +func (m *SoraAccountMutation) AddVideoConcurrency(i int) { + if m.addvideo_concurrency != nil { + *m.addvideo_concurrency += i + } else { + m.addvideo_concurrency = &i + } +} + +// AddedVideoConcurrency returns the value that was added to the "video_concurrency" field in this mutation. +func (m *SoraAccountMutation) AddedVideoConcurrency() (r int, exists bool) { + v := m.addvideo_concurrency + if v == nil { + return + } + return *v, true +} + +// ResetVideoConcurrency resets all changes to the "video_concurrency" field. +func (m *SoraAccountMutation) ResetVideoConcurrency() { + m.video_concurrency = nil + m.addvideo_concurrency = nil +} + +// SetIsExpired sets the "is_expired" field. +func (m *SoraAccountMutation) SetIsExpired(b bool) { + m.is_expired = &b +} + +// IsExpired returns the value of the "is_expired" field in the mutation. +func (m *SoraAccountMutation) IsExpired() (r bool, exists bool) { + v := m.is_expired + if v == nil { + return + } + return *v, true +} + +// OldIsExpired returns the old "is_expired" field's value of the SoraAccount entity. +// If the SoraAccount object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraAccountMutation) OldIsExpired(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsExpired is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsExpired requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsExpired: %w", err) + } + return oldValue.IsExpired, nil +} + +// ResetIsExpired resets all changes to the "is_expired" field. +func (m *SoraAccountMutation) ResetIsExpired() { + m.is_expired = nil +} + +// Where appends a list predicates to the SoraAccountMutation builder. +func (m *SoraAccountMutation) Where(ps ...predicate.SoraAccount) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SoraAccountMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SoraAccountMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SoraAccount, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SoraAccountMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SoraAccountMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (SoraAccount). +func (m *SoraAccountMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SoraAccountMutation) Fields() []string { + fields := make([]string, 0, 26) + if m.created_at != nil { + fields = append(fields, soraaccount.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, soraaccount.FieldUpdatedAt) + } + if m.account_id != nil { + fields = append(fields, soraaccount.FieldAccountID) + } + if m.access_token != nil { + fields = append(fields, soraaccount.FieldAccessToken) + } + if m.session_token != nil { + fields = append(fields, soraaccount.FieldSessionToken) + } + if m.refresh_token != nil { + fields = append(fields, soraaccount.FieldRefreshToken) + } + if m.client_id != nil { + fields = append(fields, soraaccount.FieldClientID) + } + if m.email != nil { + fields = append(fields, soraaccount.FieldEmail) + } + if m.username != nil { + fields = append(fields, soraaccount.FieldUsername) + } + if m.remark != nil { + fields = append(fields, soraaccount.FieldRemark) + } + if m.use_count != nil { + fields = append(fields, soraaccount.FieldUseCount) + } + if m.plan_type != nil { + fields = append(fields, soraaccount.FieldPlanType) + } + if m.plan_title != nil { + fields = append(fields, soraaccount.FieldPlanTitle) + } + if m.subscription_end != nil { + fields = append(fields, soraaccount.FieldSubscriptionEnd) + } + if m.sora_supported != nil { + fields = append(fields, soraaccount.FieldSoraSupported) + } + if m.sora_invite_code != nil { + fields = append(fields, soraaccount.FieldSoraInviteCode) + } + if m.sora_redeemed_count != nil { + fields = append(fields, soraaccount.FieldSoraRedeemedCount) + } + if m.sora_remaining_count != nil { + fields = append(fields, soraaccount.FieldSoraRemainingCount) + } + if m.sora_total_count != nil { + fields = append(fields, soraaccount.FieldSoraTotalCount) + } + if m.sora_cooldown_until != nil { + fields = append(fields, soraaccount.FieldSoraCooldownUntil) + } + if m.cooled_until != nil { + fields = append(fields, soraaccount.FieldCooledUntil) + } + if m.image_enabled != nil { + fields = append(fields, soraaccount.FieldImageEnabled) + } + if m.video_enabled != nil { + fields = append(fields, soraaccount.FieldVideoEnabled) + } + if m.image_concurrency != nil { + fields = append(fields, soraaccount.FieldImageConcurrency) + } + if m.video_concurrency != nil { + fields = append(fields, soraaccount.FieldVideoConcurrency) + } + if m.is_expired != nil { + fields = append(fields, soraaccount.FieldIsExpired) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SoraAccountMutation) Field(name string) (ent.Value, bool) { + switch name { + case soraaccount.FieldCreatedAt: + return m.CreatedAt() + case soraaccount.FieldUpdatedAt: + return m.UpdatedAt() + case soraaccount.FieldAccountID: + return m.AccountID() + case soraaccount.FieldAccessToken: + return m.AccessToken() + case soraaccount.FieldSessionToken: + return m.SessionToken() + case soraaccount.FieldRefreshToken: + return m.RefreshToken() + case soraaccount.FieldClientID: + return m.ClientID() + case soraaccount.FieldEmail: + return m.Email() + case soraaccount.FieldUsername: + return m.Username() + case soraaccount.FieldRemark: + return m.Remark() + case soraaccount.FieldUseCount: + return m.UseCount() + case soraaccount.FieldPlanType: + return m.PlanType() + case soraaccount.FieldPlanTitle: + return m.PlanTitle() + case soraaccount.FieldSubscriptionEnd: + return m.SubscriptionEnd() + case soraaccount.FieldSoraSupported: + return m.SoraSupported() + case soraaccount.FieldSoraInviteCode: + return m.SoraInviteCode() + case soraaccount.FieldSoraRedeemedCount: + return m.SoraRedeemedCount() + case soraaccount.FieldSoraRemainingCount: + return m.SoraRemainingCount() + case soraaccount.FieldSoraTotalCount: + return m.SoraTotalCount() + case soraaccount.FieldSoraCooldownUntil: + return m.SoraCooldownUntil() + case soraaccount.FieldCooledUntil: + return m.CooledUntil() + case soraaccount.FieldImageEnabled: + return m.ImageEnabled() + case soraaccount.FieldVideoEnabled: + return m.VideoEnabled() + case soraaccount.FieldImageConcurrency: + return m.ImageConcurrency() + case soraaccount.FieldVideoConcurrency: + return m.VideoConcurrency() + case soraaccount.FieldIsExpired: + return m.IsExpired() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SoraAccountMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case soraaccount.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case soraaccount.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case soraaccount.FieldAccountID: + return m.OldAccountID(ctx) + case soraaccount.FieldAccessToken: + return m.OldAccessToken(ctx) + case soraaccount.FieldSessionToken: + return m.OldSessionToken(ctx) + case soraaccount.FieldRefreshToken: + return m.OldRefreshToken(ctx) + case soraaccount.FieldClientID: + return m.OldClientID(ctx) + case soraaccount.FieldEmail: + return m.OldEmail(ctx) + case soraaccount.FieldUsername: + return m.OldUsername(ctx) + case soraaccount.FieldRemark: + return m.OldRemark(ctx) + case soraaccount.FieldUseCount: + return m.OldUseCount(ctx) + case soraaccount.FieldPlanType: + return m.OldPlanType(ctx) + case soraaccount.FieldPlanTitle: + return m.OldPlanTitle(ctx) + case soraaccount.FieldSubscriptionEnd: + return m.OldSubscriptionEnd(ctx) + case soraaccount.FieldSoraSupported: + return m.OldSoraSupported(ctx) + case soraaccount.FieldSoraInviteCode: + return m.OldSoraInviteCode(ctx) + case soraaccount.FieldSoraRedeemedCount: + return m.OldSoraRedeemedCount(ctx) + case soraaccount.FieldSoraRemainingCount: + return m.OldSoraRemainingCount(ctx) + case soraaccount.FieldSoraTotalCount: + return m.OldSoraTotalCount(ctx) + case soraaccount.FieldSoraCooldownUntil: + return m.OldSoraCooldownUntil(ctx) + case soraaccount.FieldCooledUntil: + return m.OldCooledUntil(ctx) + case soraaccount.FieldImageEnabled: + return m.OldImageEnabled(ctx) + case soraaccount.FieldVideoEnabled: + return m.OldVideoEnabled(ctx) + case soraaccount.FieldImageConcurrency: + return m.OldImageConcurrency(ctx) + case soraaccount.FieldVideoConcurrency: + return m.OldVideoConcurrency(ctx) + case soraaccount.FieldIsExpired: + return m.OldIsExpired(ctx) + } + return nil, fmt.Errorf("unknown SoraAccount field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraAccountMutation) SetField(name string, value ent.Value) error { + switch name { + case soraaccount.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case soraaccount.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case soraaccount.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccountID(v) + return nil + case soraaccount.FieldAccessToken: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccessToken(v) + return nil + case soraaccount.FieldSessionToken: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSessionToken(v) + return nil + case soraaccount.FieldRefreshToken: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefreshToken(v) + return nil + case soraaccount.FieldClientID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClientID(v) + return nil + case soraaccount.FieldEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEmail(v) + return nil + case soraaccount.FieldUsername: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsername(v) + return nil + case soraaccount.FieldRemark: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRemark(v) + return nil + case soraaccount.FieldUseCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUseCount(v) + return nil + case soraaccount.FieldPlanType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlanType(v) + return nil + case soraaccount.FieldPlanTitle: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlanTitle(v) + return nil + case soraaccount.FieldSubscriptionEnd: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionEnd(v) + return nil + case soraaccount.FieldSoraSupported: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraSupported(v) + return nil + case soraaccount.FieldSoraInviteCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraInviteCode(v) + return nil + case soraaccount.FieldSoraRedeemedCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraRedeemedCount(v) + return nil + case soraaccount.FieldSoraRemainingCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraRemainingCount(v) + return nil + case soraaccount.FieldSoraTotalCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraTotalCount(v) + return nil + case soraaccount.FieldSoraCooldownUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraCooldownUntil(v) + return nil + case soraaccount.FieldCooledUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCooledUntil(v) + return nil + case soraaccount.FieldImageEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageEnabled(v) + return nil + case soraaccount.FieldVideoEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVideoEnabled(v) + return nil + case soraaccount.FieldImageConcurrency: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageConcurrency(v) + return nil + case soraaccount.FieldVideoConcurrency: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVideoConcurrency(v) + return nil + case soraaccount.FieldIsExpired: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsExpired(v) + return nil + } + return fmt.Errorf("unknown SoraAccount field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SoraAccountMutation) AddedFields() []string { + var fields []string + if m.addaccount_id != nil { + fields = append(fields, soraaccount.FieldAccountID) + } + if m.adduse_count != nil { + fields = append(fields, soraaccount.FieldUseCount) + } + if m.addsora_redeemed_count != nil { + fields = append(fields, soraaccount.FieldSoraRedeemedCount) + } + if m.addsora_remaining_count != nil { + fields = append(fields, soraaccount.FieldSoraRemainingCount) + } + if m.addsora_total_count != nil { + fields = append(fields, soraaccount.FieldSoraTotalCount) + } + if m.addimage_concurrency != nil { + fields = append(fields, soraaccount.FieldImageConcurrency) + } + if m.addvideo_concurrency != nil { + fields = append(fields, soraaccount.FieldVideoConcurrency) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SoraAccountMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case soraaccount.FieldAccountID: + return m.AddedAccountID() + case soraaccount.FieldUseCount: + return m.AddedUseCount() + case soraaccount.FieldSoraRedeemedCount: + return m.AddedSoraRedeemedCount() + case soraaccount.FieldSoraRemainingCount: + return m.AddedSoraRemainingCount() + case soraaccount.FieldSoraTotalCount: + return m.AddedSoraTotalCount() + case soraaccount.FieldImageConcurrency: + return m.AddedImageConcurrency() + case soraaccount.FieldVideoConcurrency: + return m.AddedVideoConcurrency() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraAccountMutation) AddField(name string, value ent.Value) error { + switch name { + case soraaccount.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAccountID(v) + return nil + case soraaccount.FieldUseCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUseCount(v) + return nil + case soraaccount.FieldSoraRedeemedCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraRedeemedCount(v) + return nil + case soraaccount.FieldSoraRemainingCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraRemainingCount(v) + return nil + case soraaccount.FieldSoraTotalCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraTotalCount(v) + return nil + case soraaccount.FieldImageConcurrency: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImageConcurrency(v) + return nil + case soraaccount.FieldVideoConcurrency: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddVideoConcurrency(v) + return nil + } + return fmt.Errorf("unknown SoraAccount numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SoraAccountMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(soraaccount.FieldAccessToken) { + fields = append(fields, soraaccount.FieldAccessToken) + } + if m.FieldCleared(soraaccount.FieldSessionToken) { + fields = append(fields, soraaccount.FieldSessionToken) + } + if m.FieldCleared(soraaccount.FieldRefreshToken) { + fields = append(fields, soraaccount.FieldRefreshToken) + } + if m.FieldCleared(soraaccount.FieldClientID) { + fields = append(fields, soraaccount.FieldClientID) + } + if m.FieldCleared(soraaccount.FieldEmail) { + fields = append(fields, soraaccount.FieldEmail) + } + if m.FieldCleared(soraaccount.FieldUsername) { + fields = append(fields, soraaccount.FieldUsername) + } + if m.FieldCleared(soraaccount.FieldRemark) { + fields = append(fields, soraaccount.FieldRemark) + } + if m.FieldCleared(soraaccount.FieldPlanType) { + fields = append(fields, soraaccount.FieldPlanType) + } + if m.FieldCleared(soraaccount.FieldPlanTitle) { + fields = append(fields, soraaccount.FieldPlanTitle) + } + if m.FieldCleared(soraaccount.FieldSubscriptionEnd) { + fields = append(fields, soraaccount.FieldSubscriptionEnd) + } + if m.FieldCleared(soraaccount.FieldSoraInviteCode) { + fields = append(fields, soraaccount.FieldSoraInviteCode) + } + if m.FieldCleared(soraaccount.FieldSoraCooldownUntil) { + fields = append(fields, soraaccount.FieldSoraCooldownUntil) + } + if m.FieldCleared(soraaccount.FieldCooledUntil) { + fields = append(fields, soraaccount.FieldCooledUntil) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SoraAccountMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SoraAccountMutation) ClearField(name string) error { + switch name { + case soraaccount.FieldAccessToken: + m.ClearAccessToken() + return nil + case soraaccount.FieldSessionToken: + m.ClearSessionToken() + return nil + case soraaccount.FieldRefreshToken: + m.ClearRefreshToken() + return nil + case soraaccount.FieldClientID: + m.ClearClientID() + return nil + case soraaccount.FieldEmail: + m.ClearEmail() + return nil + case soraaccount.FieldUsername: + m.ClearUsername() + return nil + case soraaccount.FieldRemark: + m.ClearRemark() + return nil + case soraaccount.FieldPlanType: + m.ClearPlanType() + return nil + case soraaccount.FieldPlanTitle: + m.ClearPlanTitle() + return nil + case soraaccount.FieldSubscriptionEnd: + m.ClearSubscriptionEnd() + return nil + case soraaccount.FieldSoraInviteCode: + m.ClearSoraInviteCode() + return nil + case soraaccount.FieldSoraCooldownUntil: + m.ClearSoraCooldownUntil() + return nil + case soraaccount.FieldCooledUntil: + m.ClearCooledUntil() + return nil + } + return fmt.Errorf("unknown SoraAccount nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SoraAccountMutation) ResetField(name string) error { + switch name { + case soraaccount.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case soraaccount.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case soraaccount.FieldAccountID: + m.ResetAccountID() + return nil + case soraaccount.FieldAccessToken: + m.ResetAccessToken() + return nil + case soraaccount.FieldSessionToken: + m.ResetSessionToken() + return nil + case soraaccount.FieldRefreshToken: + m.ResetRefreshToken() + return nil + case soraaccount.FieldClientID: + m.ResetClientID() + return nil + case soraaccount.FieldEmail: + m.ResetEmail() + return nil + case soraaccount.FieldUsername: + m.ResetUsername() + return nil + case soraaccount.FieldRemark: + m.ResetRemark() + return nil + case soraaccount.FieldUseCount: + m.ResetUseCount() + return nil + case soraaccount.FieldPlanType: + m.ResetPlanType() + return nil + case soraaccount.FieldPlanTitle: + m.ResetPlanTitle() + return nil + case soraaccount.FieldSubscriptionEnd: + m.ResetSubscriptionEnd() + return nil + case soraaccount.FieldSoraSupported: + m.ResetSoraSupported() + return nil + case soraaccount.FieldSoraInviteCode: + m.ResetSoraInviteCode() + return nil + case soraaccount.FieldSoraRedeemedCount: + m.ResetSoraRedeemedCount() + return nil + case soraaccount.FieldSoraRemainingCount: + m.ResetSoraRemainingCount() + return nil + case soraaccount.FieldSoraTotalCount: + m.ResetSoraTotalCount() + return nil + case soraaccount.FieldSoraCooldownUntil: + m.ResetSoraCooldownUntil() + return nil + case soraaccount.FieldCooledUntil: + m.ResetCooledUntil() + return nil + case soraaccount.FieldImageEnabled: + m.ResetImageEnabled() + return nil + case soraaccount.FieldVideoEnabled: + m.ResetVideoEnabled() + return nil + case soraaccount.FieldImageConcurrency: + m.ResetImageConcurrency() + return nil + case soraaccount.FieldVideoConcurrency: + m.ResetVideoConcurrency() + return nil + case soraaccount.FieldIsExpired: + m.ResetIsExpired() + return nil + } + return fmt.Errorf("unknown SoraAccount field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SoraAccountMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SoraAccountMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SoraAccountMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SoraAccountMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SoraAccountMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SoraAccountMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SoraAccountMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown SoraAccount unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SoraAccountMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown SoraAccount edge %s", name) +} + +// SoraCacheFileMutation represents an operation that mutates the SoraCacheFile nodes in the graph. +type SoraCacheFileMutation struct { + config + op Op + typ string + id *int64 + task_id *string + account_id *int64 + addaccount_id *int64 + user_id *int64 + adduser_id *int64 + media_type *string + original_url *string + cache_path *string + cache_url *string + size_bytes *int64 + addsize_bytes *int64 + created_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*SoraCacheFile, error) + predicates []predicate.SoraCacheFile +} + +var _ ent.Mutation = (*SoraCacheFileMutation)(nil) + +// soracachefileOption allows management of the mutation configuration using functional options. +type soracachefileOption func(*SoraCacheFileMutation) + +// newSoraCacheFileMutation creates new mutation for the SoraCacheFile entity. +func newSoraCacheFileMutation(c config, op Op, opts ...soracachefileOption) *SoraCacheFileMutation { + m := &SoraCacheFileMutation{ + config: c, + op: op, + typ: TypeSoraCacheFile, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSoraCacheFileID sets the ID field of the mutation. +func withSoraCacheFileID(id int64) soracachefileOption { + return func(m *SoraCacheFileMutation) { + var ( + err error + once sync.Once + value *SoraCacheFile + ) + m.oldValue = func(ctx context.Context) (*SoraCacheFile, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().SoraCacheFile.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSoraCacheFile sets the old SoraCacheFile of the mutation. +func withSoraCacheFile(node *SoraCacheFile) soracachefileOption { + return func(m *SoraCacheFileMutation) { + m.oldValue = func(context.Context) (*SoraCacheFile, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SoraCacheFileMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SoraCacheFileMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SoraCacheFileMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SoraCacheFileMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().SoraCacheFile.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetTaskID sets the "task_id" field. +func (m *SoraCacheFileMutation) SetTaskID(s string) { + m.task_id = &s +} + +// TaskID returns the value of the "task_id" field in the mutation. +func (m *SoraCacheFileMutation) TaskID() (r string, exists bool) { + v := m.task_id + if v == nil { + return + } + return *v, true +} + +// OldTaskID returns the old "task_id" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldTaskID(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTaskID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTaskID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTaskID: %w", err) + } + return oldValue.TaskID, nil +} + +// ClearTaskID clears the value of the "task_id" field. +func (m *SoraCacheFileMutation) ClearTaskID() { + m.task_id = nil + m.clearedFields[soracachefile.FieldTaskID] = struct{}{} +} + +// TaskIDCleared returns if the "task_id" field was cleared in this mutation. +func (m *SoraCacheFileMutation) TaskIDCleared() bool { + _, ok := m.clearedFields[soracachefile.FieldTaskID] + return ok +} + +// ResetTaskID resets all changes to the "task_id" field. +func (m *SoraCacheFileMutation) ResetTaskID() { + m.task_id = nil + delete(m.clearedFields, soracachefile.FieldTaskID) +} + +// SetAccountID sets the "account_id" field. +func (m *SoraCacheFileMutation) SetAccountID(i int64) { + m.account_id = &i + m.addaccount_id = nil +} + +// AccountID returns the value of the "account_id" field in the mutation. +func (m *SoraCacheFileMutation) AccountID() (r int64, exists bool) { + v := m.account_id + if v == nil { + return + } + return *v, true +} + +// OldAccountID returns the old "account_id" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldAccountID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccountID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccountID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccountID: %w", err) + } + return oldValue.AccountID, nil +} + +// AddAccountID adds i to the "account_id" field. +func (m *SoraCacheFileMutation) AddAccountID(i int64) { + if m.addaccount_id != nil { + *m.addaccount_id += i + } else { + m.addaccount_id = &i + } +} + +// AddedAccountID returns the value that was added to the "account_id" field in this mutation. +func (m *SoraCacheFileMutation) AddedAccountID() (r int64, exists bool) { + v := m.addaccount_id + if v == nil { + return + } + return *v, true +} + +// ResetAccountID resets all changes to the "account_id" field. +func (m *SoraCacheFileMutation) ResetAccountID() { + m.account_id = nil + m.addaccount_id = nil +} + +// SetUserID sets the "user_id" field. +func (m *SoraCacheFileMutation) SetUserID(i int64) { + m.user_id = &i + m.adduser_id = nil +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *SoraCacheFileMutation) UserID() (r int64, exists bool) { + v := m.user_id + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldUserID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// AddUserID adds i to the "user_id" field. +func (m *SoraCacheFileMutation) AddUserID(i int64) { + if m.adduser_id != nil { + *m.adduser_id += i + } else { + m.adduser_id = &i + } +} + +// AddedUserID returns the value that was added to the "user_id" field in this mutation. +func (m *SoraCacheFileMutation) AddedUserID() (r int64, exists bool) { + v := m.adduser_id + if v == nil { + return + } + return *v, true +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *SoraCacheFileMutation) ResetUserID() { + m.user_id = nil + m.adduser_id = nil +} + +// SetMediaType sets the "media_type" field. +func (m *SoraCacheFileMutation) SetMediaType(s string) { + m.media_type = &s +} + +// MediaType returns the value of the "media_type" field in the mutation. +func (m *SoraCacheFileMutation) MediaType() (r string, exists bool) { + v := m.media_type + if v == nil { + return + } + return *v, true +} + +// OldMediaType returns the old "media_type" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldMediaType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMediaType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMediaType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMediaType: %w", err) + } + return oldValue.MediaType, nil +} + +// ResetMediaType resets all changes to the "media_type" field. +func (m *SoraCacheFileMutation) ResetMediaType() { + m.media_type = nil +} + +// SetOriginalURL sets the "original_url" field. +func (m *SoraCacheFileMutation) SetOriginalURL(s string) { + m.original_url = &s +} + +// OriginalURL returns the value of the "original_url" field in the mutation. +func (m *SoraCacheFileMutation) OriginalURL() (r string, exists bool) { + v := m.original_url + if v == nil { + return + } + return *v, true +} + +// OldOriginalURL returns the old "original_url" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldOriginalURL(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOriginalURL is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOriginalURL requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOriginalURL: %w", err) + } + return oldValue.OriginalURL, nil +} + +// ResetOriginalURL resets all changes to the "original_url" field. +func (m *SoraCacheFileMutation) ResetOriginalURL() { + m.original_url = nil +} + +// SetCachePath sets the "cache_path" field. +func (m *SoraCacheFileMutation) SetCachePath(s string) { + m.cache_path = &s +} + +// CachePath returns the value of the "cache_path" field in the mutation. +func (m *SoraCacheFileMutation) CachePath() (r string, exists bool) { + v := m.cache_path + if v == nil { + return + } + return *v, true +} + +// OldCachePath returns the old "cache_path" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldCachePath(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCachePath is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCachePath requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCachePath: %w", err) + } + return oldValue.CachePath, nil +} + +// ResetCachePath resets all changes to the "cache_path" field. +func (m *SoraCacheFileMutation) ResetCachePath() { + m.cache_path = nil +} + +// SetCacheURL sets the "cache_url" field. +func (m *SoraCacheFileMutation) SetCacheURL(s string) { + m.cache_url = &s +} + +// CacheURL returns the value of the "cache_url" field in the mutation. +func (m *SoraCacheFileMutation) CacheURL() (r string, exists bool) { + v := m.cache_url + if v == nil { + return + } + return *v, true +} + +// OldCacheURL returns the old "cache_url" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldCacheURL(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheURL is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheURL requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheURL: %w", err) + } + return oldValue.CacheURL, nil +} + +// ResetCacheURL resets all changes to the "cache_url" field. +func (m *SoraCacheFileMutation) ResetCacheURL() { + m.cache_url = nil +} + +// SetSizeBytes sets the "size_bytes" field. +func (m *SoraCacheFileMutation) SetSizeBytes(i int64) { + m.size_bytes = &i + m.addsize_bytes = nil +} + +// SizeBytes returns the value of the "size_bytes" field in the mutation. +func (m *SoraCacheFileMutation) SizeBytes() (r int64, exists bool) { + v := m.size_bytes + if v == nil { + return + } + return *v, true +} + +// OldSizeBytes returns the old "size_bytes" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldSizeBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSizeBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSizeBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSizeBytes: %w", err) + } + return oldValue.SizeBytes, nil +} + +// AddSizeBytes adds i to the "size_bytes" field. +func (m *SoraCacheFileMutation) AddSizeBytes(i int64) { + if m.addsize_bytes != nil { + *m.addsize_bytes += i + } else { + m.addsize_bytes = &i + } +} + +// AddedSizeBytes returns the value that was added to the "size_bytes" field in this mutation. +func (m *SoraCacheFileMutation) AddedSizeBytes() (r int64, exists bool) { + v := m.addsize_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSizeBytes resets all changes to the "size_bytes" field. +func (m *SoraCacheFileMutation) ResetSizeBytes() { + m.size_bytes = nil + m.addsize_bytes = nil +} + +// SetCreatedAt sets the "created_at" field. +func (m *SoraCacheFileMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SoraCacheFileMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the SoraCacheFile entity. +// If the SoraCacheFile object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraCacheFileMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SoraCacheFileMutation) ResetCreatedAt() { + m.created_at = nil +} + +// Where appends a list predicates to the SoraCacheFileMutation builder. +func (m *SoraCacheFileMutation) Where(ps ...predicate.SoraCacheFile) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SoraCacheFileMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SoraCacheFileMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SoraCacheFile, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SoraCacheFileMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SoraCacheFileMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (SoraCacheFile). +func (m *SoraCacheFileMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SoraCacheFileMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.task_id != nil { + fields = append(fields, soracachefile.FieldTaskID) + } + if m.account_id != nil { + fields = append(fields, soracachefile.FieldAccountID) + } + if m.user_id != nil { + fields = append(fields, soracachefile.FieldUserID) + } + if m.media_type != nil { + fields = append(fields, soracachefile.FieldMediaType) + } + if m.original_url != nil { + fields = append(fields, soracachefile.FieldOriginalURL) + } + if m.cache_path != nil { + fields = append(fields, soracachefile.FieldCachePath) + } + if m.cache_url != nil { + fields = append(fields, soracachefile.FieldCacheURL) + } + if m.size_bytes != nil { + fields = append(fields, soracachefile.FieldSizeBytes) + } + if m.created_at != nil { + fields = append(fields, soracachefile.FieldCreatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SoraCacheFileMutation) Field(name string) (ent.Value, bool) { + switch name { + case soracachefile.FieldTaskID: + return m.TaskID() + case soracachefile.FieldAccountID: + return m.AccountID() + case soracachefile.FieldUserID: + return m.UserID() + case soracachefile.FieldMediaType: + return m.MediaType() + case soracachefile.FieldOriginalURL: + return m.OriginalURL() + case soracachefile.FieldCachePath: + return m.CachePath() + case soracachefile.FieldCacheURL: + return m.CacheURL() + case soracachefile.FieldSizeBytes: + return m.SizeBytes() + case soracachefile.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SoraCacheFileMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case soracachefile.FieldTaskID: + return m.OldTaskID(ctx) + case soracachefile.FieldAccountID: + return m.OldAccountID(ctx) + case soracachefile.FieldUserID: + return m.OldUserID(ctx) + case soracachefile.FieldMediaType: + return m.OldMediaType(ctx) + case soracachefile.FieldOriginalURL: + return m.OldOriginalURL(ctx) + case soracachefile.FieldCachePath: + return m.OldCachePath(ctx) + case soracachefile.FieldCacheURL: + return m.OldCacheURL(ctx) + case soracachefile.FieldSizeBytes: + return m.OldSizeBytes(ctx) + case soracachefile.FieldCreatedAt: + return m.OldCreatedAt(ctx) + } + return nil, fmt.Errorf("unknown SoraCacheFile field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraCacheFileMutation) SetField(name string, value ent.Value) error { + switch name { + case soracachefile.FieldTaskID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTaskID(v) + return nil + case soracachefile.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccountID(v) + return nil + case soracachefile.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case soracachefile.FieldMediaType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMediaType(v) + return nil + case soracachefile.FieldOriginalURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOriginalURL(v) + return nil + case soracachefile.FieldCachePath: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCachePath(v) + return nil + case soracachefile.FieldCacheURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheURL(v) + return nil + case soracachefile.FieldSizeBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSizeBytes(v) + return nil + case soracachefile.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + } + return fmt.Errorf("unknown SoraCacheFile field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SoraCacheFileMutation) AddedFields() []string { + var fields []string + if m.addaccount_id != nil { + fields = append(fields, soracachefile.FieldAccountID) + } + if m.adduser_id != nil { + fields = append(fields, soracachefile.FieldUserID) + } + if m.addsize_bytes != nil { + fields = append(fields, soracachefile.FieldSizeBytes) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SoraCacheFileMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case soracachefile.FieldAccountID: + return m.AddedAccountID() + case soracachefile.FieldUserID: + return m.AddedUserID() + case soracachefile.FieldSizeBytes: + return m.AddedSizeBytes() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraCacheFileMutation) AddField(name string, value ent.Value) error { + switch name { + case soracachefile.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAccountID(v) + return nil + case soracachefile.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUserID(v) + return nil + case soracachefile.FieldSizeBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSizeBytes(v) + return nil + } + return fmt.Errorf("unknown SoraCacheFile numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SoraCacheFileMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(soracachefile.FieldTaskID) { + fields = append(fields, soracachefile.FieldTaskID) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SoraCacheFileMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SoraCacheFileMutation) ClearField(name string) error { + switch name { + case soracachefile.FieldTaskID: + m.ClearTaskID() + return nil + } + return fmt.Errorf("unknown SoraCacheFile nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SoraCacheFileMutation) ResetField(name string) error { + switch name { + case soracachefile.FieldTaskID: + m.ResetTaskID() + return nil + case soracachefile.FieldAccountID: + m.ResetAccountID() + return nil + case soracachefile.FieldUserID: + m.ResetUserID() + return nil + case soracachefile.FieldMediaType: + m.ResetMediaType() + return nil + case soracachefile.FieldOriginalURL: + m.ResetOriginalURL() + return nil + case soracachefile.FieldCachePath: + m.ResetCachePath() + return nil + case soracachefile.FieldCacheURL: + m.ResetCacheURL() + return nil + case soracachefile.FieldSizeBytes: + m.ResetSizeBytes() + return nil + case soracachefile.FieldCreatedAt: + m.ResetCreatedAt() + return nil + } + return fmt.Errorf("unknown SoraCacheFile field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SoraCacheFileMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SoraCacheFileMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SoraCacheFileMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SoraCacheFileMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SoraCacheFileMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SoraCacheFileMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SoraCacheFileMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown SoraCacheFile unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SoraCacheFileMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown SoraCacheFile edge %s", name) +} + +// SoraTaskMutation represents an operation that mutates the SoraTask nodes in the graph. +type SoraTaskMutation struct { + config + op Op + typ string + id *int64 + task_id *string + account_id *int64 + addaccount_id *int64 + model *string + prompt *string + status *string + progress *float64 + addprogress *float64 + result_urls *string + error_message *string + retry_count *int + addretry_count *int + created_at *time.Time + completed_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*SoraTask, error) + predicates []predicate.SoraTask +} + +var _ ent.Mutation = (*SoraTaskMutation)(nil) + +// sorataskOption allows management of the mutation configuration using functional options. +type sorataskOption func(*SoraTaskMutation) + +// newSoraTaskMutation creates new mutation for the SoraTask entity. +func newSoraTaskMutation(c config, op Op, opts ...sorataskOption) *SoraTaskMutation { + m := &SoraTaskMutation{ + config: c, + op: op, + typ: TypeSoraTask, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSoraTaskID sets the ID field of the mutation. +func withSoraTaskID(id int64) sorataskOption { + return func(m *SoraTaskMutation) { + var ( + err error + once sync.Once + value *SoraTask + ) + m.oldValue = func(ctx context.Context) (*SoraTask, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().SoraTask.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSoraTask sets the old SoraTask of the mutation. +func withSoraTask(node *SoraTask) sorataskOption { + return func(m *SoraTaskMutation) { + m.oldValue = func(context.Context) (*SoraTask, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SoraTaskMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SoraTaskMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SoraTaskMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SoraTaskMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().SoraTask.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetTaskID sets the "task_id" field. +func (m *SoraTaskMutation) SetTaskID(s string) { + m.task_id = &s +} + +// TaskID returns the value of the "task_id" field in the mutation. +func (m *SoraTaskMutation) TaskID() (r string, exists bool) { + v := m.task_id + if v == nil { + return + } + return *v, true +} + +// OldTaskID returns the old "task_id" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldTaskID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTaskID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTaskID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTaskID: %w", err) + } + return oldValue.TaskID, nil +} + +// ResetTaskID resets all changes to the "task_id" field. +func (m *SoraTaskMutation) ResetTaskID() { + m.task_id = nil +} + +// SetAccountID sets the "account_id" field. +func (m *SoraTaskMutation) SetAccountID(i int64) { + m.account_id = &i + m.addaccount_id = nil +} + +// AccountID returns the value of the "account_id" field in the mutation. +func (m *SoraTaskMutation) AccountID() (r int64, exists bool) { + v := m.account_id + if v == nil { + return + } + return *v, true +} + +// OldAccountID returns the old "account_id" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldAccountID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccountID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccountID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccountID: %w", err) + } + return oldValue.AccountID, nil +} + +// AddAccountID adds i to the "account_id" field. +func (m *SoraTaskMutation) AddAccountID(i int64) { + if m.addaccount_id != nil { + *m.addaccount_id += i + } else { + m.addaccount_id = &i + } +} + +// AddedAccountID returns the value that was added to the "account_id" field in this mutation. +func (m *SoraTaskMutation) AddedAccountID() (r int64, exists bool) { + v := m.addaccount_id + if v == nil { + return + } + return *v, true +} + +// ResetAccountID resets all changes to the "account_id" field. +func (m *SoraTaskMutation) ResetAccountID() { + m.account_id = nil + m.addaccount_id = nil +} + +// SetModel sets the "model" field. +func (m *SoraTaskMutation) SetModel(s string) { + m.model = &s +} + +// Model returns the value of the "model" field in the mutation. +func (m *SoraTaskMutation) Model() (r string, exists bool) { + v := m.model + if v == nil { + return + } + return *v, true +} + +// OldModel returns the old "model" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldModel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModel: %w", err) + } + return oldValue.Model, nil +} + +// ResetModel resets all changes to the "model" field. +func (m *SoraTaskMutation) ResetModel() { + m.model = nil +} + +// SetPrompt sets the "prompt" field. +func (m *SoraTaskMutation) SetPrompt(s string) { + m.prompt = &s +} + +// Prompt returns the value of the "prompt" field in the mutation. +func (m *SoraTaskMutation) Prompt() (r string, exists bool) { + v := m.prompt + if v == nil { + return + } + return *v, true +} + +// OldPrompt returns the old "prompt" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldPrompt(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPrompt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPrompt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPrompt: %w", err) + } + return oldValue.Prompt, nil +} + +// ResetPrompt resets all changes to the "prompt" field. +func (m *SoraTaskMutation) ResetPrompt() { + m.prompt = nil +} + +// SetStatus sets the "status" field. +func (m *SoraTaskMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *SoraTaskMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *SoraTaskMutation) ResetStatus() { + m.status = nil +} + +// SetProgress sets the "progress" field. +func (m *SoraTaskMutation) SetProgress(f float64) { + m.progress = &f + m.addprogress = nil +} + +// Progress returns the value of the "progress" field in the mutation. +func (m *SoraTaskMutation) Progress() (r float64, exists bool) { + v := m.progress + if v == nil { + return + } + return *v, true +} + +// OldProgress returns the old "progress" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldProgress(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProgress is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProgress requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProgress: %w", err) + } + return oldValue.Progress, nil +} + +// AddProgress adds f to the "progress" field. +func (m *SoraTaskMutation) AddProgress(f float64) { + if m.addprogress != nil { + *m.addprogress += f + } else { + m.addprogress = &f + } +} + +// AddedProgress returns the value that was added to the "progress" field in this mutation. +func (m *SoraTaskMutation) AddedProgress() (r float64, exists bool) { + v := m.addprogress + if v == nil { + return + } + return *v, true +} + +// ResetProgress resets all changes to the "progress" field. +func (m *SoraTaskMutation) ResetProgress() { + m.progress = nil + m.addprogress = nil +} + +// SetResultUrls sets the "result_urls" field. +func (m *SoraTaskMutation) SetResultUrls(s string) { + m.result_urls = &s +} + +// ResultUrls returns the value of the "result_urls" field in the mutation. +func (m *SoraTaskMutation) ResultUrls() (r string, exists bool) { + v := m.result_urls + if v == nil { + return + } + return *v, true +} + +// OldResultUrls returns the old "result_urls" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldResultUrls(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResultUrls is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResultUrls requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResultUrls: %w", err) + } + return oldValue.ResultUrls, nil +} + +// ClearResultUrls clears the value of the "result_urls" field. +func (m *SoraTaskMutation) ClearResultUrls() { + m.result_urls = nil + m.clearedFields[soratask.FieldResultUrls] = struct{}{} +} + +// ResultUrlsCleared returns if the "result_urls" field was cleared in this mutation. +func (m *SoraTaskMutation) ResultUrlsCleared() bool { + _, ok := m.clearedFields[soratask.FieldResultUrls] + return ok +} + +// ResetResultUrls resets all changes to the "result_urls" field. +func (m *SoraTaskMutation) ResetResultUrls() { + m.result_urls = nil + delete(m.clearedFields, soratask.FieldResultUrls) +} + +// SetErrorMessage sets the "error_message" field. +func (m *SoraTaskMutation) SetErrorMessage(s string) { + m.error_message = &s +} + +// ErrorMessage returns the value of the "error_message" field in the mutation. +func (m *SoraTaskMutation) ErrorMessage() (r string, exists bool) { + v := m.error_message + if v == nil { + return + } + return *v, true +} + +// OldErrorMessage returns the old "error_message" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldErrorMessage(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorMessage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorMessage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorMessage: %w", err) + } + return oldValue.ErrorMessage, nil +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (m *SoraTaskMutation) ClearErrorMessage() { + m.error_message = nil + m.clearedFields[soratask.FieldErrorMessage] = struct{}{} +} + +// ErrorMessageCleared returns if the "error_message" field was cleared in this mutation. +func (m *SoraTaskMutation) ErrorMessageCleared() bool { + _, ok := m.clearedFields[soratask.FieldErrorMessage] + return ok +} + +// ResetErrorMessage resets all changes to the "error_message" field. +func (m *SoraTaskMutation) ResetErrorMessage() { + m.error_message = nil + delete(m.clearedFields, soratask.FieldErrorMessage) +} + +// SetRetryCount sets the "retry_count" field. +func (m *SoraTaskMutation) SetRetryCount(i int) { + m.retry_count = &i + m.addretry_count = nil +} + +// RetryCount returns the value of the "retry_count" field in the mutation. +func (m *SoraTaskMutation) RetryCount() (r int, exists bool) { + v := m.retry_count + if v == nil { + return + } + return *v, true +} + +// OldRetryCount returns the old "retry_count" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldRetryCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRetryCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRetryCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRetryCount: %w", err) + } + return oldValue.RetryCount, nil +} + +// AddRetryCount adds i to the "retry_count" field. +func (m *SoraTaskMutation) AddRetryCount(i int) { + if m.addretry_count != nil { + *m.addretry_count += i + } else { + m.addretry_count = &i + } +} + +// AddedRetryCount returns the value that was added to the "retry_count" field in this mutation. +func (m *SoraTaskMutation) AddedRetryCount() (r int, exists bool) { + v := m.addretry_count + if v == nil { + return + } + return *v, true +} + +// ResetRetryCount resets all changes to the "retry_count" field. +func (m *SoraTaskMutation) ResetRetryCount() { + m.retry_count = nil + m.addretry_count = nil +} + +// SetCreatedAt sets the "created_at" field. +func (m *SoraTaskMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SoraTaskMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SoraTaskMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetCompletedAt sets the "completed_at" field. +func (m *SoraTaskMutation) SetCompletedAt(t time.Time) { + m.completed_at = &t +} + +// CompletedAt returns the value of the "completed_at" field in the mutation. +func (m *SoraTaskMutation) CompletedAt() (r time.Time, exists bool) { + v := m.completed_at + if v == nil { + return + } + return *v, true +} + +// OldCompletedAt returns the old "completed_at" field's value of the SoraTask entity. +// If the SoraTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraTaskMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err) + } + return oldValue.CompletedAt, nil +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (m *SoraTaskMutation) ClearCompletedAt() { + m.completed_at = nil + m.clearedFields[soratask.FieldCompletedAt] = struct{}{} +} + +// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. +func (m *SoraTaskMutation) CompletedAtCleared() bool { + _, ok := m.clearedFields[soratask.FieldCompletedAt] + return ok +} + +// ResetCompletedAt resets all changes to the "completed_at" field. +func (m *SoraTaskMutation) ResetCompletedAt() { + m.completed_at = nil + delete(m.clearedFields, soratask.FieldCompletedAt) +} + +// Where appends a list predicates to the SoraTaskMutation builder. +func (m *SoraTaskMutation) Where(ps ...predicate.SoraTask) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SoraTaskMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SoraTaskMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SoraTask, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SoraTaskMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SoraTaskMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (SoraTask). +func (m *SoraTaskMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SoraTaskMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.task_id != nil { + fields = append(fields, soratask.FieldTaskID) + } + if m.account_id != nil { + fields = append(fields, soratask.FieldAccountID) + } + if m.model != nil { + fields = append(fields, soratask.FieldModel) + } + if m.prompt != nil { + fields = append(fields, soratask.FieldPrompt) + } + if m.status != nil { + fields = append(fields, soratask.FieldStatus) + } + if m.progress != nil { + fields = append(fields, soratask.FieldProgress) + } + if m.result_urls != nil { + fields = append(fields, soratask.FieldResultUrls) + } + if m.error_message != nil { + fields = append(fields, soratask.FieldErrorMessage) + } + if m.retry_count != nil { + fields = append(fields, soratask.FieldRetryCount) + } + if m.created_at != nil { + fields = append(fields, soratask.FieldCreatedAt) + } + if m.completed_at != nil { + fields = append(fields, soratask.FieldCompletedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SoraTaskMutation) Field(name string) (ent.Value, bool) { + switch name { + case soratask.FieldTaskID: + return m.TaskID() + case soratask.FieldAccountID: + return m.AccountID() + case soratask.FieldModel: + return m.Model() + case soratask.FieldPrompt: + return m.Prompt() + case soratask.FieldStatus: + return m.Status() + case soratask.FieldProgress: + return m.Progress() + case soratask.FieldResultUrls: + return m.ResultUrls() + case soratask.FieldErrorMessage: + return m.ErrorMessage() + case soratask.FieldRetryCount: + return m.RetryCount() + case soratask.FieldCreatedAt: + return m.CreatedAt() + case soratask.FieldCompletedAt: + return m.CompletedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SoraTaskMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case soratask.FieldTaskID: + return m.OldTaskID(ctx) + case soratask.FieldAccountID: + return m.OldAccountID(ctx) + case soratask.FieldModel: + return m.OldModel(ctx) + case soratask.FieldPrompt: + return m.OldPrompt(ctx) + case soratask.FieldStatus: + return m.OldStatus(ctx) + case soratask.FieldProgress: + return m.OldProgress(ctx) + case soratask.FieldResultUrls: + return m.OldResultUrls(ctx) + case soratask.FieldErrorMessage: + return m.OldErrorMessage(ctx) + case soratask.FieldRetryCount: + return m.OldRetryCount(ctx) + case soratask.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case soratask.FieldCompletedAt: + return m.OldCompletedAt(ctx) + } + return nil, fmt.Errorf("unknown SoraTask field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraTaskMutation) SetField(name string, value ent.Value) error { + switch name { + case soratask.FieldTaskID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTaskID(v) + return nil + case soratask.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccountID(v) + return nil + case soratask.FieldModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModel(v) + return nil + case soratask.FieldPrompt: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPrompt(v) + return nil + case soratask.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case soratask.FieldProgress: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProgress(v) + return nil + case soratask.FieldResultUrls: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResultUrls(v) + return nil + case soratask.FieldErrorMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorMessage(v) + return nil + case soratask.FieldRetryCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRetryCount(v) + return nil + case soratask.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case soratask.FieldCompletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletedAt(v) + return nil + } + return fmt.Errorf("unknown SoraTask field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SoraTaskMutation) AddedFields() []string { + var fields []string + if m.addaccount_id != nil { + fields = append(fields, soratask.FieldAccountID) + } + if m.addprogress != nil { + fields = append(fields, soratask.FieldProgress) + } + if m.addretry_count != nil { + fields = append(fields, soratask.FieldRetryCount) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SoraTaskMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case soratask.FieldAccountID: + return m.AddedAccountID() + case soratask.FieldProgress: + return m.AddedProgress() + case soratask.FieldRetryCount: + return m.AddedRetryCount() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraTaskMutation) AddField(name string, value ent.Value) error { + switch name { + case soratask.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAccountID(v) + return nil + case soratask.FieldProgress: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddProgress(v) + return nil + case soratask.FieldRetryCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRetryCount(v) + return nil + } + return fmt.Errorf("unknown SoraTask numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SoraTaskMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(soratask.FieldResultUrls) { + fields = append(fields, soratask.FieldResultUrls) + } + if m.FieldCleared(soratask.FieldErrorMessage) { + fields = append(fields, soratask.FieldErrorMessage) + } + if m.FieldCleared(soratask.FieldCompletedAt) { + fields = append(fields, soratask.FieldCompletedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SoraTaskMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SoraTaskMutation) ClearField(name string) error { + switch name { + case soratask.FieldResultUrls: + m.ClearResultUrls() + return nil + case soratask.FieldErrorMessage: + m.ClearErrorMessage() + return nil + case soratask.FieldCompletedAt: + m.ClearCompletedAt() + return nil + } + return fmt.Errorf("unknown SoraTask nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SoraTaskMutation) ResetField(name string) error { + switch name { + case soratask.FieldTaskID: + m.ResetTaskID() + return nil + case soratask.FieldAccountID: + m.ResetAccountID() + return nil + case soratask.FieldModel: + m.ResetModel() + return nil + case soratask.FieldPrompt: + m.ResetPrompt() + return nil + case soratask.FieldStatus: + m.ResetStatus() + return nil + case soratask.FieldProgress: + m.ResetProgress() + return nil + case soratask.FieldResultUrls: + m.ResetResultUrls() + return nil + case soratask.FieldErrorMessage: + m.ResetErrorMessage() + return nil + case soratask.FieldRetryCount: + m.ResetRetryCount() + return nil + case soratask.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case soratask.FieldCompletedAt: + m.ResetCompletedAt() + return nil + } + return fmt.Errorf("unknown SoraTask field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SoraTaskMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SoraTaskMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SoraTaskMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SoraTaskMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SoraTaskMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SoraTaskMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SoraTaskMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown SoraTask unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SoraTaskMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown SoraTask edge %s", name) +} + +// SoraUsageStatMutation represents an operation that mutates the SoraUsageStat nodes in the graph. +type SoraUsageStatMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + account_id *int64 + addaccount_id *int64 + image_count *int + addimage_count *int + video_count *int + addvideo_count *int + error_count *int + adderror_count *int + last_error_at *time.Time + today_image_count *int + addtoday_image_count *int + today_video_count *int + addtoday_video_count *int + today_error_count *int + addtoday_error_count *int + today_date *time.Time + consecutive_error_count *int + addconsecutive_error_count *int + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*SoraUsageStat, error) + predicates []predicate.SoraUsageStat +} + +var _ ent.Mutation = (*SoraUsageStatMutation)(nil) + +// sorausagestatOption allows management of the mutation configuration using functional options. +type sorausagestatOption func(*SoraUsageStatMutation) + +// newSoraUsageStatMutation creates new mutation for the SoraUsageStat entity. +func newSoraUsageStatMutation(c config, op Op, opts ...sorausagestatOption) *SoraUsageStatMutation { + m := &SoraUsageStatMutation{ + config: c, + op: op, + typ: TypeSoraUsageStat, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSoraUsageStatID sets the ID field of the mutation. +func withSoraUsageStatID(id int64) sorausagestatOption { + return func(m *SoraUsageStatMutation) { + var ( + err error + once sync.Once + value *SoraUsageStat + ) + m.oldValue = func(ctx context.Context) (*SoraUsageStat, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().SoraUsageStat.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSoraUsageStat sets the old SoraUsageStat of the mutation. +func withSoraUsageStat(node *SoraUsageStat) sorausagestatOption { + return func(m *SoraUsageStatMutation) { + m.oldValue = func(context.Context) (*SoraUsageStat, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SoraUsageStatMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SoraUsageStatMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SoraUsageStatMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SoraUsageStatMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().SoraUsageStat.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *SoraUsageStatMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SoraUsageStatMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SoraUsageStatMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *SoraUsageStatMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *SoraUsageStatMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *SoraUsageStatMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetAccountID sets the "account_id" field. +func (m *SoraUsageStatMutation) SetAccountID(i int64) { + m.account_id = &i + m.addaccount_id = nil +} + +// AccountID returns the value of the "account_id" field in the mutation. +func (m *SoraUsageStatMutation) AccountID() (r int64, exists bool) { + v := m.account_id + if v == nil { + return + } + return *v, true +} + +// OldAccountID returns the old "account_id" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldAccountID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAccountID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAccountID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAccountID: %w", err) + } + return oldValue.AccountID, nil +} + +// AddAccountID adds i to the "account_id" field. +func (m *SoraUsageStatMutation) AddAccountID(i int64) { + if m.addaccount_id != nil { + *m.addaccount_id += i + } else { + m.addaccount_id = &i + } +} + +// AddedAccountID returns the value that was added to the "account_id" field in this mutation. +func (m *SoraUsageStatMutation) AddedAccountID() (r int64, exists bool) { + v := m.addaccount_id + if v == nil { + return + } + return *v, true +} + +// ResetAccountID resets all changes to the "account_id" field. +func (m *SoraUsageStatMutation) ResetAccountID() { + m.account_id = nil + m.addaccount_id = nil +} + +// SetImageCount sets the "image_count" field. +func (m *SoraUsageStatMutation) SetImageCount(i int) { + m.image_count = &i + m.addimage_count = nil +} + +// ImageCount returns the value of the "image_count" field in the mutation. +func (m *SoraUsageStatMutation) ImageCount() (r int, exists bool) { + v := m.image_count + if v == nil { + return + } + return *v, true +} + +// OldImageCount returns the old "image_count" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldImageCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageCount: %w", err) + } + return oldValue.ImageCount, nil +} + +// AddImageCount adds i to the "image_count" field. +func (m *SoraUsageStatMutation) AddImageCount(i int) { + if m.addimage_count != nil { + *m.addimage_count += i + } else { + m.addimage_count = &i + } +} + +// AddedImageCount returns the value that was added to the "image_count" field in this mutation. +func (m *SoraUsageStatMutation) AddedImageCount() (r int, exists bool) { + v := m.addimage_count + if v == nil { + return + } + return *v, true +} + +// ResetImageCount resets all changes to the "image_count" field. +func (m *SoraUsageStatMutation) ResetImageCount() { + m.image_count = nil + m.addimage_count = nil +} + +// SetVideoCount sets the "video_count" field. +func (m *SoraUsageStatMutation) SetVideoCount(i int) { + m.video_count = &i + m.addvideo_count = nil +} + +// VideoCount returns the value of the "video_count" field in the mutation. +func (m *SoraUsageStatMutation) VideoCount() (r int, exists bool) { + v := m.video_count + if v == nil { + return + } + return *v, true +} + +// OldVideoCount returns the old "video_count" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldVideoCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldVideoCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldVideoCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldVideoCount: %w", err) + } + return oldValue.VideoCount, nil +} + +// AddVideoCount adds i to the "video_count" field. +func (m *SoraUsageStatMutation) AddVideoCount(i int) { + if m.addvideo_count != nil { + *m.addvideo_count += i + } else { + m.addvideo_count = &i + } +} + +// AddedVideoCount returns the value that was added to the "video_count" field in this mutation. +func (m *SoraUsageStatMutation) AddedVideoCount() (r int, exists bool) { + v := m.addvideo_count + if v == nil { + return + } + return *v, true +} + +// ResetVideoCount resets all changes to the "video_count" field. +func (m *SoraUsageStatMutation) ResetVideoCount() { + m.video_count = nil + m.addvideo_count = nil +} + +// SetErrorCount sets the "error_count" field. +func (m *SoraUsageStatMutation) SetErrorCount(i int) { + m.error_count = &i + m.adderror_count = nil +} + +// ErrorCount returns the value of the "error_count" field in the mutation. +func (m *SoraUsageStatMutation) ErrorCount() (r int, exists bool) { + v := m.error_count + if v == nil { + return + } + return *v, true +} + +// OldErrorCount returns the old "error_count" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldErrorCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorCount: %w", err) + } + return oldValue.ErrorCount, nil +} + +// AddErrorCount adds i to the "error_count" field. +func (m *SoraUsageStatMutation) AddErrorCount(i int) { + if m.adderror_count != nil { + *m.adderror_count += i + } else { + m.adderror_count = &i + } +} + +// AddedErrorCount returns the value that was added to the "error_count" field in this mutation. +func (m *SoraUsageStatMutation) AddedErrorCount() (r int, exists bool) { + v := m.adderror_count + if v == nil { + return + } + return *v, true +} + +// ResetErrorCount resets all changes to the "error_count" field. +func (m *SoraUsageStatMutation) ResetErrorCount() { + m.error_count = nil + m.adderror_count = nil +} + +// SetLastErrorAt sets the "last_error_at" field. +func (m *SoraUsageStatMutation) SetLastErrorAt(t time.Time) { + m.last_error_at = &t +} + +// LastErrorAt returns the value of the "last_error_at" field in the mutation. +func (m *SoraUsageStatMutation) LastErrorAt() (r time.Time, exists bool) { + v := m.last_error_at + if v == nil { + return + } + return *v, true +} + +// OldLastErrorAt returns the old "last_error_at" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldLastErrorAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastErrorAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastErrorAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastErrorAt: %w", err) + } + return oldValue.LastErrorAt, nil +} + +// ClearLastErrorAt clears the value of the "last_error_at" field. +func (m *SoraUsageStatMutation) ClearLastErrorAt() { + m.last_error_at = nil + m.clearedFields[sorausagestat.FieldLastErrorAt] = struct{}{} +} + +// LastErrorAtCleared returns if the "last_error_at" field was cleared in this mutation. +func (m *SoraUsageStatMutation) LastErrorAtCleared() bool { + _, ok := m.clearedFields[sorausagestat.FieldLastErrorAt] + return ok +} + +// ResetLastErrorAt resets all changes to the "last_error_at" field. +func (m *SoraUsageStatMutation) ResetLastErrorAt() { + m.last_error_at = nil + delete(m.clearedFields, sorausagestat.FieldLastErrorAt) +} + +// SetTodayImageCount sets the "today_image_count" field. +func (m *SoraUsageStatMutation) SetTodayImageCount(i int) { + m.today_image_count = &i + m.addtoday_image_count = nil +} + +// TodayImageCount returns the value of the "today_image_count" field in the mutation. +func (m *SoraUsageStatMutation) TodayImageCount() (r int, exists bool) { + v := m.today_image_count + if v == nil { + return + } + return *v, true +} + +// OldTodayImageCount returns the old "today_image_count" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldTodayImageCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTodayImageCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTodayImageCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTodayImageCount: %w", err) + } + return oldValue.TodayImageCount, nil +} + +// AddTodayImageCount adds i to the "today_image_count" field. +func (m *SoraUsageStatMutation) AddTodayImageCount(i int) { + if m.addtoday_image_count != nil { + *m.addtoday_image_count += i + } else { + m.addtoday_image_count = &i + } +} + +// AddedTodayImageCount returns the value that was added to the "today_image_count" field in this mutation. +func (m *SoraUsageStatMutation) AddedTodayImageCount() (r int, exists bool) { + v := m.addtoday_image_count + if v == nil { + return + } + return *v, true +} + +// ResetTodayImageCount resets all changes to the "today_image_count" field. +func (m *SoraUsageStatMutation) ResetTodayImageCount() { + m.today_image_count = nil + m.addtoday_image_count = nil +} + +// SetTodayVideoCount sets the "today_video_count" field. +func (m *SoraUsageStatMutation) SetTodayVideoCount(i int) { + m.today_video_count = &i + m.addtoday_video_count = nil +} + +// TodayVideoCount returns the value of the "today_video_count" field in the mutation. +func (m *SoraUsageStatMutation) TodayVideoCount() (r int, exists bool) { + v := m.today_video_count + if v == nil { + return + } + return *v, true +} + +// OldTodayVideoCount returns the old "today_video_count" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldTodayVideoCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTodayVideoCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTodayVideoCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTodayVideoCount: %w", err) + } + return oldValue.TodayVideoCount, nil +} + +// AddTodayVideoCount adds i to the "today_video_count" field. +func (m *SoraUsageStatMutation) AddTodayVideoCount(i int) { + if m.addtoday_video_count != nil { + *m.addtoday_video_count += i + } else { + m.addtoday_video_count = &i + } +} + +// AddedTodayVideoCount returns the value that was added to the "today_video_count" field in this mutation. +func (m *SoraUsageStatMutation) AddedTodayVideoCount() (r int, exists bool) { + v := m.addtoday_video_count + if v == nil { + return + } + return *v, true +} + +// ResetTodayVideoCount resets all changes to the "today_video_count" field. +func (m *SoraUsageStatMutation) ResetTodayVideoCount() { + m.today_video_count = nil + m.addtoday_video_count = nil +} + +// SetTodayErrorCount sets the "today_error_count" field. +func (m *SoraUsageStatMutation) SetTodayErrorCount(i int) { + m.today_error_count = &i + m.addtoday_error_count = nil +} + +// TodayErrorCount returns the value of the "today_error_count" field in the mutation. +func (m *SoraUsageStatMutation) TodayErrorCount() (r int, exists bool) { + v := m.today_error_count + if v == nil { + return + } + return *v, true +} + +// OldTodayErrorCount returns the old "today_error_count" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldTodayErrorCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTodayErrorCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTodayErrorCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTodayErrorCount: %w", err) + } + return oldValue.TodayErrorCount, nil +} + +// AddTodayErrorCount adds i to the "today_error_count" field. +func (m *SoraUsageStatMutation) AddTodayErrorCount(i int) { + if m.addtoday_error_count != nil { + *m.addtoday_error_count += i + } else { + m.addtoday_error_count = &i + } +} + +// AddedTodayErrorCount returns the value that was added to the "today_error_count" field in this mutation. +func (m *SoraUsageStatMutation) AddedTodayErrorCount() (r int, exists bool) { + v := m.addtoday_error_count + if v == nil { + return + } + return *v, true +} + +// ResetTodayErrorCount resets all changes to the "today_error_count" field. +func (m *SoraUsageStatMutation) ResetTodayErrorCount() { + m.today_error_count = nil + m.addtoday_error_count = nil +} + +// SetTodayDate sets the "today_date" field. +func (m *SoraUsageStatMutation) SetTodayDate(t time.Time) { + m.today_date = &t +} + +// TodayDate returns the value of the "today_date" field in the mutation. +func (m *SoraUsageStatMutation) TodayDate() (r time.Time, exists bool) { + v := m.today_date + if v == nil { + return + } + return *v, true +} + +// OldTodayDate returns the old "today_date" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldTodayDate(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTodayDate is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTodayDate requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTodayDate: %w", err) + } + return oldValue.TodayDate, nil +} + +// ClearTodayDate clears the value of the "today_date" field. +func (m *SoraUsageStatMutation) ClearTodayDate() { + m.today_date = nil + m.clearedFields[sorausagestat.FieldTodayDate] = struct{}{} +} + +// TodayDateCleared returns if the "today_date" field was cleared in this mutation. +func (m *SoraUsageStatMutation) TodayDateCleared() bool { + _, ok := m.clearedFields[sorausagestat.FieldTodayDate] + return ok +} + +// ResetTodayDate resets all changes to the "today_date" field. +func (m *SoraUsageStatMutation) ResetTodayDate() { + m.today_date = nil + delete(m.clearedFields, sorausagestat.FieldTodayDate) +} + +// SetConsecutiveErrorCount sets the "consecutive_error_count" field. +func (m *SoraUsageStatMutation) SetConsecutiveErrorCount(i int) { + m.consecutive_error_count = &i + m.addconsecutive_error_count = nil +} + +// ConsecutiveErrorCount returns the value of the "consecutive_error_count" field in the mutation. +func (m *SoraUsageStatMutation) ConsecutiveErrorCount() (r int, exists bool) { + v := m.consecutive_error_count + if v == nil { + return + } + return *v, true +} + +// OldConsecutiveErrorCount returns the old "consecutive_error_count" field's value of the SoraUsageStat entity. +// If the SoraUsageStat object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SoraUsageStatMutation) OldConsecutiveErrorCount(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConsecutiveErrorCount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConsecutiveErrorCount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConsecutiveErrorCount: %w", err) + } + return oldValue.ConsecutiveErrorCount, nil +} + +// AddConsecutiveErrorCount adds i to the "consecutive_error_count" field. +func (m *SoraUsageStatMutation) AddConsecutiveErrorCount(i int) { + if m.addconsecutive_error_count != nil { + *m.addconsecutive_error_count += i + } else { + m.addconsecutive_error_count = &i + } +} + +// AddedConsecutiveErrorCount returns the value that was added to the "consecutive_error_count" field in this mutation. +func (m *SoraUsageStatMutation) AddedConsecutiveErrorCount() (r int, exists bool) { + v := m.addconsecutive_error_count + if v == nil { + return + } + return *v, true +} + +// ResetConsecutiveErrorCount resets all changes to the "consecutive_error_count" field. +func (m *SoraUsageStatMutation) ResetConsecutiveErrorCount() { + m.consecutive_error_count = nil + m.addconsecutive_error_count = nil +} + +// Where appends a list predicates to the SoraUsageStatMutation builder. +func (m *SoraUsageStatMutation) Where(ps ...predicate.SoraUsageStat) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SoraUsageStatMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SoraUsageStatMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SoraUsageStat, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SoraUsageStatMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SoraUsageStatMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (SoraUsageStat). +func (m *SoraUsageStatMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SoraUsageStatMutation) Fields() []string { + fields := make([]string, 0, 12) + if m.created_at != nil { + fields = append(fields, sorausagestat.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, sorausagestat.FieldUpdatedAt) + } + if m.account_id != nil { + fields = append(fields, sorausagestat.FieldAccountID) + } + if m.image_count != nil { + fields = append(fields, sorausagestat.FieldImageCount) + } + if m.video_count != nil { + fields = append(fields, sorausagestat.FieldVideoCount) + } + if m.error_count != nil { + fields = append(fields, sorausagestat.FieldErrorCount) + } + if m.last_error_at != nil { + fields = append(fields, sorausagestat.FieldLastErrorAt) + } + if m.today_image_count != nil { + fields = append(fields, sorausagestat.FieldTodayImageCount) + } + if m.today_video_count != nil { + fields = append(fields, sorausagestat.FieldTodayVideoCount) + } + if m.today_error_count != nil { + fields = append(fields, sorausagestat.FieldTodayErrorCount) + } + if m.today_date != nil { + fields = append(fields, sorausagestat.FieldTodayDate) + } + if m.consecutive_error_count != nil { + fields = append(fields, sorausagestat.FieldConsecutiveErrorCount) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SoraUsageStatMutation) Field(name string) (ent.Value, bool) { + switch name { + case sorausagestat.FieldCreatedAt: + return m.CreatedAt() + case sorausagestat.FieldUpdatedAt: + return m.UpdatedAt() + case sorausagestat.FieldAccountID: + return m.AccountID() + case sorausagestat.FieldImageCount: + return m.ImageCount() + case sorausagestat.FieldVideoCount: + return m.VideoCount() + case sorausagestat.FieldErrorCount: + return m.ErrorCount() + case sorausagestat.FieldLastErrorAt: + return m.LastErrorAt() + case sorausagestat.FieldTodayImageCount: + return m.TodayImageCount() + case sorausagestat.FieldTodayVideoCount: + return m.TodayVideoCount() + case sorausagestat.FieldTodayErrorCount: + return m.TodayErrorCount() + case sorausagestat.FieldTodayDate: + return m.TodayDate() + case sorausagestat.FieldConsecutiveErrorCount: + return m.ConsecutiveErrorCount() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SoraUsageStatMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case sorausagestat.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case sorausagestat.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case sorausagestat.FieldAccountID: + return m.OldAccountID(ctx) + case sorausagestat.FieldImageCount: + return m.OldImageCount(ctx) + case sorausagestat.FieldVideoCount: + return m.OldVideoCount(ctx) + case sorausagestat.FieldErrorCount: + return m.OldErrorCount(ctx) + case sorausagestat.FieldLastErrorAt: + return m.OldLastErrorAt(ctx) + case sorausagestat.FieldTodayImageCount: + return m.OldTodayImageCount(ctx) + case sorausagestat.FieldTodayVideoCount: + return m.OldTodayVideoCount(ctx) + case sorausagestat.FieldTodayErrorCount: + return m.OldTodayErrorCount(ctx) + case sorausagestat.FieldTodayDate: + return m.OldTodayDate(ctx) + case sorausagestat.FieldConsecutiveErrorCount: + return m.OldConsecutiveErrorCount(ctx) + } + return nil, fmt.Errorf("unknown SoraUsageStat field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraUsageStatMutation) SetField(name string, value ent.Value) error { + switch name { + case sorausagestat.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case sorausagestat.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case sorausagestat.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAccountID(v) + return nil + case sorausagestat.FieldImageCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageCount(v) + return nil + case sorausagestat.FieldVideoCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVideoCount(v) + return nil + case sorausagestat.FieldErrorCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorCount(v) + return nil + case sorausagestat.FieldLastErrorAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastErrorAt(v) + return nil + case sorausagestat.FieldTodayImageCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTodayImageCount(v) + return nil + case sorausagestat.FieldTodayVideoCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTodayVideoCount(v) + return nil + case sorausagestat.FieldTodayErrorCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTodayErrorCount(v) + return nil + case sorausagestat.FieldTodayDate: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTodayDate(v) + return nil + case sorausagestat.FieldConsecutiveErrorCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConsecutiveErrorCount(v) + return nil + } + return fmt.Errorf("unknown SoraUsageStat field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SoraUsageStatMutation) AddedFields() []string { + var fields []string + if m.addaccount_id != nil { + fields = append(fields, sorausagestat.FieldAccountID) + } + if m.addimage_count != nil { + fields = append(fields, sorausagestat.FieldImageCount) + } + if m.addvideo_count != nil { + fields = append(fields, sorausagestat.FieldVideoCount) + } + if m.adderror_count != nil { + fields = append(fields, sorausagestat.FieldErrorCount) + } + if m.addtoday_image_count != nil { + fields = append(fields, sorausagestat.FieldTodayImageCount) + } + if m.addtoday_video_count != nil { + fields = append(fields, sorausagestat.FieldTodayVideoCount) + } + if m.addtoday_error_count != nil { + fields = append(fields, sorausagestat.FieldTodayErrorCount) + } + if m.addconsecutive_error_count != nil { + fields = append(fields, sorausagestat.FieldConsecutiveErrorCount) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SoraUsageStatMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case sorausagestat.FieldAccountID: + return m.AddedAccountID() + case sorausagestat.FieldImageCount: + return m.AddedImageCount() + case sorausagestat.FieldVideoCount: + return m.AddedVideoCount() + case sorausagestat.FieldErrorCount: + return m.AddedErrorCount() + case sorausagestat.FieldTodayImageCount: + return m.AddedTodayImageCount() + case sorausagestat.FieldTodayVideoCount: + return m.AddedTodayVideoCount() + case sorausagestat.FieldTodayErrorCount: + return m.AddedTodayErrorCount() + case sorausagestat.FieldConsecutiveErrorCount: + return m.AddedConsecutiveErrorCount() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SoraUsageStatMutation) AddField(name string, value ent.Value) error { + switch name { + case sorausagestat.FieldAccountID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAccountID(v) + return nil + case sorausagestat.FieldImageCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImageCount(v) + return nil + case sorausagestat.FieldVideoCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddVideoCount(v) + return nil + case sorausagestat.FieldErrorCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddErrorCount(v) + return nil + case sorausagestat.FieldTodayImageCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTodayImageCount(v) + return nil + case sorausagestat.FieldTodayVideoCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTodayVideoCount(v) + return nil + case sorausagestat.FieldTodayErrorCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTodayErrorCount(v) + return nil + case sorausagestat.FieldConsecutiveErrorCount: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddConsecutiveErrorCount(v) + return nil + } + return fmt.Errorf("unknown SoraUsageStat numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SoraUsageStatMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(sorausagestat.FieldLastErrorAt) { + fields = append(fields, sorausagestat.FieldLastErrorAt) + } + if m.FieldCleared(sorausagestat.FieldTodayDate) { + fields = append(fields, sorausagestat.FieldTodayDate) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SoraUsageStatMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SoraUsageStatMutation) ClearField(name string) error { + switch name { + case sorausagestat.FieldLastErrorAt: + m.ClearLastErrorAt() + return nil + case sorausagestat.FieldTodayDate: + m.ClearTodayDate() + return nil + } + return fmt.Errorf("unknown SoraUsageStat nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SoraUsageStatMutation) ResetField(name string) error { + switch name { + case sorausagestat.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case sorausagestat.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case sorausagestat.FieldAccountID: + m.ResetAccountID() + return nil + case sorausagestat.FieldImageCount: + m.ResetImageCount() + return nil + case sorausagestat.FieldVideoCount: + m.ResetVideoCount() + return nil + case sorausagestat.FieldErrorCount: + m.ResetErrorCount() + return nil + case sorausagestat.FieldLastErrorAt: + m.ResetLastErrorAt() + return nil + case sorausagestat.FieldTodayImageCount: + m.ResetTodayImageCount() + return nil + case sorausagestat.FieldTodayVideoCount: + m.ResetTodayVideoCount() + return nil + case sorausagestat.FieldTodayErrorCount: + m.ResetTodayErrorCount() + return nil + case sorausagestat.FieldTodayDate: + m.ResetTodayDate() + return nil + case sorausagestat.FieldConsecutiveErrorCount: + m.ResetConsecutiveErrorCount() + return nil + } + return fmt.Errorf("unknown SoraUsageStat field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SoraUsageStatMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SoraUsageStatMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SoraUsageStatMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SoraUsageStatMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SoraUsageStatMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SoraUsageStatMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SoraUsageStatMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown SoraUsageStat unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SoraUsageStatMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown SoraUsageStat edge %s", name) +} + // UsageCleanupTaskMutation represents an operation that mutates the UsageCleanupTask nodes in the graph. type UsageCleanupTaskMutation struct { config diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 785cb4e6..2ad57ac3 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -33,6 +33,18 @@ type RedeemCode func(*sql.Selector) // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) +// SoraAccount is the predicate function for soraaccount builders. +type SoraAccount func(*sql.Selector) + +// SoraCacheFile is the predicate function for soracachefile builders. +type SoraCacheFile func(*sql.Selector) + +// SoraTask is the predicate function for soratask builders. +type SoraTask func(*sql.Selector) + +// SoraUsageStat is the predicate function for sorausagestat builders. +type SoraUsageStat func(*sql.Selector) + // UsageCleanupTask is the predicate function for usagecleanuptask builders. type UsageCleanupTask func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 1e3f4cbe..31e88e46 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -15,6 +15,10 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/schema" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" + "github.com/Wei-Shaw/sub2api/ent/soratask" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" @@ -496,6 +500,150 @@ func init() { setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time) // setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time) + soraaccountMixin := schema.SoraAccount{}.Mixin() + soraaccountMixinFields0 := soraaccountMixin[0].Fields() + _ = soraaccountMixinFields0 + soraaccountFields := schema.SoraAccount{}.Fields() + _ = soraaccountFields + // soraaccountDescCreatedAt is the schema descriptor for created_at field. + soraaccountDescCreatedAt := soraaccountMixinFields0[0].Descriptor() + // soraaccount.DefaultCreatedAt holds the default value on creation for the created_at field. + soraaccount.DefaultCreatedAt = soraaccountDescCreatedAt.Default.(func() time.Time) + // soraaccountDescUpdatedAt is the schema descriptor for updated_at field. + soraaccountDescUpdatedAt := soraaccountMixinFields0[1].Descriptor() + // soraaccount.DefaultUpdatedAt holds the default value on creation for the updated_at field. + soraaccount.DefaultUpdatedAt = soraaccountDescUpdatedAt.Default.(func() time.Time) + // soraaccount.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + soraaccount.UpdateDefaultUpdatedAt = soraaccountDescUpdatedAt.UpdateDefault.(func() time.Time) + // soraaccountDescUseCount is the schema descriptor for use_count field. + soraaccountDescUseCount := soraaccountFields[8].Descriptor() + // soraaccount.DefaultUseCount holds the default value on creation for the use_count field. + soraaccount.DefaultUseCount = soraaccountDescUseCount.Default.(int) + // soraaccountDescSoraSupported is the schema descriptor for sora_supported field. + soraaccountDescSoraSupported := soraaccountFields[12].Descriptor() + // soraaccount.DefaultSoraSupported holds the default value on creation for the sora_supported field. + soraaccount.DefaultSoraSupported = soraaccountDescSoraSupported.Default.(bool) + // soraaccountDescSoraRedeemedCount is the schema descriptor for sora_redeemed_count field. + soraaccountDescSoraRedeemedCount := soraaccountFields[14].Descriptor() + // soraaccount.DefaultSoraRedeemedCount holds the default value on creation for the sora_redeemed_count field. + soraaccount.DefaultSoraRedeemedCount = soraaccountDescSoraRedeemedCount.Default.(int) + // soraaccountDescSoraRemainingCount is the schema descriptor for sora_remaining_count field. + soraaccountDescSoraRemainingCount := soraaccountFields[15].Descriptor() + // soraaccount.DefaultSoraRemainingCount holds the default value on creation for the sora_remaining_count field. + soraaccount.DefaultSoraRemainingCount = soraaccountDescSoraRemainingCount.Default.(int) + // soraaccountDescSoraTotalCount is the schema descriptor for sora_total_count field. + soraaccountDescSoraTotalCount := soraaccountFields[16].Descriptor() + // soraaccount.DefaultSoraTotalCount holds the default value on creation for the sora_total_count field. + soraaccount.DefaultSoraTotalCount = soraaccountDescSoraTotalCount.Default.(int) + // soraaccountDescImageEnabled is the schema descriptor for image_enabled field. + soraaccountDescImageEnabled := soraaccountFields[19].Descriptor() + // soraaccount.DefaultImageEnabled holds the default value on creation for the image_enabled field. + soraaccount.DefaultImageEnabled = soraaccountDescImageEnabled.Default.(bool) + // soraaccountDescVideoEnabled is the schema descriptor for video_enabled field. + soraaccountDescVideoEnabled := soraaccountFields[20].Descriptor() + // soraaccount.DefaultVideoEnabled holds the default value on creation for the video_enabled field. + soraaccount.DefaultVideoEnabled = soraaccountDescVideoEnabled.Default.(bool) + // soraaccountDescImageConcurrency is the schema descriptor for image_concurrency field. + soraaccountDescImageConcurrency := soraaccountFields[21].Descriptor() + // soraaccount.DefaultImageConcurrency holds the default value on creation for the image_concurrency field. + soraaccount.DefaultImageConcurrency = soraaccountDescImageConcurrency.Default.(int) + // soraaccountDescVideoConcurrency is the schema descriptor for video_concurrency field. + soraaccountDescVideoConcurrency := soraaccountFields[22].Descriptor() + // soraaccount.DefaultVideoConcurrency holds the default value on creation for the video_concurrency field. + soraaccount.DefaultVideoConcurrency = soraaccountDescVideoConcurrency.Default.(int) + // soraaccountDescIsExpired is the schema descriptor for is_expired field. + soraaccountDescIsExpired := soraaccountFields[23].Descriptor() + // soraaccount.DefaultIsExpired holds the default value on creation for the is_expired field. + soraaccount.DefaultIsExpired = soraaccountDescIsExpired.Default.(bool) + soracachefileFields := schema.SoraCacheFile{}.Fields() + _ = soracachefileFields + // soracachefileDescTaskID is the schema descriptor for task_id field. + soracachefileDescTaskID := soracachefileFields[0].Descriptor() + // soracachefile.TaskIDValidator is a validator for the "task_id" field. It is called by the builders before save. + soracachefile.TaskIDValidator = soracachefileDescTaskID.Validators[0].(func(string) error) + // soracachefileDescMediaType is the schema descriptor for media_type field. + soracachefileDescMediaType := soracachefileFields[3].Descriptor() + // soracachefile.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + soracachefile.MediaTypeValidator = soracachefileDescMediaType.Validators[0].(func(string) error) + // soracachefileDescSizeBytes is the schema descriptor for size_bytes field. + soracachefileDescSizeBytes := soracachefileFields[7].Descriptor() + // soracachefile.DefaultSizeBytes holds the default value on creation for the size_bytes field. + soracachefile.DefaultSizeBytes = soracachefileDescSizeBytes.Default.(int64) + // soracachefileDescCreatedAt is the schema descriptor for created_at field. + soracachefileDescCreatedAt := soracachefileFields[8].Descriptor() + // soracachefile.DefaultCreatedAt holds the default value on creation for the created_at field. + soracachefile.DefaultCreatedAt = soracachefileDescCreatedAt.Default.(func() time.Time) + sorataskFields := schema.SoraTask{}.Fields() + _ = sorataskFields + // sorataskDescTaskID is the schema descriptor for task_id field. + sorataskDescTaskID := sorataskFields[0].Descriptor() + // soratask.TaskIDValidator is a validator for the "task_id" field. It is called by the builders before save. + soratask.TaskIDValidator = sorataskDescTaskID.Validators[0].(func(string) error) + // sorataskDescModel is the schema descriptor for model field. + sorataskDescModel := sorataskFields[2].Descriptor() + // soratask.ModelValidator is a validator for the "model" field. It is called by the builders before save. + soratask.ModelValidator = sorataskDescModel.Validators[0].(func(string) error) + // sorataskDescStatus is the schema descriptor for status field. + sorataskDescStatus := sorataskFields[4].Descriptor() + // soratask.DefaultStatus holds the default value on creation for the status field. + soratask.DefaultStatus = sorataskDescStatus.Default.(string) + // soratask.StatusValidator is a validator for the "status" field. It is called by the builders before save. + soratask.StatusValidator = sorataskDescStatus.Validators[0].(func(string) error) + // sorataskDescProgress is the schema descriptor for progress field. + sorataskDescProgress := sorataskFields[5].Descriptor() + // soratask.DefaultProgress holds the default value on creation for the progress field. + soratask.DefaultProgress = sorataskDescProgress.Default.(float64) + // sorataskDescRetryCount is the schema descriptor for retry_count field. + sorataskDescRetryCount := sorataskFields[8].Descriptor() + // soratask.DefaultRetryCount holds the default value on creation for the retry_count field. + soratask.DefaultRetryCount = sorataskDescRetryCount.Default.(int) + // sorataskDescCreatedAt is the schema descriptor for created_at field. + sorataskDescCreatedAt := sorataskFields[9].Descriptor() + // soratask.DefaultCreatedAt holds the default value on creation for the created_at field. + soratask.DefaultCreatedAt = sorataskDescCreatedAt.Default.(func() time.Time) + sorausagestatMixin := schema.SoraUsageStat{}.Mixin() + sorausagestatMixinFields0 := sorausagestatMixin[0].Fields() + _ = sorausagestatMixinFields0 + sorausagestatFields := schema.SoraUsageStat{}.Fields() + _ = sorausagestatFields + // sorausagestatDescCreatedAt is the schema descriptor for created_at field. + sorausagestatDescCreatedAt := sorausagestatMixinFields0[0].Descriptor() + // sorausagestat.DefaultCreatedAt holds the default value on creation for the created_at field. + sorausagestat.DefaultCreatedAt = sorausagestatDescCreatedAt.Default.(func() time.Time) + // sorausagestatDescUpdatedAt is the schema descriptor for updated_at field. + sorausagestatDescUpdatedAt := sorausagestatMixinFields0[1].Descriptor() + // sorausagestat.DefaultUpdatedAt holds the default value on creation for the updated_at field. + sorausagestat.DefaultUpdatedAt = sorausagestatDescUpdatedAt.Default.(func() time.Time) + // sorausagestat.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + sorausagestat.UpdateDefaultUpdatedAt = sorausagestatDescUpdatedAt.UpdateDefault.(func() time.Time) + // sorausagestatDescImageCount is the schema descriptor for image_count field. + sorausagestatDescImageCount := sorausagestatFields[1].Descriptor() + // sorausagestat.DefaultImageCount holds the default value on creation for the image_count field. + sorausagestat.DefaultImageCount = sorausagestatDescImageCount.Default.(int) + // sorausagestatDescVideoCount is the schema descriptor for video_count field. + sorausagestatDescVideoCount := sorausagestatFields[2].Descriptor() + // sorausagestat.DefaultVideoCount holds the default value on creation for the video_count field. + sorausagestat.DefaultVideoCount = sorausagestatDescVideoCount.Default.(int) + // sorausagestatDescErrorCount is the schema descriptor for error_count field. + sorausagestatDescErrorCount := sorausagestatFields[3].Descriptor() + // sorausagestat.DefaultErrorCount holds the default value on creation for the error_count field. + sorausagestat.DefaultErrorCount = sorausagestatDescErrorCount.Default.(int) + // sorausagestatDescTodayImageCount is the schema descriptor for today_image_count field. + sorausagestatDescTodayImageCount := sorausagestatFields[5].Descriptor() + // sorausagestat.DefaultTodayImageCount holds the default value on creation for the today_image_count field. + sorausagestat.DefaultTodayImageCount = sorausagestatDescTodayImageCount.Default.(int) + // sorausagestatDescTodayVideoCount is the schema descriptor for today_video_count field. + sorausagestatDescTodayVideoCount := sorausagestatFields[6].Descriptor() + // sorausagestat.DefaultTodayVideoCount holds the default value on creation for the today_video_count field. + sorausagestat.DefaultTodayVideoCount = sorausagestatDescTodayVideoCount.Default.(int) + // sorausagestatDescTodayErrorCount is the schema descriptor for today_error_count field. + sorausagestatDescTodayErrorCount := sorausagestatFields[7].Descriptor() + // sorausagestat.DefaultTodayErrorCount holds the default value on creation for the today_error_count field. + sorausagestat.DefaultTodayErrorCount = sorausagestatDescTodayErrorCount.Default.(int) + // sorausagestatDescConsecutiveErrorCount is the schema descriptor for consecutive_error_count field. + sorausagestatDescConsecutiveErrorCount := sorausagestatFields[9].Descriptor() + // sorausagestat.DefaultConsecutiveErrorCount holds the default value on creation for the consecutive_error_count field. + sorausagestat.DefaultConsecutiveErrorCount = sorausagestatDescConsecutiveErrorCount.Default.(int) usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin() usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields() _ = usagecleanuptaskMixinFields0 diff --git a/backend/ent/schema/sora_account.go b/backend/ent/schema/sora_account.go new file mode 100644 index 00000000..c40b4e3c --- /dev/null +++ b/backend/ent/schema/sora_account.go @@ -0,0 +1,115 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +// 每个文件对应一个数据库实体(表),定义其字段、边(关联)和索引。 +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// SoraAccount 定义 Sora 账号扩展表。 +type SoraAccount struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (SoraAccount) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "sora_accounts"}, + } +} + +// Mixin 返回该 schema 使用的混入组件。 +func (SoraAccount) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +// Fields 定义 SoraAccount 的字段。 +func (SoraAccount) Fields() []ent.Field { + return []ent.Field{ + field.Int64("account_id"). + Comment("关联 accounts 表的 ID"), + field.String("access_token"). + Optional(). + Nillable(), + field.String("session_token"). + Optional(). + Nillable(), + field.String("refresh_token"). + Optional(). + Nillable(), + field.String("client_id"). + Optional(). + Nillable(), + field.String("email"). + Optional(). + Nillable(), + field.String("username"). + Optional(). + Nillable(), + field.String("remark"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int("use_count"). + Default(0), + field.String("plan_type"). + Optional(). + Nillable(), + field.String("plan_title"). + Optional(). + Nillable(), + field.Time("subscription_end"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Bool("sora_supported"). + Default(false), + field.String("sora_invite_code"). + Optional(). + Nillable(), + field.Int("sora_redeemed_count"). + Default(0), + field.Int("sora_remaining_count"). + Default(0), + field.Int("sora_total_count"). + Default(0), + field.Time("sora_cooldown_until"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("cooled_until"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Bool("image_enabled"). + Default(true), + field.Bool("video_enabled"). + Default(true), + field.Int("image_concurrency"). + Default(-1), + field.Int("video_concurrency"). + Default(-1), + field.Bool("is_expired"). + Default(false), + } +} + +// Indexes 定义索引。 +func (SoraAccount) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("account_id").Unique(), + index.Fields("plan_type"), + index.Fields("sora_supported"), + index.Fields("image_enabled"), + index.Fields("video_enabled"), + } +} diff --git a/backend/ent/schema/sora_cache_file.go b/backend/ent/schema/sora_cache_file.go new file mode 100644 index 00000000..df398565 --- /dev/null +++ b/backend/ent/schema/sora_cache_file.go @@ -0,0 +1,60 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +// 每个文件对应一个数据库实体(表),定义其字段、边(关联)和索引。 +package schema + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// SoraCacheFile 定义 Sora 缓存文件表。 +type SoraCacheFile struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (SoraCacheFile) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "sora_cache_files"}, + } +} + +// Fields 定义 SoraCacheFile 的字段。 +func (SoraCacheFile) Fields() []ent.Field { + return []ent.Field{ + field.String("task_id"). + MaxLen(120). + Optional(). + Nillable(), + field.Int64("account_id"), + field.Int64("user_id"), + field.String("media_type"). + MaxLen(32), + field.String("original_url"). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("cache_path"). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("cache_url"). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int64("size_bytes"). + Default(0), + field.Time("created_at"). + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +// Indexes 定义索引。 +func (SoraCacheFile) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("account_id"), + index.Fields("user_id"), + index.Fields("media_type"), + } +} diff --git a/backend/ent/schema/sora_task.go b/backend/ent/schema/sora_task.go new file mode 100644 index 00000000..476580f2 --- /dev/null +++ b/backend/ent/schema/sora_task.go @@ -0,0 +1,70 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +// 每个文件对应一个数据库实体(表),定义其字段、边(关联)和索引。 +package schema + +import ( + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// SoraTask 定义 Sora 任务记录表。 +type SoraTask struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (SoraTask) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "sora_tasks"}, + } +} + +// Fields 定义 SoraTask 的字段。 +func (SoraTask) Fields() []ent.Field { + return []ent.Field{ + field.String("task_id"). + MaxLen(120). + Unique(), + field.Int64("account_id"), + field.String("model"). + MaxLen(120), + field.String("prompt"). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("status"). + MaxLen(32). + Default("processing"), + field.Float("progress"). + Default(0), + field.String("result_urls"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("error_message"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int("retry_count"). + Default(0), + field.Time("created_at"). + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("completed_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +// Indexes 定义索引。 +func (SoraTask) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("account_id"), + index.Fields("status"), + } +} diff --git a/backend/ent/schema/sora_usage_stat.go b/backend/ent/schema/sora_usage_stat.go new file mode 100644 index 00000000..2604e868 --- /dev/null +++ b/backend/ent/schema/sora_usage_stat.go @@ -0,0 +1,71 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +// 每个文件对应一个数据库实体(表),定义其字段、边(关联)和索引。 +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// SoraUsageStat 定义 Sora 调用统计表。 +type SoraUsageStat struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (SoraUsageStat) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "sora_usage_stats"}, + } +} + +// Mixin 返回该 schema 使用的混入组件。 +func (SoraUsageStat) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +// Fields 定义 SoraUsageStat 的字段。 +func (SoraUsageStat) Fields() []ent.Field { + return []ent.Field{ + field.Int64("account_id"). + Comment("关联 accounts 表的 ID"), + field.Int("image_count"). + Default(0), + field.Int("video_count"). + Default(0), + field.Int("error_count"). + Default(0), + field.Time("last_error_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Int("today_image_count"). + Default(0), + field.Int("today_video_count"). + Default(0), + field.Int("today_error_count"). + Default(0), + field.Time("today_date"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "date"}), + field.Int("consecutive_error_count"). + Default(0), + } +} + +// Indexes 定义索引。 +func (SoraUsageStat) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("account_id").Unique(), + index.Fields("today_date"), + } +} diff --git a/backend/ent/soraaccount.go b/backend/ent/soraaccount.go new file mode 100644 index 00000000..952eb638 --- /dev/null +++ b/backend/ent/soraaccount.go @@ -0,0 +1,422 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" +) + +// SoraAccount is the model entity for the SoraAccount schema. +type SoraAccount struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // 关联 accounts 表的 ID + AccountID int64 `json:"account_id,omitempty"` + // AccessToken holds the value of the "access_token" field. + AccessToken *string `json:"access_token,omitempty"` + // SessionToken holds the value of the "session_token" field. + SessionToken *string `json:"session_token,omitempty"` + // RefreshToken holds the value of the "refresh_token" field. + RefreshToken *string `json:"refresh_token,omitempty"` + // ClientID holds the value of the "client_id" field. + ClientID *string `json:"client_id,omitempty"` + // Email holds the value of the "email" field. + Email *string `json:"email,omitempty"` + // Username holds the value of the "username" field. + Username *string `json:"username,omitempty"` + // Remark holds the value of the "remark" field. + Remark *string `json:"remark,omitempty"` + // UseCount holds the value of the "use_count" field. + UseCount int `json:"use_count,omitempty"` + // PlanType holds the value of the "plan_type" field. + PlanType *string `json:"plan_type,omitempty"` + // PlanTitle holds the value of the "plan_title" field. + PlanTitle *string `json:"plan_title,omitempty"` + // SubscriptionEnd holds the value of the "subscription_end" field. + SubscriptionEnd *time.Time `json:"subscription_end,omitempty"` + // SoraSupported holds the value of the "sora_supported" field. + SoraSupported bool `json:"sora_supported,omitempty"` + // SoraInviteCode holds the value of the "sora_invite_code" field. + SoraInviteCode *string `json:"sora_invite_code,omitempty"` + // SoraRedeemedCount holds the value of the "sora_redeemed_count" field. + SoraRedeemedCount int `json:"sora_redeemed_count,omitempty"` + // SoraRemainingCount holds the value of the "sora_remaining_count" field. + SoraRemainingCount int `json:"sora_remaining_count,omitempty"` + // SoraTotalCount holds the value of the "sora_total_count" field. + SoraTotalCount int `json:"sora_total_count,omitempty"` + // SoraCooldownUntil holds the value of the "sora_cooldown_until" field. + SoraCooldownUntil *time.Time `json:"sora_cooldown_until,omitempty"` + // CooledUntil holds the value of the "cooled_until" field. + CooledUntil *time.Time `json:"cooled_until,omitempty"` + // ImageEnabled holds the value of the "image_enabled" field. + ImageEnabled bool `json:"image_enabled,omitempty"` + // VideoEnabled holds the value of the "video_enabled" field. + VideoEnabled bool `json:"video_enabled,omitempty"` + // ImageConcurrency holds the value of the "image_concurrency" field. + ImageConcurrency int `json:"image_concurrency,omitempty"` + // VideoConcurrency holds the value of the "video_concurrency" field. + VideoConcurrency int `json:"video_concurrency,omitempty"` + // IsExpired holds the value of the "is_expired" field. + IsExpired bool `json:"is_expired,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*SoraAccount) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case soraaccount.FieldSoraSupported, soraaccount.FieldImageEnabled, soraaccount.FieldVideoEnabled, soraaccount.FieldIsExpired: + values[i] = new(sql.NullBool) + case soraaccount.FieldID, soraaccount.FieldAccountID, soraaccount.FieldUseCount, soraaccount.FieldSoraRedeemedCount, soraaccount.FieldSoraRemainingCount, soraaccount.FieldSoraTotalCount, soraaccount.FieldImageConcurrency, soraaccount.FieldVideoConcurrency: + values[i] = new(sql.NullInt64) + case soraaccount.FieldAccessToken, soraaccount.FieldSessionToken, soraaccount.FieldRefreshToken, soraaccount.FieldClientID, soraaccount.FieldEmail, soraaccount.FieldUsername, soraaccount.FieldRemark, soraaccount.FieldPlanType, soraaccount.FieldPlanTitle, soraaccount.FieldSoraInviteCode: + values[i] = new(sql.NullString) + case soraaccount.FieldCreatedAt, soraaccount.FieldUpdatedAt, soraaccount.FieldSubscriptionEnd, soraaccount.FieldSoraCooldownUntil, soraaccount.FieldCooledUntil: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the SoraAccount fields. +func (_m *SoraAccount) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case soraaccount.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case soraaccount.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case soraaccount.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case soraaccount.FieldAccountID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field account_id", values[i]) + } else if value.Valid { + _m.AccountID = value.Int64 + } + case soraaccount.FieldAccessToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field access_token", values[i]) + } else if value.Valid { + _m.AccessToken = new(string) + *_m.AccessToken = value.String + } + case soraaccount.FieldSessionToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field session_token", values[i]) + } else if value.Valid { + _m.SessionToken = new(string) + *_m.SessionToken = value.String + } + case soraaccount.FieldRefreshToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field refresh_token", values[i]) + } else if value.Valid { + _m.RefreshToken = new(string) + *_m.RefreshToken = value.String + } + case soraaccount.FieldClientID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field client_id", values[i]) + } else if value.Valid { + _m.ClientID = new(string) + *_m.ClientID = value.String + } + case soraaccount.FieldEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field email", values[i]) + } else if value.Valid { + _m.Email = new(string) + *_m.Email = value.String + } + case soraaccount.FieldUsername: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field username", values[i]) + } else if value.Valid { + _m.Username = new(string) + *_m.Username = value.String + } + case soraaccount.FieldRemark: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field remark", values[i]) + } else if value.Valid { + _m.Remark = new(string) + *_m.Remark = value.String + } + case soraaccount.FieldUseCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field use_count", values[i]) + } else if value.Valid { + _m.UseCount = int(value.Int64) + } + case soraaccount.FieldPlanType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field plan_type", values[i]) + } else if value.Valid { + _m.PlanType = new(string) + *_m.PlanType = value.String + } + case soraaccount.FieldPlanTitle: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field plan_title", values[i]) + } else if value.Valid { + _m.PlanTitle = new(string) + *_m.PlanTitle = value.String + } + case soraaccount.FieldSubscriptionEnd: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field subscription_end", values[i]) + } else if value.Valid { + _m.SubscriptionEnd = new(time.Time) + *_m.SubscriptionEnd = value.Time + } + case soraaccount.FieldSoraSupported: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field sora_supported", values[i]) + } else if value.Valid { + _m.SoraSupported = value.Bool + } + case soraaccount.FieldSoraInviteCode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field sora_invite_code", values[i]) + } else if value.Valid { + _m.SoraInviteCode = new(string) + *_m.SoraInviteCode = value.String + } + case soraaccount.FieldSoraRedeemedCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_redeemed_count", values[i]) + } else if value.Valid { + _m.SoraRedeemedCount = int(value.Int64) + } + case soraaccount.FieldSoraRemainingCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_remaining_count", values[i]) + } else if value.Valid { + _m.SoraRemainingCount = int(value.Int64) + } + case soraaccount.FieldSoraTotalCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_total_count", values[i]) + } else if value.Valid { + _m.SoraTotalCount = int(value.Int64) + } + case soraaccount.FieldSoraCooldownUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field sora_cooldown_until", values[i]) + } else if value.Valid { + _m.SoraCooldownUntil = new(time.Time) + *_m.SoraCooldownUntil = value.Time + } + case soraaccount.FieldCooledUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field cooled_until", values[i]) + } else if value.Valid { + _m.CooledUntil = new(time.Time) + *_m.CooledUntil = value.Time + } + case soraaccount.FieldImageEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field image_enabled", values[i]) + } else if value.Valid { + _m.ImageEnabled = value.Bool + } + case soraaccount.FieldVideoEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field video_enabled", values[i]) + } else if value.Valid { + _m.VideoEnabled = value.Bool + } + case soraaccount.FieldImageConcurrency: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field image_concurrency", values[i]) + } else if value.Valid { + _m.ImageConcurrency = int(value.Int64) + } + case soraaccount.FieldVideoConcurrency: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field video_concurrency", values[i]) + } else if value.Valid { + _m.VideoConcurrency = int(value.Int64) + } + case soraaccount.FieldIsExpired: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field is_expired", values[i]) + } else if value.Valid { + _m.IsExpired = value.Bool + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the SoraAccount. +// This includes values selected through modifiers, order, etc. +func (_m *SoraAccount) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this SoraAccount. +// Note that you need to call SoraAccount.Unwrap() before calling this method if this SoraAccount +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *SoraAccount) Update() *SoraAccountUpdateOne { + return NewSoraAccountClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the SoraAccount entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *SoraAccount) Unwrap() *SoraAccount { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: SoraAccount is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *SoraAccount) String() string { + var builder strings.Builder + builder.WriteString("SoraAccount(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("account_id=") + builder.WriteString(fmt.Sprintf("%v", _m.AccountID)) + builder.WriteString(", ") + if v := _m.AccessToken; v != nil { + builder.WriteString("access_token=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.SessionToken; v != nil { + builder.WriteString("session_token=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.RefreshToken; v != nil { + builder.WriteString("refresh_token=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ClientID; v != nil { + builder.WriteString("client_id=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.Email; v != nil { + builder.WriteString("email=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.Username; v != nil { + builder.WriteString("username=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.Remark; v != nil { + builder.WriteString("remark=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("use_count=") + builder.WriteString(fmt.Sprintf("%v", _m.UseCount)) + builder.WriteString(", ") + if v := _m.PlanType; v != nil { + builder.WriteString("plan_type=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.PlanTitle; v != nil { + builder.WriteString("plan_title=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.SubscriptionEnd; v != nil { + builder.WriteString("subscription_end=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("sora_supported=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraSupported)) + builder.WriteString(", ") + if v := _m.SoraInviteCode; v != nil { + builder.WriteString("sora_invite_code=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("sora_redeemed_count=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraRedeemedCount)) + builder.WriteString(", ") + builder.WriteString("sora_remaining_count=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraRemainingCount)) + builder.WriteString(", ") + builder.WriteString("sora_total_count=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraTotalCount)) + builder.WriteString(", ") + if v := _m.SoraCooldownUntil; v != nil { + builder.WriteString("sora_cooldown_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.CooledUntil; v != nil { + builder.WriteString("cooled_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("image_enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageEnabled)) + builder.WriteString(", ") + builder.WriteString("video_enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.VideoEnabled)) + builder.WriteString(", ") + builder.WriteString("image_concurrency=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageConcurrency)) + builder.WriteString(", ") + builder.WriteString("video_concurrency=") + builder.WriteString(fmt.Sprintf("%v", _m.VideoConcurrency)) + builder.WriteString(", ") + builder.WriteString("is_expired=") + builder.WriteString(fmt.Sprintf("%v", _m.IsExpired)) + builder.WriteByte(')') + return builder.String() +} + +// SoraAccounts is a parsable slice of SoraAccount. +type SoraAccounts []*SoraAccount diff --git a/backend/ent/soraaccount/soraaccount.go b/backend/ent/soraaccount/soraaccount.go new file mode 100644 index 00000000..8f11c5e3 --- /dev/null +++ b/backend/ent/soraaccount/soraaccount.go @@ -0,0 +1,278 @@ +// Code generated by ent, DO NOT EDIT. + +package soraaccount + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the soraaccount type in the database. + Label = "sora_account" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldAccountID holds the string denoting the account_id field in the database. + FieldAccountID = "account_id" + // FieldAccessToken holds the string denoting the access_token field in the database. + FieldAccessToken = "access_token" + // FieldSessionToken holds the string denoting the session_token field in the database. + FieldSessionToken = "session_token" + // FieldRefreshToken holds the string denoting the refresh_token field in the database. + FieldRefreshToken = "refresh_token" + // FieldClientID holds the string denoting the client_id field in the database. + FieldClientID = "client_id" + // FieldEmail holds the string denoting the email field in the database. + FieldEmail = "email" + // FieldUsername holds the string denoting the username field in the database. + FieldUsername = "username" + // FieldRemark holds the string denoting the remark field in the database. + FieldRemark = "remark" + // FieldUseCount holds the string denoting the use_count field in the database. + FieldUseCount = "use_count" + // FieldPlanType holds the string denoting the plan_type field in the database. + FieldPlanType = "plan_type" + // FieldPlanTitle holds the string denoting the plan_title field in the database. + FieldPlanTitle = "plan_title" + // FieldSubscriptionEnd holds the string denoting the subscription_end field in the database. + FieldSubscriptionEnd = "subscription_end" + // FieldSoraSupported holds the string denoting the sora_supported field in the database. + FieldSoraSupported = "sora_supported" + // FieldSoraInviteCode holds the string denoting the sora_invite_code field in the database. + FieldSoraInviteCode = "sora_invite_code" + // FieldSoraRedeemedCount holds the string denoting the sora_redeemed_count field in the database. + FieldSoraRedeemedCount = "sora_redeemed_count" + // FieldSoraRemainingCount holds the string denoting the sora_remaining_count field in the database. + FieldSoraRemainingCount = "sora_remaining_count" + // FieldSoraTotalCount holds the string denoting the sora_total_count field in the database. + FieldSoraTotalCount = "sora_total_count" + // FieldSoraCooldownUntil holds the string denoting the sora_cooldown_until field in the database. + FieldSoraCooldownUntil = "sora_cooldown_until" + // FieldCooledUntil holds the string denoting the cooled_until field in the database. + FieldCooledUntil = "cooled_until" + // FieldImageEnabled holds the string denoting the image_enabled field in the database. + FieldImageEnabled = "image_enabled" + // FieldVideoEnabled holds the string denoting the video_enabled field in the database. + FieldVideoEnabled = "video_enabled" + // FieldImageConcurrency holds the string denoting the image_concurrency field in the database. + FieldImageConcurrency = "image_concurrency" + // FieldVideoConcurrency holds the string denoting the video_concurrency field in the database. + FieldVideoConcurrency = "video_concurrency" + // FieldIsExpired holds the string denoting the is_expired field in the database. + FieldIsExpired = "is_expired" + // Table holds the table name of the soraaccount in the database. + Table = "sora_accounts" +) + +// Columns holds all SQL columns for soraaccount fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldAccountID, + FieldAccessToken, + FieldSessionToken, + FieldRefreshToken, + FieldClientID, + FieldEmail, + FieldUsername, + FieldRemark, + FieldUseCount, + FieldPlanType, + FieldPlanTitle, + FieldSubscriptionEnd, + FieldSoraSupported, + FieldSoraInviteCode, + FieldSoraRedeemedCount, + FieldSoraRemainingCount, + FieldSoraTotalCount, + FieldSoraCooldownUntil, + FieldCooledUntil, + FieldImageEnabled, + FieldVideoEnabled, + FieldImageConcurrency, + FieldVideoConcurrency, + FieldIsExpired, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultUseCount holds the default value on creation for the "use_count" field. + DefaultUseCount int + // DefaultSoraSupported holds the default value on creation for the "sora_supported" field. + DefaultSoraSupported bool + // DefaultSoraRedeemedCount holds the default value on creation for the "sora_redeemed_count" field. + DefaultSoraRedeemedCount int + // DefaultSoraRemainingCount holds the default value on creation for the "sora_remaining_count" field. + DefaultSoraRemainingCount int + // DefaultSoraTotalCount holds the default value on creation for the "sora_total_count" field. + DefaultSoraTotalCount int + // DefaultImageEnabled holds the default value on creation for the "image_enabled" field. + DefaultImageEnabled bool + // DefaultVideoEnabled holds the default value on creation for the "video_enabled" field. + DefaultVideoEnabled bool + // DefaultImageConcurrency holds the default value on creation for the "image_concurrency" field. + DefaultImageConcurrency int + // DefaultVideoConcurrency holds the default value on creation for the "video_concurrency" field. + DefaultVideoConcurrency int + // DefaultIsExpired holds the default value on creation for the "is_expired" field. + DefaultIsExpired bool +) + +// OrderOption defines the ordering options for the SoraAccount queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByAccountID orders the results by the account_id field. +func ByAccountID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccountID, opts...).ToFunc() +} + +// ByAccessToken orders the results by the access_token field. +func ByAccessToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccessToken, opts...).ToFunc() +} + +// BySessionToken orders the results by the session_token field. +func BySessionToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSessionToken, opts...).ToFunc() +} + +// ByRefreshToken orders the results by the refresh_token field. +func ByRefreshToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRefreshToken, opts...).ToFunc() +} + +// ByClientID orders the results by the client_id field. +func ByClientID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClientID, opts...).ToFunc() +} + +// ByEmail orders the results by the email field. +func ByEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEmail, opts...).ToFunc() +} + +// ByUsername orders the results by the username field. +func ByUsername(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsername, opts...).ToFunc() +} + +// ByRemark orders the results by the remark field. +func ByRemark(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRemark, opts...).ToFunc() +} + +// ByUseCount orders the results by the use_count field. +func ByUseCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUseCount, opts...).ToFunc() +} + +// ByPlanType orders the results by the plan_type field. +func ByPlanType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPlanType, opts...).ToFunc() +} + +// ByPlanTitle orders the results by the plan_title field. +func ByPlanTitle(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPlanTitle, opts...).ToFunc() +} + +// BySubscriptionEnd orders the results by the subscription_end field. +func BySubscriptionEnd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSubscriptionEnd, opts...).ToFunc() +} + +// BySoraSupported orders the results by the sora_supported field. +func BySoraSupported(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraSupported, opts...).ToFunc() +} + +// BySoraInviteCode orders the results by the sora_invite_code field. +func BySoraInviteCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraInviteCode, opts...).ToFunc() +} + +// BySoraRedeemedCount orders the results by the sora_redeemed_count field. +func BySoraRedeemedCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraRedeemedCount, opts...).ToFunc() +} + +// BySoraRemainingCount orders the results by the sora_remaining_count field. +func BySoraRemainingCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraRemainingCount, opts...).ToFunc() +} + +// BySoraTotalCount orders the results by the sora_total_count field. +func BySoraTotalCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraTotalCount, opts...).ToFunc() +} + +// BySoraCooldownUntil orders the results by the sora_cooldown_until field. +func BySoraCooldownUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraCooldownUntil, opts...).ToFunc() +} + +// ByCooledUntil orders the results by the cooled_until field. +func ByCooledUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCooledUntil, opts...).ToFunc() +} + +// ByImageEnabled orders the results by the image_enabled field. +func ByImageEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageEnabled, opts...).ToFunc() +} + +// ByVideoEnabled orders the results by the video_enabled field. +func ByVideoEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVideoEnabled, opts...).ToFunc() +} + +// ByImageConcurrency orders the results by the image_concurrency field. +func ByImageConcurrency(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageConcurrency, opts...).ToFunc() +} + +// ByVideoConcurrency orders the results by the video_concurrency field. +func ByVideoConcurrency(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVideoConcurrency, opts...).ToFunc() +} + +// ByIsExpired orders the results by the is_expired field. +func ByIsExpired(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsExpired, opts...).ToFunc() +} diff --git a/backend/ent/soraaccount/where.go b/backend/ent/soraaccount/where.go new file mode 100644 index 00000000..3cc12398 --- /dev/null +++ b/backend/ent/soraaccount/where.go @@ -0,0 +1,1500 @@ +// Code generated by ent, DO NOT EDIT. + +package soraaccount + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// AccountID applies equality check predicate on the "account_id" field. It's identical to AccountIDEQ. +func AccountID(v int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldAccountID, v)) +} + +// AccessToken applies equality check predicate on the "access_token" field. It's identical to AccessTokenEQ. +func AccessToken(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldAccessToken, v)) +} + +// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ. +func SessionToken(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSessionToken, v)) +} + +// RefreshToken applies equality check predicate on the "refresh_token" field. It's identical to RefreshTokenEQ. +func RefreshToken(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldRefreshToken, v)) +} + +// ClientID applies equality check predicate on the "client_id" field. It's identical to ClientIDEQ. +func ClientID(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldClientID, v)) +} + +// Email applies equality check predicate on the "email" field. It's identical to EmailEQ. +func Email(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldEmail, v)) +} + +// Username applies equality check predicate on the "username" field. It's identical to UsernameEQ. +func Username(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldUsername, v)) +} + +// Remark applies equality check predicate on the "remark" field. It's identical to RemarkEQ. +func Remark(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldRemark, v)) +} + +// UseCount applies equality check predicate on the "use_count" field. It's identical to UseCountEQ. +func UseCount(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldUseCount, v)) +} + +// PlanType applies equality check predicate on the "plan_type" field. It's identical to PlanTypeEQ. +func PlanType(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldPlanType, v)) +} + +// PlanTitle applies equality check predicate on the "plan_title" field. It's identical to PlanTitleEQ. +func PlanTitle(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldPlanTitle, v)) +} + +// SubscriptionEnd applies equality check predicate on the "subscription_end" field. It's identical to SubscriptionEndEQ. +func SubscriptionEnd(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSubscriptionEnd, v)) +} + +// SoraSupported applies equality check predicate on the "sora_supported" field. It's identical to SoraSupportedEQ. +func SoraSupported(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraSupported, v)) +} + +// SoraInviteCode applies equality check predicate on the "sora_invite_code" field. It's identical to SoraInviteCodeEQ. +func SoraInviteCode(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraInviteCode, v)) +} + +// SoraRedeemedCount applies equality check predicate on the "sora_redeemed_count" field. It's identical to SoraRedeemedCountEQ. +func SoraRedeemedCount(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraRedeemedCount, v)) +} + +// SoraRemainingCount applies equality check predicate on the "sora_remaining_count" field. It's identical to SoraRemainingCountEQ. +func SoraRemainingCount(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraRemainingCount, v)) +} + +// SoraTotalCount applies equality check predicate on the "sora_total_count" field. It's identical to SoraTotalCountEQ. +func SoraTotalCount(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraTotalCount, v)) +} + +// SoraCooldownUntil applies equality check predicate on the "sora_cooldown_until" field. It's identical to SoraCooldownUntilEQ. +func SoraCooldownUntil(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraCooldownUntil, v)) +} + +// CooledUntil applies equality check predicate on the "cooled_until" field. It's identical to CooledUntilEQ. +func CooledUntil(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldCooledUntil, v)) +} + +// ImageEnabled applies equality check predicate on the "image_enabled" field. It's identical to ImageEnabledEQ. +func ImageEnabled(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldImageEnabled, v)) +} + +// VideoEnabled applies equality check predicate on the "video_enabled" field. It's identical to VideoEnabledEQ. +func VideoEnabled(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldVideoEnabled, v)) +} + +// ImageConcurrency applies equality check predicate on the "image_concurrency" field. It's identical to ImageConcurrencyEQ. +func ImageConcurrency(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldImageConcurrency, v)) +} + +// VideoConcurrency applies equality check predicate on the "video_concurrency" field. It's identical to VideoConcurrencyEQ. +func VideoConcurrency(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldVideoConcurrency, v)) +} + +// IsExpired applies equality check predicate on the "is_expired" field. It's identical to IsExpiredEQ. +func IsExpired(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldIsExpired, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// AccountIDEQ applies the EQ predicate on the "account_id" field. +func AccountIDEQ(v int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldAccountID, v)) +} + +// AccountIDNEQ applies the NEQ predicate on the "account_id" field. +func AccountIDNEQ(v int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldAccountID, v)) +} + +// AccountIDIn applies the In predicate on the "account_id" field. +func AccountIDIn(vs ...int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldAccountID, vs...)) +} + +// AccountIDNotIn applies the NotIn predicate on the "account_id" field. +func AccountIDNotIn(vs ...int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldAccountID, vs...)) +} + +// AccountIDGT applies the GT predicate on the "account_id" field. +func AccountIDGT(v int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldAccountID, v)) +} + +// AccountIDGTE applies the GTE predicate on the "account_id" field. +func AccountIDGTE(v int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldAccountID, v)) +} + +// AccountIDLT applies the LT predicate on the "account_id" field. +func AccountIDLT(v int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldAccountID, v)) +} + +// AccountIDLTE applies the LTE predicate on the "account_id" field. +func AccountIDLTE(v int64) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldAccountID, v)) +} + +// AccessTokenEQ applies the EQ predicate on the "access_token" field. +func AccessTokenEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldAccessToken, v)) +} + +// AccessTokenNEQ applies the NEQ predicate on the "access_token" field. +func AccessTokenNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldAccessToken, v)) +} + +// AccessTokenIn applies the In predicate on the "access_token" field. +func AccessTokenIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldAccessToken, vs...)) +} + +// AccessTokenNotIn applies the NotIn predicate on the "access_token" field. +func AccessTokenNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldAccessToken, vs...)) +} + +// AccessTokenGT applies the GT predicate on the "access_token" field. +func AccessTokenGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldAccessToken, v)) +} + +// AccessTokenGTE applies the GTE predicate on the "access_token" field. +func AccessTokenGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldAccessToken, v)) +} + +// AccessTokenLT applies the LT predicate on the "access_token" field. +func AccessTokenLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldAccessToken, v)) +} + +// AccessTokenLTE applies the LTE predicate on the "access_token" field. +func AccessTokenLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldAccessToken, v)) +} + +// AccessTokenContains applies the Contains predicate on the "access_token" field. +func AccessTokenContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldAccessToken, v)) +} + +// AccessTokenHasPrefix applies the HasPrefix predicate on the "access_token" field. +func AccessTokenHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldAccessToken, v)) +} + +// AccessTokenHasSuffix applies the HasSuffix predicate on the "access_token" field. +func AccessTokenHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldAccessToken, v)) +} + +// AccessTokenIsNil applies the IsNil predicate on the "access_token" field. +func AccessTokenIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldAccessToken)) +} + +// AccessTokenNotNil applies the NotNil predicate on the "access_token" field. +func AccessTokenNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldAccessToken)) +} + +// AccessTokenEqualFold applies the EqualFold predicate on the "access_token" field. +func AccessTokenEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldAccessToken, v)) +} + +// AccessTokenContainsFold applies the ContainsFold predicate on the "access_token" field. +func AccessTokenContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldAccessToken, v)) +} + +// SessionTokenEQ applies the EQ predicate on the "session_token" field. +func SessionTokenEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSessionToken, v)) +} + +// SessionTokenNEQ applies the NEQ predicate on the "session_token" field. +func SessionTokenNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSessionToken, v)) +} + +// SessionTokenIn applies the In predicate on the "session_token" field. +func SessionTokenIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldSessionToken, vs...)) +} + +// SessionTokenNotIn applies the NotIn predicate on the "session_token" field. +func SessionTokenNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldSessionToken, vs...)) +} + +// SessionTokenGT applies the GT predicate on the "session_token" field. +func SessionTokenGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldSessionToken, v)) +} + +// SessionTokenGTE applies the GTE predicate on the "session_token" field. +func SessionTokenGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldSessionToken, v)) +} + +// SessionTokenLT applies the LT predicate on the "session_token" field. +func SessionTokenLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldSessionToken, v)) +} + +// SessionTokenLTE applies the LTE predicate on the "session_token" field. +func SessionTokenLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldSessionToken, v)) +} + +// SessionTokenContains applies the Contains predicate on the "session_token" field. +func SessionTokenContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldSessionToken, v)) +} + +// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field. +func SessionTokenHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldSessionToken, v)) +} + +// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field. +func SessionTokenHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldSessionToken, v)) +} + +// SessionTokenIsNil applies the IsNil predicate on the "session_token" field. +func SessionTokenIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldSessionToken)) +} + +// SessionTokenNotNil applies the NotNil predicate on the "session_token" field. +func SessionTokenNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldSessionToken)) +} + +// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field. +func SessionTokenEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldSessionToken, v)) +} + +// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field. +func SessionTokenContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldSessionToken, v)) +} + +// RefreshTokenEQ applies the EQ predicate on the "refresh_token" field. +func RefreshTokenEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldRefreshToken, v)) +} + +// RefreshTokenNEQ applies the NEQ predicate on the "refresh_token" field. +func RefreshTokenNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldRefreshToken, v)) +} + +// RefreshTokenIn applies the In predicate on the "refresh_token" field. +func RefreshTokenIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldRefreshToken, vs...)) +} + +// RefreshTokenNotIn applies the NotIn predicate on the "refresh_token" field. +func RefreshTokenNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldRefreshToken, vs...)) +} + +// RefreshTokenGT applies the GT predicate on the "refresh_token" field. +func RefreshTokenGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldRefreshToken, v)) +} + +// RefreshTokenGTE applies the GTE predicate on the "refresh_token" field. +func RefreshTokenGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldRefreshToken, v)) +} + +// RefreshTokenLT applies the LT predicate on the "refresh_token" field. +func RefreshTokenLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldRefreshToken, v)) +} + +// RefreshTokenLTE applies the LTE predicate on the "refresh_token" field. +func RefreshTokenLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldRefreshToken, v)) +} + +// RefreshTokenContains applies the Contains predicate on the "refresh_token" field. +func RefreshTokenContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldRefreshToken, v)) +} + +// RefreshTokenHasPrefix applies the HasPrefix predicate on the "refresh_token" field. +func RefreshTokenHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldRefreshToken, v)) +} + +// RefreshTokenHasSuffix applies the HasSuffix predicate on the "refresh_token" field. +func RefreshTokenHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldRefreshToken, v)) +} + +// RefreshTokenIsNil applies the IsNil predicate on the "refresh_token" field. +func RefreshTokenIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldRefreshToken)) +} + +// RefreshTokenNotNil applies the NotNil predicate on the "refresh_token" field. +func RefreshTokenNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldRefreshToken)) +} + +// RefreshTokenEqualFold applies the EqualFold predicate on the "refresh_token" field. +func RefreshTokenEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldRefreshToken, v)) +} + +// RefreshTokenContainsFold applies the ContainsFold predicate on the "refresh_token" field. +func RefreshTokenContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldRefreshToken, v)) +} + +// ClientIDEQ applies the EQ predicate on the "client_id" field. +func ClientIDEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldClientID, v)) +} + +// ClientIDNEQ applies the NEQ predicate on the "client_id" field. +func ClientIDNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldClientID, v)) +} + +// ClientIDIn applies the In predicate on the "client_id" field. +func ClientIDIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldClientID, vs...)) +} + +// ClientIDNotIn applies the NotIn predicate on the "client_id" field. +func ClientIDNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldClientID, vs...)) +} + +// ClientIDGT applies the GT predicate on the "client_id" field. +func ClientIDGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldClientID, v)) +} + +// ClientIDGTE applies the GTE predicate on the "client_id" field. +func ClientIDGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldClientID, v)) +} + +// ClientIDLT applies the LT predicate on the "client_id" field. +func ClientIDLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldClientID, v)) +} + +// ClientIDLTE applies the LTE predicate on the "client_id" field. +func ClientIDLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldClientID, v)) +} + +// ClientIDContains applies the Contains predicate on the "client_id" field. +func ClientIDContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldClientID, v)) +} + +// ClientIDHasPrefix applies the HasPrefix predicate on the "client_id" field. +func ClientIDHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldClientID, v)) +} + +// ClientIDHasSuffix applies the HasSuffix predicate on the "client_id" field. +func ClientIDHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldClientID, v)) +} + +// ClientIDIsNil applies the IsNil predicate on the "client_id" field. +func ClientIDIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldClientID)) +} + +// ClientIDNotNil applies the NotNil predicate on the "client_id" field. +func ClientIDNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldClientID)) +} + +// ClientIDEqualFold applies the EqualFold predicate on the "client_id" field. +func ClientIDEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldClientID, v)) +} + +// ClientIDContainsFold applies the ContainsFold predicate on the "client_id" field. +func ClientIDContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldClientID, v)) +} + +// EmailEQ applies the EQ predicate on the "email" field. +func EmailEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldEmail, v)) +} + +// EmailNEQ applies the NEQ predicate on the "email" field. +func EmailNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldEmail, v)) +} + +// EmailIn applies the In predicate on the "email" field. +func EmailIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldEmail, vs...)) +} + +// EmailNotIn applies the NotIn predicate on the "email" field. +func EmailNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldEmail, vs...)) +} + +// EmailGT applies the GT predicate on the "email" field. +func EmailGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldEmail, v)) +} + +// EmailGTE applies the GTE predicate on the "email" field. +func EmailGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldEmail, v)) +} + +// EmailLT applies the LT predicate on the "email" field. +func EmailLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldEmail, v)) +} + +// EmailLTE applies the LTE predicate on the "email" field. +func EmailLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldEmail, v)) +} + +// EmailContains applies the Contains predicate on the "email" field. +func EmailContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldEmail, v)) +} + +// EmailHasPrefix applies the HasPrefix predicate on the "email" field. +func EmailHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldEmail, v)) +} + +// EmailHasSuffix applies the HasSuffix predicate on the "email" field. +func EmailHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldEmail, v)) +} + +// EmailIsNil applies the IsNil predicate on the "email" field. +func EmailIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldEmail)) +} + +// EmailNotNil applies the NotNil predicate on the "email" field. +func EmailNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldEmail)) +} + +// EmailEqualFold applies the EqualFold predicate on the "email" field. +func EmailEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldEmail, v)) +} + +// EmailContainsFold applies the ContainsFold predicate on the "email" field. +func EmailContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldEmail, v)) +} + +// UsernameEQ applies the EQ predicate on the "username" field. +func UsernameEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldUsername, v)) +} + +// UsernameNEQ applies the NEQ predicate on the "username" field. +func UsernameNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldUsername, v)) +} + +// UsernameIn applies the In predicate on the "username" field. +func UsernameIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldUsername, vs...)) +} + +// UsernameNotIn applies the NotIn predicate on the "username" field. +func UsernameNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldUsername, vs...)) +} + +// UsernameGT applies the GT predicate on the "username" field. +func UsernameGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldUsername, v)) +} + +// UsernameGTE applies the GTE predicate on the "username" field. +func UsernameGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldUsername, v)) +} + +// UsernameLT applies the LT predicate on the "username" field. +func UsernameLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldUsername, v)) +} + +// UsernameLTE applies the LTE predicate on the "username" field. +func UsernameLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldUsername, v)) +} + +// UsernameContains applies the Contains predicate on the "username" field. +func UsernameContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldUsername, v)) +} + +// UsernameHasPrefix applies the HasPrefix predicate on the "username" field. +func UsernameHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldUsername, v)) +} + +// UsernameHasSuffix applies the HasSuffix predicate on the "username" field. +func UsernameHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldUsername, v)) +} + +// UsernameIsNil applies the IsNil predicate on the "username" field. +func UsernameIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldUsername)) +} + +// UsernameNotNil applies the NotNil predicate on the "username" field. +func UsernameNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldUsername)) +} + +// UsernameEqualFold applies the EqualFold predicate on the "username" field. +func UsernameEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldUsername, v)) +} + +// UsernameContainsFold applies the ContainsFold predicate on the "username" field. +func UsernameContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldUsername, v)) +} + +// RemarkEQ applies the EQ predicate on the "remark" field. +func RemarkEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldRemark, v)) +} + +// RemarkNEQ applies the NEQ predicate on the "remark" field. +func RemarkNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldRemark, v)) +} + +// RemarkIn applies the In predicate on the "remark" field. +func RemarkIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldRemark, vs...)) +} + +// RemarkNotIn applies the NotIn predicate on the "remark" field. +func RemarkNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldRemark, vs...)) +} + +// RemarkGT applies the GT predicate on the "remark" field. +func RemarkGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldRemark, v)) +} + +// RemarkGTE applies the GTE predicate on the "remark" field. +func RemarkGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldRemark, v)) +} + +// RemarkLT applies the LT predicate on the "remark" field. +func RemarkLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldRemark, v)) +} + +// RemarkLTE applies the LTE predicate on the "remark" field. +func RemarkLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldRemark, v)) +} + +// RemarkContains applies the Contains predicate on the "remark" field. +func RemarkContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldRemark, v)) +} + +// RemarkHasPrefix applies the HasPrefix predicate on the "remark" field. +func RemarkHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldRemark, v)) +} + +// RemarkHasSuffix applies the HasSuffix predicate on the "remark" field. +func RemarkHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldRemark, v)) +} + +// RemarkIsNil applies the IsNil predicate on the "remark" field. +func RemarkIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldRemark)) +} + +// RemarkNotNil applies the NotNil predicate on the "remark" field. +func RemarkNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldRemark)) +} + +// RemarkEqualFold applies the EqualFold predicate on the "remark" field. +func RemarkEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldRemark, v)) +} + +// RemarkContainsFold applies the ContainsFold predicate on the "remark" field. +func RemarkContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldRemark, v)) +} + +// UseCountEQ applies the EQ predicate on the "use_count" field. +func UseCountEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldUseCount, v)) +} + +// UseCountNEQ applies the NEQ predicate on the "use_count" field. +func UseCountNEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldUseCount, v)) +} + +// UseCountIn applies the In predicate on the "use_count" field. +func UseCountIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldUseCount, vs...)) +} + +// UseCountNotIn applies the NotIn predicate on the "use_count" field. +func UseCountNotIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldUseCount, vs...)) +} + +// UseCountGT applies the GT predicate on the "use_count" field. +func UseCountGT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldUseCount, v)) +} + +// UseCountGTE applies the GTE predicate on the "use_count" field. +func UseCountGTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldUseCount, v)) +} + +// UseCountLT applies the LT predicate on the "use_count" field. +func UseCountLT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldUseCount, v)) +} + +// UseCountLTE applies the LTE predicate on the "use_count" field. +func UseCountLTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldUseCount, v)) +} + +// PlanTypeEQ applies the EQ predicate on the "plan_type" field. +func PlanTypeEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldPlanType, v)) +} + +// PlanTypeNEQ applies the NEQ predicate on the "plan_type" field. +func PlanTypeNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldPlanType, v)) +} + +// PlanTypeIn applies the In predicate on the "plan_type" field. +func PlanTypeIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldPlanType, vs...)) +} + +// PlanTypeNotIn applies the NotIn predicate on the "plan_type" field. +func PlanTypeNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldPlanType, vs...)) +} + +// PlanTypeGT applies the GT predicate on the "plan_type" field. +func PlanTypeGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldPlanType, v)) +} + +// PlanTypeGTE applies the GTE predicate on the "plan_type" field. +func PlanTypeGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldPlanType, v)) +} + +// PlanTypeLT applies the LT predicate on the "plan_type" field. +func PlanTypeLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldPlanType, v)) +} + +// PlanTypeLTE applies the LTE predicate on the "plan_type" field. +func PlanTypeLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldPlanType, v)) +} + +// PlanTypeContains applies the Contains predicate on the "plan_type" field. +func PlanTypeContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldPlanType, v)) +} + +// PlanTypeHasPrefix applies the HasPrefix predicate on the "plan_type" field. +func PlanTypeHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldPlanType, v)) +} + +// PlanTypeHasSuffix applies the HasSuffix predicate on the "plan_type" field. +func PlanTypeHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldPlanType, v)) +} + +// PlanTypeIsNil applies the IsNil predicate on the "plan_type" field. +func PlanTypeIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldPlanType)) +} + +// PlanTypeNotNil applies the NotNil predicate on the "plan_type" field. +func PlanTypeNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldPlanType)) +} + +// PlanTypeEqualFold applies the EqualFold predicate on the "plan_type" field. +func PlanTypeEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldPlanType, v)) +} + +// PlanTypeContainsFold applies the ContainsFold predicate on the "plan_type" field. +func PlanTypeContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldPlanType, v)) +} + +// PlanTitleEQ applies the EQ predicate on the "plan_title" field. +func PlanTitleEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldPlanTitle, v)) +} + +// PlanTitleNEQ applies the NEQ predicate on the "plan_title" field. +func PlanTitleNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldPlanTitle, v)) +} + +// PlanTitleIn applies the In predicate on the "plan_title" field. +func PlanTitleIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldPlanTitle, vs...)) +} + +// PlanTitleNotIn applies the NotIn predicate on the "plan_title" field. +func PlanTitleNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldPlanTitle, vs...)) +} + +// PlanTitleGT applies the GT predicate on the "plan_title" field. +func PlanTitleGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldPlanTitle, v)) +} + +// PlanTitleGTE applies the GTE predicate on the "plan_title" field. +func PlanTitleGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldPlanTitle, v)) +} + +// PlanTitleLT applies the LT predicate on the "plan_title" field. +func PlanTitleLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldPlanTitle, v)) +} + +// PlanTitleLTE applies the LTE predicate on the "plan_title" field. +func PlanTitleLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldPlanTitle, v)) +} + +// PlanTitleContains applies the Contains predicate on the "plan_title" field. +func PlanTitleContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldPlanTitle, v)) +} + +// PlanTitleHasPrefix applies the HasPrefix predicate on the "plan_title" field. +func PlanTitleHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldPlanTitle, v)) +} + +// PlanTitleHasSuffix applies the HasSuffix predicate on the "plan_title" field. +func PlanTitleHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldPlanTitle, v)) +} + +// PlanTitleIsNil applies the IsNil predicate on the "plan_title" field. +func PlanTitleIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldPlanTitle)) +} + +// PlanTitleNotNil applies the NotNil predicate on the "plan_title" field. +func PlanTitleNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldPlanTitle)) +} + +// PlanTitleEqualFold applies the EqualFold predicate on the "plan_title" field. +func PlanTitleEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldPlanTitle, v)) +} + +// PlanTitleContainsFold applies the ContainsFold predicate on the "plan_title" field. +func PlanTitleContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldPlanTitle, v)) +} + +// SubscriptionEndEQ applies the EQ predicate on the "subscription_end" field. +func SubscriptionEndEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSubscriptionEnd, v)) +} + +// SubscriptionEndNEQ applies the NEQ predicate on the "subscription_end" field. +func SubscriptionEndNEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSubscriptionEnd, v)) +} + +// SubscriptionEndIn applies the In predicate on the "subscription_end" field. +func SubscriptionEndIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldSubscriptionEnd, vs...)) +} + +// SubscriptionEndNotIn applies the NotIn predicate on the "subscription_end" field. +func SubscriptionEndNotIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldSubscriptionEnd, vs...)) +} + +// SubscriptionEndGT applies the GT predicate on the "subscription_end" field. +func SubscriptionEndGT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldSubscriptionEnd, v)) +} + +// SubscriptionEndGTE applies the GTE predicate on the "subscription_end" field. +func SubscriptionEndGTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldSubscriptionEnd, v)) +} + +// SubscriptionEndLT applies the LT predicate on the "subscription_end" field. +func SubscriptionEndLT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldSubscriptionEnd, v)) +} + +// SubscriptionEndLTE applies the LTE predicate on the "subscription_end" field. +func SubscriptionEndLTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldSubscriptionEnd, v)) +} + +// SubscriptionEndIsNil applies the IsNil predicate on the "subscription_end" field. +func SubscriptionEndIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldSubscriptionEnd)) +} + +// SubscriptionEndNotNil applies the NotNil predicate on the "subscription_end" field. +func SubscriptionEndNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldSubscriptionEnd)) +} + +// SoraSupportedEQ applies the EQ predicate on the "sora_supported" field. +func SoraSupportedEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraSupported, v)) +} + +// SoraSupportedNEQ applies the NEQ predicate on the "sora_supported" field. +func SoraSupportedNEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSoraSupported, v)) +} + +// SoraInviteCodeEQ applies the EQ predicate on the "sora_invite_code" field. +func SoraInviteCodeEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeNEQ applies the NEQ predicate on the "sora_invite_code" field. +func SoraInviteCodeNEQ(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeIn applies the In predicate on the "sora_invite_code" field. +func SoraInviteCodeIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldSoraInviteCode, vs...)) +} + +// SoraInviteCodeNotIn applies the NotIn predicate on the "sora_invite_code" field. +func SoraInviteCodeNotIn(vs ...string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldSoraInviteCode, vs...)) +} + +// SoraInviteCodeGT applies the GT predicate on the "sora_invite_code" field. +func SoraInviteCodeGT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeGTE applies the GTE predicate on the "sora_invite_code" field. +func SoraInviteCodeGTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeLT applies the LT predicate on the "sora_invite_code" field. +func SoraInviteCodeLT(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeLTE applies the LTE predicate on the "sora_invite_code" field. +func SoraInviteCodeLTE(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeContains applies the Contains predicate on the "sora_invite_code" field. +func SoraInviteCodeContains(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContains(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeHasPrefix applies the HasPrefix predicate on the "sora_invite_code" field. +func SoraInviteCodeHasPrefix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasPrefix(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeHasSuffix applies the HasSuffix predicate on the "sora_invite_code" field. +func SoraInviteCodeHasSuffix(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldHasSuffix(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeIsNil applies the IsNil predicate on the "sora_invite_code" field. +func SoraInviteCodeIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldSoraInviteCode)) +} + +// SoraInviteCodeNotNil applies the NotNil predicate on the "sora_invite_code" field. +func SoraInviteCodeNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldSoraInviteCode)) +} + +// SoraInviteCodeEqualFold applies the EqualFold predicate on the "sora_invite_code" field. +func SoraInviteCodeEqualFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEqualFold(FieldSoraInviteCode, v)) +} + +// SoraInviteCodeContainsFold applies the ContainsFold predicate on the "sora_invite_code" field. +func SoraInviteCodeContainsFold(v string) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldContainsFold(FieldSoraInviteCode, v)) +} + +// SoraRedeemedCountEQ applies the EQ predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraRedeemedCount, v)) +} + +// SoraRedeemedCountNEQ applies the NEQ predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountNEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSoraRedeemedCount, v)) +} + +// SoraRedeemedCountIn applies the In predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldSoraRedeemedCount, vs...)) +} + +// SoraRedeemedCountNotIn applies the NotIn predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountNotIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldSoraRedeemedCount, vs...)) +} + +// SoraRedeemedCountGT applies the GT predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountGT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldSoraRedeemedCount, v)) +} + +// SoraRedeemedCountGTE applies the GTE predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountGTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldSoraRedeemedCount, v)) +} + +// SoraRedeemedCountLT applies the LT predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountLT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldSoraRedeemedCount, v)) +} + +// SoraRedeemedCountLTE applies the LTE predicate on the "sora_redeemed_count" field. +func SoraRedeemedCountLTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldSoraRedeemedCount, v)) +} + +// SoraRemainingCountEQ applies the EQ predicate on the "sora_remaining_count" field. +func SoraRemainingCountEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraRemainingCount, v)) +} + +// SoraRemainingCountNEQ applies the NEQ predicate on the "sora_remaining_count" field. +func SoraRemainingCountNEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSoraRemainingCount, v)) +} + +// SoraRemainingCountIn applies the In predicate on the "sora_remaining_count" field. +func SoraRemainingCountIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldSoraRemainingCount, vs...)) +} + +// SoraRemainingCountNotIn applies the NotIn predicate on the "sora_remaining_count" field. +func SoraRemainingCountNotIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldSoraRemainingCount, vs...)) +} + +// SoraRemainingCountGT applies the GT predicate on the "sora_remaining_count" field. +func SoraRemainingCountGT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldSoraRemainingCount, v)) +} + +// SoraRemainingCountGTE applies the GTE predicate on the "sora_remaining_count" field. +func SoraRemainingCountGTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldSoraRemainingCount, v)) +} + +// SoraRemainingCountLT applies the LT predicate on the "sora_remaining_count" field. +func SoraRemainingCountLT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldSoraRemainingCount, v)) +} + +// SoraRemainingCountLTE applies the LTE predicate on the "sora_remaining_count" field. +func SoraRemainingCountLTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldSoraRemainingCount, v)) +} + +// SoraTotalCountEQ applies the EQ predicate on the "sora_total_count" field. +func SoraTotalCountEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraTotalCount, v)) +} + +// SoraTotalCountNEQ applies the NEQ predicate on the "sora_total_count" field. +func SoraTotalCountNEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSoraTotalCount, v)) +} + +// SoraTotalCountIn applies the In predicate on the "sora_total_count" field. +func SoraTotalCountIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldSoraTotalCount, vs...)) +} + +// SoraTotalCountNotIn applies the NotIn predicate on the "sora_total_count" field. +func SoraTotalCountNotIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldSoraTotalCount, vs...)) +} + +// SoraTotalCountGT applies the GT predicate on the "sora_total_count" field. +func SoraTotalCountGT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldSoraTotalCount, v)) +} + +// SoraTotalCountGTE applies the GTE predicate on the "sora_total_count" field. +func SoraTotalCountGTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldSoraTotalCount, v)) +} + +// SoraTotalCountLT applies the LT predicate on the "sora_total_count" field. +func SoraTotalCountLT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldSoraTotalCount, v)) +} + +// SoraTotalCountLTE applies the LTE predicate on the "sora_total_count" field. +func SoraTotalCountLTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldSoraTotalCount, v)) +} + +// SoraCooldownUntilEQ applies the EQ predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldSoraCooldownUntil, v)) +} + +// SoraCooldownUntilNEQ applies the NEQ predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilNEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldSoraCooldownUntil, v)) +} + +// SoraCooldownUntilIn applies the In predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldSoraCooldownUntil, vs...)) +} + +// SoraCooldownUntilNotIn applies the NotIn predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilNotIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldSoraCooldownUntil, vs...)) +} + +// SoraCooldownUntilGT applies the GT predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilGT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldSoraCooldownUntil, v)) +} + +// SoraCooldownUntilGTE applies the GTE predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilGTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldSoraCooldownUntil, v)) +} + +// SoraCooldownUntilLT applies the LT predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilLT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldSoraCooldownUntil, v)) +} + +// SoraCooldownUntilLTE applies the LTE predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilLTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldSoraCooldownUntil, v)) +} + +// SoraCooldownUntilIsNil applies the IsNil predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldSoraCooldownUntil)) +} + +// SoraCooldownUntilNotNil applies the NotNil predicate on the "sora_cooldown_until" field. +func SoraCooldownUntilNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldSoraCooldownUntil)) +} + +// CooledUntilEQ applies the EQ predicate on the "cooled_until" field. +func CooledUntilEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldCooledUntil, v)) +} + +// CooledUntilNEQ applies the NEQ predicate on the "cooled_until" field. +func CooledUntilNEQ(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldCooledUntil, v)) +} + +// CooledUntilIn applies the In predicate on the "cooled_until" field. +func CooledUntilIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldCooledUntil, vs...)) +} + +// CooledUntilNotIn applies the NotIn predicate on the "cooled_until" field. +func CooledUntilNotIn(vs ...time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldCooledUntil, vs...)) +} + +// CooledUntilGT applies the GT predicate on the "cooled_until" field. +func CooledUntilGT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldCooledUntil, v)) +} + +// CooledUntilGTE applies the GTE predicate on the "cooled_until" field. +func CooledUntilGTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldCooledUntil, v)) +} + +// CooledUntilLT applies the LT predicate on the "cooled_until" field. +func CooledUntilLT(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldCooledUntil, v)) +} + +// CooledUntilLTE applies the LTE predicate on the "cooled_until" field. +func CooledUntilLTE(v time.Time) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldCooledUntil, v)) +} + +// CooledUntilIsNil applies the IsNil predicate on the "cooled_until" field. +func CooledUntilIsNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIsNull(FieldCooledUntil)) +} + +// CooledUntilNotNil applies the NotNil predicate on the "cooled_until" field. +func CooledUntilNotNil() predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotNull(FieldCooledUntil)) +} + +// ImageEnabledEQ applies the EQ predicate on the "image_enabled" field. +func ImageEnabledEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldImageEnabled, v)) +} + +// ImageEnabledNEQ applies the NEQ predicate on the "image_enabled" field. +func ImageEnabledNEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldImageEnabled, v)) +} + +// VideoEnabledEQ applies the EQ predicate on the "video_enabled" field. +func VideoEnabledEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldVideoEnabled, v)) +} + +// VideoEnabledNEQ applies the NEQ predicate on the "video_enabled" field. +func VideoEnabledNEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldVideoEnabled, v)) +} + +// ImageConcurrencyEQ applies the EQ predicate on the "image_concurrency" field. +func ImageConcurrencyEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldImageConcurrency, v)) +} + +// ImageConcurrencyNEQ applies the NEQ predicate on the "image_concurrency" field. +func ImageConcurrencyNEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldImageConcurrency, v)) +} + +// ImageConcurrencyIn applies the In predicate on the "image_concurrency" field. +func ImageConcurrencyIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldImageConcurrency, vs...)) +} + +// ImageConcurrencyNotIn applies the NotIn predicate on the "image_concurrency" field. +func ImageConcurrencyNotIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldImageConcurrency, vs...)) +} + +// ImageConcurrencyGT applies the GT predicate on the "image_concurrency" field. +func ImageConcurrencyGT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldImageConcurrency, v)) +} + +// ImageConcurrencyGTE applies the GTE predicate on the "image_concurrency" field. +func ImageConcurrencyGTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldImageConcurrency, v)) +} + +// ImageConcurrencyLT applies the LT predicate on the "image_concurrency" field. +func ImageConcurrencyLT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldImageConcurrency, v)) +} + +// ImageConcurrencyLTE applies the LTE predicate on the "image_concurrency" field. +func ImageConcurrencyLTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldImageConcurrency, v)) +} + +// VideoConcurrencyEQ applies the EQ predicate on the "video_concurrency" field. +func VideoConcurrencyEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldVideoConcurrency, v)) +} + +// VideoConcurrencyNEQ applies the NEQ predicate on the "video_concurrency" field. +func VideoConcurrencyNEQ(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldVideoConcurrency, v)) +} + +// VideoConcurrencyIn applies the In predicate on the "video_concurrency" field. +func VideoConcurrencyIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldIn(FieldVideoConcurrency, vs...)) +} + +// VideoConcurrencyNotIn applies the NotIn predicate on the "video_concurrency" field. +func VideoConcurrencyNotIn(vs ...int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNotIn(FieldVideoConcurrency, vs...)) +} + +// VideoConcurrencyGT applies the GT predicate on the "video_concurrency" field. +func VideoConcurrencyGT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGT(FieldVideoConcurrency, v)) +} + +// VideoConcurrencyGTE applies the GTE predicate on the "video_concurrency" field. +func VideoConcurrencyGTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldGTE(FieldVideoConcurrency, v)) +} + +// VideoConcurrencyLT applies the LT predicate on the "video_concurrency" field. +func VideoConcurrencyLT(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLT(FieldVideoConcurrency, v)) +} + +// VideoConcurrencyLTE applies the LTE predicate on the "video_concurrency" field. +func VideoConcurrencyLTE(v int) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldLTE(FieldVideoConcurrency, v)) +} + +// IsExpiredEQ applies the EQ predicate on the "is_expired" field. +func IsExpiredEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldEQ(FieldIsExpired, v)) +} + +// IsExpiredNEQ applies the NEQ predicate on the "is_expired" field. +func IsExpiredNEQ(v bool) predicate.SoraAccount { + return predicate.SoraAccount(sql.FieldNEQ(FieldIsExpired, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.SoraAccount) predicate.SoraAccount { + return predicate.SoraAccount(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.SoraAccount) predicate.SoraAccount { + return predicate.SoraAccount(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.SoraAccount) predicate.SoraAccount { + return predicate.SoraAccount(sql.NotPredicates(p)) +} diff --git a/backend/ent/soraaccount_create.go b/backend/ent/soraaccount_create.go new file mode 100644 index 00000000..3aa03a1e --- /dev/null +++ b/backend/ent/soraaccount_create.go @@ -0,0 +1,2367 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" +) + +// SoraAccountCreate is the builder for creating a SoraAccount entity. +type SoraAccountCreate struct { + config + mutation *SoraAccountMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *SoraAccountCreate) SetCreatedAt(v time.Time) *SoraAccountCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableCreatedAt(v *time.Time) *SoraAccountCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *SoraAccountCreate) SetUpdatedAt(v time.Time) *SoraAccountCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableUpdatedAt(v *time.Time) *SoraAccountCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetAccountID sets the "account_id" field. +func (_c *SoraAccountCreate) SetAccountID(v int64) *SoraAccountCreate { + _c.mutation.SetAccountID(v) + return _c +} + +// SetAccessToken sets the "access_token" field. +func (_c *SoraAccountCreate) SetAccessToken(v string) *SoraAccountCreate { + _c.mutation.SetAccessToken(v) + return _c +} + +// SetNillableAccessToken sets the "access_token" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableAccessToken(v *string) *SoraAccountCreate { + if v != nil { + _c.SetAccessToken(*v) + } + return _c +} + +// SetSessionToken sets the "session_token" field. +func (_c *SoraAccountCreate) SetSessionToken(v string) *SoraAccountCreate { + _c.mutation.SetSessionToken(v) + return _c +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSessionToken(v *string) *SoraAccountCreate { + if v != nil { + _c.SetSessionToken(*v) + } + return _c +} + +// SetRefreshToken sets the "refresh_token" field. +func (_c *SoraAccountCreate) SetRefreshToken(v string) *SoraAccountCreate { + _c.mutation.SetRefreshToken(v) + return _c +} + +// SetNillableRefreshToken sets the "refresh_token" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableRefreshToken(v *string) *SoraAccountCreate { + if v != nil { + _c.SetRefreshToken(*v) + } + return _c +} + +// SetClientID sets the "client_id" field. +func (_c *SoraAccountCreate) SetClientID(v string) *SoraAccountCreate { + _c.mutation.SetClientID(v) + return _c +} + +// SetNillableClientID sets the "client_id" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableClientID(v *string) *SoraAccountCreate { + if v != nil { + _c.SetClientID(*v) + } + return _c +} + +// SetEmail sets the "email" field. +func (_c *SoraAccountCreate) SetEmail(v string) *SoraAccountCreate { + _c.mutation.SetEmail(v) + return _c +} + +// SetNillableEmail sets the "email" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableEmail(v *string) *SoraAccountCreate { + if v != nil { + _c.SetEmail(*v) + } + return _c +} + +// SetUsername sets the "username" field. +func (_c *SoraAccountCreate) SetUsername(v string) *SoraAccountCreate { + _c.mutation.SetUsername(v) + return _c +} + +// SetNillableUsername sets the "username" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableUsername(v *string) *SoraAccountCreate { + if v != nil { + _c.SetUsername(*v) + } + return _c +} + +// SetRemark sets the "remark" field. +func (_c *SoraAccountCreate) SetRemark(v string) *SoraAccountCreate { + _c.mutation.SetRemark(v) + return _c +} + +// SetNillableRemark sets the "remark" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableRemark(v *string) *SoraAccountCreate { + if v != nil { + _c.SetRemark(*v) + } + return _c +} + +// SetUseCount sets the "use_count" field. +func (_c *SoraAccountCreate) SetUseCount(v int) *SoraAccountCreate { + _c.mutation.SetUseCount(v) + return _c +} + +// SetNillableUseCount sets the "use_count" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableUseCount(v *int) *SoraAccountCreate { + if v != nil { + _c.SetUseCount(*v) + } + return _c +} + +// SetPlanType sets the "plan_type" field. +func (_c *SoraAccountCreate) SetPlanType(v string) *SoraAccountCreate { + _c.mutation.SetPlanType(v) + return _c +} + +// SetNillablePlanType sets the "plan_type" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillablePlanType(v *string) *SoraAccountCreate { + if v != nil { + _c.SetPlanType(*v) + } + return _c +} + +// SetPlanTitle sets the "plan_title" field. +func (_c *SoraAccountCreate) SetPlanTitle(v string) *SoraAccountCreate { + _c.mutation.SetPlanTitle(v) + return _c +} + +// SetNillablePlanTitle sets the "plan_title" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillablePlanTitle(v *string) *SoraAccountCreate { + if v != nil { + _c.SetPlanTitle(*v) + } + return _c +} + +// SetSubscriptionEnd sets the "subscription_end" field. +func (_c *SoraAccountCreate) SetSubscriptionEnd(v time.Time) *SoraAccountCreate { + _c.mutation.SetSubscriptionEnd(v) + return _c +} + +// SetNillableSubscriptionEnd sets the "subscription_end" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSubscriptionEnd(v *time.Time) *SoraAccountCreate { + if v != nil { + _c.SetSubscriptionEnd(*v) + } + return _c +} + +// SetSoraSupported sets the "sora_supported" field. +func (_c *SoraAccountCreate) SetSoraSupported(v bool) *SoraAccountCreate { + _c.mutation.SetSoraSupported(v) + return _c +} + +// SetNillableSoraSupported sets the "sora_supported" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSoraSupported(v *bool) *SoraAccountCreate { + if v != nil { + _c.SetSoraSupported(*v) + } + return _c +} + +// SetSoraInviteCode sets the "sora_invite_code" field. +func (_c *SoraAccountCreate) SetSoraInviteCode(v string) *SoraAccountCreate { + _c.mutation.SetSoraInviteCode(v) + return _c +} + +// SetNillableSoraInviteCode sets the "sora_invite_code" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSoraInviteCode(v *string) *SoraAccountCreate { + if v != nil { + _c.SetSoraInviteCode(*v) + } + return _c +} + +// SetSoraRedeemedCount sets the "sora_redeemed_count" field. +func (_c *SoraAccountCreate) SetSoraRedeemedCount(v int) *SoraAccountCreate { + _c.mutation.SetSoraRedeemedCount(v) + return _c +} + +// SetNillableSoraRedeemedCount sets the "sora_redeemed_count" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSoraRedeemedCount(v *int) *SoraAccountCreate { + if v != nil { + _c.SetSoraRedeemedCount(*v) + } + return _c +} + +// SetSoraRemainingCount sets the "sora_remaining_count" field. +func (_c *SoraAccountCreate) SetSoraRemainingCount(v int) *SoraAccountCreate { + _c.mutation.SetSoraRemainingCount(v) + return _c +} + +// SetNillableSoraRemainingCount sets the "sora_remaining_count" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSoraRemainingCount(v *int) *SoraAccountCreate { + if v != nil { + _c.SetSoraRemainingCount(*v) + } + return _c +} + +// SetSoraTotalCount sets the "sora_total_count" field. +func (_c *SoraAccountCreate) SetSoraTotalCount(v int) *SoraAccountCreate { + _c.mutation.SetSoraTotalCount(v) + return _c +} + +// SetNillableSoraTotalCount sets the "sora_total_count" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSoraTotalCount(v *int) *SoraAccountCreate { + if v != nil { + _c.SetSoraTotalCount(*v) + } + return _c +} + +// SetSoraCooldownUntil sets the "sora_cooldown_until" field. +func (_c *SoraAccountCreate) SetSoraCooldownUntil(v time.Time) *SoraAccountCreate { + _c.mutation.SetSoraCooldownUntil(v) + return _c +} + +// SetNillableSoraCooldownUntil sets the "sora_cooldown_until" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableSoraCooldownUntil(v *time.Time) *SoraAccountCreate { + if v != nil { + _c.SetSoraCooldownUntil(*v) + } + return _c +} + +// SetCooledUntil sets the "cooled_until" field. +func (_c *SoraAccountCreate) SetCooledUntil(v time.Time) *SoraAccountCreate { + _c.mutation.SetCooledUntil(v) + return _c +} + +// SetNillableCooledUntil sets the "cooled_until" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableCooledUntil(v *time.Time) *SoraAccountCreate { + if v != nil { + _c.SetCooledUntil(*v) + } + return _c +} + +// SetImageEnabled sets the "image_enabled" field. +func (_c *SoraAccountCreate) SetImageEnabled(v bool) *SoraAccountCreate { + _c.mutation.SetImageEnabled(v) + return _c +} + +// SetNillableImageEnabled sets the "image_enabled" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableImageEnabled(v *bool) *SoraAccountCreate { + if v != nil { + _c.SetImageEnabled(*v) + } + return _c +} + +// SetVideoEnabled sets the "video_enabled" field. +func (_c *SoraAccountCreate) SetVideoEnabled(v bool) *SoraAccountCreate { + _c.mutation.SetVideoEnabled(v) + return _c +} + +// SetNillableVideoEnabled sets the "video_enabled" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableVideoEnabled(v *bool) *SoraAccountCreate { + if v != nil { + _c.SetVideoEnabled(*v) + } + return _c +} + +// SetImageConcurrency sets the "image_concurrency" field. +func (_c *SoraAccountCreate) SetImageConcurrency(v int) *SoraAccountCreate { + _c.mutation.SetImageConcurrency(v) + return _c +} + +// SetNillableImageConcurrency sets the "image_concurrency" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableImageConcurrency(v *int) *SoraAccountCreate { + if v != nil { + _c.SetImageConcurrency(*v) + } + return _c +} + +// SetVideoConcurrency sets the "video_concurrency" field. +func (_c *SoraAccountCreate) SetVideoConcurrency(v int) *SoraAccountCreate { + _c.mutation.SetVideoConcurrency(v) + return _c +} + +// SetNillableVideoConcurrency sets the "video_concurrency" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableVideoConcurrency(v *int) *SoraAccountCreate { + if v != nil { + _c.SetVideoConcurrency(*v) + } + return _c +} + +// SetIsExpired sets the "is_expired" field. +func (_c *SoraAccountCreate) SetIsExpired(v bool) *SoraAccountCreate { + _c.mutation.SetIsExpired(v) + return _c +} + +// SetNillableIsExpired sets the "is_expired" field if the given value is not nil. +func (_c *SoraAccountCreate) SetNillableIsExpired(v *bool) *SoraAccountCreate { + if v != nil { + _c.SetIsExpired(*v) + } + return _c +} + +// Mutation returns the SoraAccountMutation object of the builder. +func (_c *SoraAccountCreate) Mutation() *SoraAccountMutation { + return _c.mutation +} + +// Save creates the SoraAccount in the database. +func (_c *SoraAccountCreate) Save(ctx context.Context) (*SoraAccount, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *SoraAccountCreate) SaveX(ctx context.Context) *SoraAccount { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraAccountCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraAccountCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *SoraAccountCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := soraaccount.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := soraaccount.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.UseCount(); !ok { + v := soraaccount.DefaultUseCount + _c.mutation.SetUseCount(v) + } + if _, ok := _c.mutation.SoraSupported(); !ok { + v := soraaccount.DefaultSoraSupported + _c.mutation.SetSoraSupported(v) + } + if _, ok := _c.mutation.SoraRedeemedCount(); !ok { + v := soraaccount.DefaultSoraRedeemedCount + _c.mutation.SetSoraRedeemedCount(v) + } + if _, ok := _c.mutation.SoraRemainingCount(); !ok { + v := soraaccount.DefaultSoraRemainingCount + _c.mutation.SetSoraRemainingCount(v) + } + if _, ok := _c.mutation.SoraTotalCount(); !ok { + v := soraaccount.DefaultSoraTotalCount + _c.mutation.SetSoraTotalCount(v) + } + if _, ok := _c.mutation.ImageEnabled(); !ok { + v := soraaccount.DefaultImageEnabled + _c.mutation.SetImageEnabled(v) + } + if _, ok := _c.mutation.VideoEnabled(); !ok { + v := soraaccount.DefaultVideoEnabled + _c.mutation.SetVideoEnabled(v) + } + if _, ok := _c.mutation.ImageConcurrency(); !ok { + v := soraaccount.DefaultImageConcurrency + _c.mutation.SetImageConcurrency(v) + } + if _, ok := _c.mutation.VideoConcurrency(); !ok { + v := soraaccount.DefaultVideoConcurrency + _c.mutation.SetVideoConcurrency(v) + } + if _, ok := _c.mutation.IsExpired(); !ok { + v := soraaccount.DefaultIsExpired + _c.mutation.SetIsExpired(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *SoraAccountCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SoraAccount.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SoraAccount.updated_at"`)} + } + if _, ok := _c.mutation.AccountID(); !ok { + return &ValidationError{Name: "account_id", err: errors.New(`ent: missing required field "SoraAccount.account_id"`)} + } + if _, ok := _c.mutation.UseCount(); !ok { + return &ValidationError{Name: "use_count", err: errors.New(`ent: missing required field "SoraAccount.use_count"`)} + } + if _, ok := _c.mutation.SoraSupported(); !ok { + return &ValidationError{Name: "sora_supported", err: errors.New(`ent: missing required field "SoraAccount.sora_supported"`)} + } + if _, ok := _c.mutation.SoraRedeemedCount(); !ok { + return &ValidationError{Name: "sora_redeemed_count", err: errors.New(`ent: missing required field "SoraAccount.sora_redeemed_count"`)} + } + if _, ok := _c.mutation.SoraRemainingCount(); !ok { + return &ValidationError{Name: "sora_remaining_count", err: errors.New(`ent: missing required field "SoraAccount.sora_remaining_count"`)} + } + if _, ok := _c.mutation.SoraTotalCount(); !ok { + return &ValidationError{Name: "sora_total_count", err: errors.New(`ent: missing required field "SoraAccount.sora_total_count"`)} + } + if _, ok := _c.mutation.ImageEnabled(); !ok { + return &ValidationError{Name: "image_enabled", err: errors.New(`ent: missing required field "SoraAccount.image_enabled"`)} + } + if _, ok := _c.mutation.VideoEnabled(); !ok { + return &ValidationError{Name: "video_enabled", err: errors.New(`ent: missing required field "SoraAccount.video_enabled"`)} + } + if _, ok := _c.mutation.ImageConcurrency(); !ok { + return &ValidationError{Name: "image_concurrency", err: errors.New(`ent: missing required field "SoraAccount.image_concurrency"`)} + } + if _, ok := _c.mutation.VideoConcurrency(); !ok { + return &ValidationError{Name: "video_concurrency", err: errors.New(`ent: missing required field "SoraAccount.video_concurrency"`)} + } + if _, ok := _c.mutation.IsExpired(); !ok { + return &ValidationError{Name: "is_expired", err: errors.New(`ent: missing required field "SoraAccount.is_expired"`)} + } + return nil +} + +func (_c *SoraAccountCreate) sqlSave(ctx context.Context) (*SoraAccount, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *SoraAccountCreate) createSpec() (*SoraAccount, *sqlgraph.CreateSpec) { + var ( + _node = &SoraAccount{config: _c.config} + _spec = sqlgraph.NewCreateSpec(soraaccount.Table, sqlgraph.NewFieldSpec(soraaccount.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(soraaccount.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(soraaccount.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.AccountID(); ok { + _spec.SetField(soraaccount.FieldAccountID, field.TypeInt64, value) + _node.AccountID = value + } + if value, ok := _c.mutation.AccessToken(); ok { + _spec.SetField(soraaccount.FieldAccessToken, field.TypeString, value) + _node.AccessToken = &value + } + if value, ok := _c.mutation.SessionToken(); ok { + _spec.SetField(soraaccount.FieldSessionToken, field.TypeString, value) + _node.SessionToken = &value + } + if value, ok := _c.mutation.RefreshToken(); ok { + _spec.SetField(soraaccount.FieldRefreshToken, field.TypeString, value) + _node.RefreshToken = &value + } + if value, ok := _c.mutation.ClientID(); ok { + _spec.SetField(soraaccount.FieldClientID, field.TypeString, value) + _node.ClientID = &value + } + if value, ok := _c.mutation.Email(); ok { + _spec.SetField(soraaccount.FieldEmail, field.TypeString, value) + _node.Email = &value + } + if value, ok := _c.mutation.Username(); ok { + _spec.SetField(soraaccount.FieldUsername, field.TypeString, value) + _node.Username = &value + } + if value, ok := _c.mutation.Remark(); ok { + _spec.SetField(soraaccount.FieldRemark, field.TypeString, value) + _node.Remark = &value + } + if value, ok := _c.mutation.UseCount(); ok { + _spec.SetField(soraaccount.FieldUseCount, field.TypeInt, value) + _node.UseCount = value + } + if value, ok := _c.mutation.PlanType(); ok { + _spec.SetField(soraaccount.FieldPlanType, field.TypeString, value) + _node.PlanType = &value + } + if value, ok := _c.mutation.PlanTitle(); ok { + _spec.SetField(soraaccount.FieldPlanTitle, field.TypeString, value) + _node.PlanTitle = &value + } + if value, ok := _c.mutation.SubscriptionEnd(); ok { + _spec.SetField(soraaccount.FieldSubscriptionEnd, field.TypeTime, value) + _node.SubscriptionEnd = &value + } + if value, ok := _c.mutation.SoraSupported(); ok { + _spec.SetField(soraaccount.FieldSoraSupported, field.TypeBool, value) + _node.SoraSupported = value + } + if value, ok := _c.mutation.SoraInviteCode(); ok { + _spec.SetField(soraaccount.FieldSoraInviteCode, field.TypeString, value) + _node.SoraInviteCode = &value + } + if value, ok := _c.mutation.SoraRedeemedCount(); ok { + _spec.SetField(soraaccount.FieldSoraRedeemedCount, field.TypeInt, value) + _node.SoraRedeemedCount = value + } + if value, ok := _c.mutation.SoraRemainingCount(); ok { + _spec.SetField(soraaccount.FieldSoraRemainingCount, field.TypeInt, value) + _node.SoraRemainingCount = value + } + if value, ok := _c.mutation.SoraTotalCount(); ok { + _spec.SetField(soraaccount.FieldSoraTotalCount, field.TypeInt, value) + _node.SoraTotalCount = value + } + if value, ok := _c.mutation.SoraCooldownUntil(); ok { + _spec.SetField(soraaccount.FieldSoraCooldownUntil, field.TypeTime, value) + _node.SoraCooldownUntil = &value + } + if value, ok := _c.mutation.CooledUntil(); ok { + _spec.SetField(soraaccount.FieldCooledUntil, field.TypeTime, value) + _node.CooledUntil = &value + } + if value, ok := _c.mutation.ImageEnabled(); ok { + _spec.SetField(soraaccount.FieldImageEnabled, field.TypeBool, value) + _node.ImageEnabled = value + } + if value, ok := _c.mutation.VideoEnabled(); ok { + _spec.SetField(soraaccount.FieldVideoEnabled, field.TypeBool, value) + _node.VideoEnabled = value + } + if value, ok := _c.mutation.ImageConcurrency(); ok { + _spec.SetField(soraaccount.FieldImageConcurrency, field.TypeInt, value) + _node.ImageConcurrency = value + } + if value, ok := _c.mutation.VideoConcurrency(); ok { + _spec.SetField(soraaccount.FieldVideoConcurrency, field.TypeInt, value) + _node.VideoConcurrency = value + } + if value, ok := _c.mutation.IsExpired(); ok { + _spec.SetField(soraaccount.FieldIsExpired, field.TypeBool, value) + _node.IsExpired = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraAccount.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraAccountUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SoraAccountCreate) OnConflict(opts ...sql.ConflictOption) *SoraAccountUpsertOne { + _c.conflict = opts + return &SoraAccountUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraAccount.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraAccountCreate) OnConflictColumns(columns ...string) *SoraAccountUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraAccountUpsertOne{ + create: _c, + } +} + +type ( + // SoraAccountUpsertOne is the builder for "upsert"-ing + // one SoraAccount node. + SoraAccountUpsertOne struct { + create *SoraAccountCreate + } + + // SoraAccountUpsert is the "OnConflict" setter. + SoraAccountUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *SoraAccountUpsert) SetUpdatedAt(v time.Time) *SoraAccountUpsert { + u.Set(soraaccount.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateUpdatedAt() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldUpdatedAt) + return u +} + +// SetAccountID sets the "account_id" field. +func (u *SoraAccountUpsert) SetAccountID(v int64) *SoraAccountUpsert { + u.Set(soraaccount.FieldAccountID, v) + return u +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateAccountID() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldAccountID) + return u +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraAccountUpsert) AddAccountID(v int64) *SoraAccountUpsert { + u.Add(soraaccount.FieldAccountID, v) + return u +} + +// SetAccessToken sets the "access_token" field. +func (u *SoraAccountUpsert) SetAccessToken(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldAccessToken, v) + return u +} + +// UpdateAccessToken sets the "access_token" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateAccessToken() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldAccessToken) + return u +} + +// ClearAccessToken clears the value of the "access_token" field. +func (u *SoraAccountUpsert) ClearAccessToken() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldAccessToken) + return u +} + +// SetSessionToken sets the "session_token" field. +func (u *SoraAccountUpsert) SetSessionToken(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldSessionToken, v) + return u +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSessionToken() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSessionToken) + return u +} + +// ClearSessionToken clears the value of the "session_token" field. +func (u *SoraAccountUpsert) ClearSessionToken() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldSessionToken) + return u +} + +// SetRefreshToken sets the "refresh_token" field. +func (u *SoraAccountUpsert) SetRefreshToken(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldRefreshToken, v) + return u +} + +// UpdateRefreshToken sets the "refresh_token" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateRefreshToken() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldRefreshToken) + return u +} + +// ClearRefreshToken clears the value of the "refresh_token" field. +func (u *SoraAccountUpsert) ClearRefreshToken() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldRefreshToken) + return u +} + +// SetClientID sets the "client_id" field. +func (u *SoraAccountUpsert) SetClientID(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldClientID, v) + return u +} + +// UpdateClientID sets the "client_id" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateClientID() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldClientID) + return u +} + +// ClearClientID clears the value of the "client_id" field. +func (u *SoraAccountUpsert) ClearClientID() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldClientID) + return u +} + +// SetEmail sets the "email" field. +func (u *SoraAccountUpsert) SetEmail(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldEmail, v) + return u +} + +// UpdateEmail sets the "email" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateEmail() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldEmail) + return u +} + +// ClearEmail clears the value of the "email" field. +func (u *SoraAccountUpsert) ClearEmail() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldEmail) + return u +} + +// SetUsername sets the "username" field. +func (u *SoraAccountUpsert) SetUsername(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldUsername, v) + return u +} + +// UpdateUsername sets the "username" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateUsername() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldUsername) + return u +} + +// ClearUsername clears the value of the "username" field. +func (u *SoraAccountUpsert) ClearUsername() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldUsername) + return u +} + +// SetRemark sets the "remark" field. +func (u *SoraAccountUpsert) SetRemark(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldRemark, v) + return u +} + +// UpdateRemark sets the "remark" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateRemark() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldRemark) + return u +} + +// ClearRemark clears the value of the "remark" field. +func (u *SoraAccountUpsert) ClearRemark() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldRemark) + return u +} + +// SetUseCount sets the "use_count" field. +func (u *SoraAccountUpsert) SetUseCount(v int) *SoraAccountUpsert { + u.Set(soraaccount.FieldUseCount, v) + return u +} + +// UpdateUseCount sets the "use_count" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateUseCount() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldUseCount) + return u +} + +// AddUseCount adds v to the "use_count" field. +func (u *SoraAccountUpsert) AddUseCount(v int) *SoraAccountUpsert { + u.Add(soraaccount.FieldUseCount, v) + return u +} + +// SetPlanType sets the "plan_type" field. +func (u *SoraAccountUpsert) SetPlanType(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldPlanType, v) + return u +} + +// UpdatePlanType sets the "plan_type" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdatePlanType() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldPlanType) + return u +} + +// ClearPlanType clears the value of the "plan_type" field. +func (u *SoraAccountUpsert) ClearPlanType() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldPlanType) + return u +} + +// SetPlanTitle sets the "plan_title" field. +func (u *SoraAccountUpsert) SetPlanTitle(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldPlanTitle, v) + return u +} + +// UpdatePlanTitle sets the "plan_title" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdatePlanTitle() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldPlanTitle) + return u +} + +// ClearPlanTitle clears the value of the "plan_title" field. +func (u *SoraAccountUpsert) ClearPlanTitle() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldPlanTitle) + return u +} + +// SetSubscriptionEnd sets the "subscription_end" field. +func (u *SoraAccountUpsert) SetSubscriptionEnd(v time.Time) *SoraAccountUpsert { + u.Set(soraaccount.FieldSubscriptionEnd, v) + return u +} + +// UpdateSubscriptionEnd sets the "subscription_end" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSubscriptionEnd() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSubscriptionEnd) + return u +} + +// ClearSubscriptionEnd clears the value of the "subscription_end" field. +func (u *SoraAccountUpsert) ClearSubscriptionEnd() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldSubscriptionEnd) + return u +} + +// SetSoraSupported sets the "sora_supported" field. +func (u *SoraAccountUpsert) SetSoraSupported(v bool) *SoraAccountUpsert { + u.Set(soraaccount.FieldSoraSupported, v) + return u +} + +// UpdateSoraSupported sets the "sora_supported" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSoraSupported() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSoraSupported) + return u +} + +// SetSoraInviteCode sets the "sora_invite_code" field. +func (u *SoraAccountUpsert) SetSoraInviteCode(v string) *SoraAccountUpsert { + u.Set(soraaccount.FieldSoraInviteCode, v) + return u +} + +// UpdateSoraInviteCode sets the "sora_invite_code" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSoraInviteCode() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSoraInviteCode) + return u +} + +// ClearSoraInviteCode clears the value of the "sora_invite_code" field. +func (u *SoraAccountUpsert) ClearSoraInviteCode() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldSoraInviteCode) + return u +} + +// SetSoraRedeemedCount sets the "sora_redeemed_count" field. +func (u *SoraAccountUpsert) SetSoraRedeemedCount(v int) *SoraAccountUpsert { + u.Set(soraaccount.FieldSoraRedeemedCount, v) + return u +} + +// UpdateSoraRedeemedCount sets the "sora_redeemed_count" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSoraRedeemedCount() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSoraRedeemedCount) + return u +} + +// AddSoraRedeemedCount adds v to the "sora_redeemed_count" field. +func (u *SoraAccountUpsert) AddSoraRedeemedCount(v int) *SoraAccountUpsert { + u.Add(soraaccount.FieldSoraRedeemedCount, v) + return u +} + +// SetSoraRemainingCount sets the "sora_remaining_count" field. +func (u *SoraAccountUpsert) SetSoraRemainingCount(v int) *SoraAccountUpsert { + u.Set(soraaccount.FieldSoraRemainingCount, v) + return u +} + +// UpdateSoraRemainingCount sets the "sora_remaining_count" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSoraRemainingCount() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSoraRemainingCount) + return u +} + +// AddSoraRemainingCount adds v to the "sora_remaining_count" field. +func (u *SoraAccountUpsert) AddSoraRemainingCount(v int) *SoraAccountUpsert { + u.Add(soraaccount.FieldSoraRemainingCount, v) + return u +} + +// SetSoraTotalCount sets the "sora_total_count" field. +func (u *SoraAccountUpsert) SetSoraTotalCount(v int) *SoraAccountUpsert { + u.Set(soraaccount.FieldSoraTotalCount, v) + return u +} + +// UpdateSoraTotalCount sets the "sora_total_count" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSoraTotalCount() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSoraTotalCount) + return u +} + +// AddSoraTotalCount adds v to the "sora_total_count" field. +func (u *SoraAccountUpsert) AddSoraTotalCount(v int) *SoraAccountUpsert { + u.Add(soraaccount.FieldSoraTotalCount, v) + return u +} + +// SetSoraCooldownUntil sets the "sora_cooldown_until" field. +func (u *SoraAccountUpsert) SetSoraCooldownUntil(v time.Time) *SoraAccountUpsert { + u.Set(soraaccount.FieldSoraCooldownUntil, v) + return u +} + +// UpdateSoraCooldownUntil sets the "sora_cooldown_until" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateSoraCooldownUntil() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldSoraCooldownUntil) + return u +} + +// ClearSoraCooldownUntil clears the value of the "sora_cooldown_until" field. +func (u *SoraAccountUpsert) ClearSoraCooldownUntil() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldSoraCooldownUntil) + return u +} + +// SetCooledUntil sets the "cooled_until" field. +func (u *SoraAccountUpsert) SetCooledUntil(v time.Time) *SoraAccountUpsert { + u.Set(soraaccount.FieldCooledUntil, v) + return u +} + +// UpdateCooledUntil sets the "cooled_until" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateCooledUntil() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldCooledUntil) + return u +} + +// ClearCooledUntil clears the value of the "cooled_until" field. +func (u *SoraAccountUpsert) ClearCooledUntil() *SoraAccountUpsert { + u.SetNull(soraaccount.FieldCooledUntil) + return u +} + +// SetImageEnabled sets the "image_enabled" field. +func (u *SoraAccountUpsert) SetImageEnabled(v bool) *SoraAccountUpsert { + u.Set(soraaccount.FieldImageEnabled, v) + return u +} + +// UpdateImageEnabled sets the "image_enabled" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateImageEnabled() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldImageEnabled) + return u +} + +// SetVideoEnabled sets the "video_enabled" field. +func (u *SoraAccountUpsert) SetVideoEnabled(v bool) *SoraAccountUpsert { + u.Set(soraaccount.FieldVideoEnabled, v) + return u +} + +// UpdateVideoEnabled sets the "video_enabled" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateVideoEnabled() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldVideoEnabled) + return u +} + +// SetImageConcurrency sets the "image_concurrency" field. +func (u *SoraAccountUpsert) SetImageConcurrency(v int) *SoraAccountUpsert { + u.Set(soraaccount.FieldImageConcurrency, v) + return u +} + +// UpdateImageConcurrency sets the "image_concurrency" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateImageConcurrency() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldImageConcurrency) + return u +} + +// AddImageConcurrency adds v to the "image_concurrency" field. +func (u *SoraAccountUpsert) AddImageConcurrency(v int) *SoraAccountUpsert { + u.Add(soraaccount.FieldImageConcurrency, v) + return u +} + +// SetVideoConcurrency sets the "video_concurrency" field. +func (u *SoraAccountUpsert) SetVideoConcurrency(v int) *SoraAccountUpsert { + u.Set(soraaccount.FieldVideoConcurrency, v) + return u +} + +// UpdateVideoConcurrency sets the "video_concurrency" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateVideoConcurrency() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldVideoConcurrency) + return u +} + +// AddVideoConcurrency adds v to the "video_concurrency" field. +func (u *SoraAccountUpsert) AddVideoConcurrency(v int) *SoraAccountUpsert { + u.Add(soraaccount.FieldVideoConcurrency, v) + return u +} + +// SetIsExpired sets the "is_expired" field. +func (u *SoraAccountUpsert) SetIsExpired(v bool) *SoraAccountUpsert { + u.Set(soraaccount.FieldIsExpired, v) + return u +} + +// UpdateIsExpired sets the "is_expired" field to the value that was provided on create. +func (u *SoraAccountUpsert) UpdateIsExpired() *SoraAccountUpsert { + u.SetExcluded(soraaccount.FieldIsExpired) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.SoraAccount.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraAccountUpsertOne) UpdateNewValues() *SoraAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(soraaccount.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraAccount.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraAccountUpsertOne) Ignore() *SoraAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraAccountUpsertOne) DoNothing() *SoraAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraAccountCreate.OnConflict +// documentation for more info. +func (u *SoraAccountUpsertOne) Update(set func(*SoraAccountUpsert)) *SoraAccountUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraAccountUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SoraAccountUpsertOne) SetUpdatedAt(v time.Time) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateUpdatedAt() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraAccountUpsertOne) SetAccountID(v int64) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraAccountUpsertOne) AddAccountID(v int64) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateAccountID() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateAccountID() + }) +} + +// SetAccessToken sets the "access_token" field. +func (u *SoraAccountUpsertOne) SetAccessToken(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetAccessToken(v) + }) +} + +// UpdateAccessToken sets the "access_token" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateAccessToken() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateAccessToken() + }) +} + +// ClearAccessToken clears the value of the "access_token" field. +func (u *SoraAccountUpsertOne) ClearAccessToken() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearAccessToken() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *SoraAccountUpsertOne) SetSessionToken(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSessionToken() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSessionToken() + }) +} + +// ClearSessionToken clears the value of the "session_token" field. +func (u *SoraAccountUpsertOne) ClearSessionToken() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSessionToken() + }) +} + +// SetRefreshToken sets the "refresh_token" field. +func (u *SoraAccountUpsertOne) SetRefreshToken(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetRefreshToken(v) + }) +} + +// UpdateRefreshToken sets the "refresh_token" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateRefreshToken() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateRefreshToken() + }) +} + +// ClearRefreshToken clears the value of the "refresh_token" field. +func (u *SoraAccountUpsertOne) ClearRefreshToken() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearRefreshToken() + }) +} + +// SetClientID sets the "client_id" field. +func (u *SoraAccountUpsertOne) SetClientID(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetClientID(v) + }) +} + +// UpdateClientID sets the "client_id" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateClientID() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateClientID() + }) +} + +// ClearClientID clears the value of the "client_id" field. +func (u *SoraAccountUpsertOne) ClearClientID() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearClientID() + }) +} + +// SetEmail sets the "email" field. +func (u *SoraAccountUpsertOne) SetEmail(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetEmail(v) + }) +} + +// UpdateEmail sets the "email" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateEmail() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateEmail() + }) +} + +// ClearEmail clears the value of the "email" field. +func (u *SoraAccountUpsertOne) ClearEmail() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearEmail() + }) +} + +// SetUsername sets the "username" field. +func (u *SoraAccountUpsertOne) SetUsername(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetUsername(v) + }) +} + +// UpdateUsername sets the "username" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateUsername() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateUsername() + }) +} + +// ClearUsername clears the value of the "username" field. +func (u *SoraAccountUpsertOne) ClearUsername() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearUsername() + }) +} + +// SetRemark sets the "remark" field. +func (u *SoraAccountUpsertOne) SetRemark(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetRemark(v) + }) +} + +// UpdateRemark sets the "remark" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateRemark() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateRemark() + }) +} + +// ClearRemark clears the value of the "remark" field. +func (u *SoraAccountUpsertOne) ClearRemark() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearRemark() + }) +} + +// SetUseCount sets the "use_count" field. +func (u *SoraAccountUpsertOne) SetUseCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetUseCount(v) + }) +} + +// AddUseCount adds v to the "use_count" field. +func (u *SoraAccountUpsertOne) AddUseCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.AddUseCount(v) + }) +} + +// UpdateUseCount sets the "use_count" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateUseCount() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateUseCount() + }) +} + +// SetPlanType sets the "plan_type" field. +func (u *SoraAccountUpsertOne) SetPlanType(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetPlanType(v) + }) +} + +// UpdatePlanType sets the "plan_type" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdatePlanType() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdatePlanType() + }) +} + +// ClearPlanType clears the value of the "plan_type" field. +func (u *SoraAccountUpsertOne) ClearPlanType() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearPlanType() + }) +} + +// SetPlanTitle sets the "plan_title" field. +func (u *SoraAccountUpsertOne) SetPlanTitle(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetPlanTitle(v) + }) +} + +// UpdatePlanTitle sets the "plan_title" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdatePlanTitle() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdatePlanTitle() + }) +} + +// ClearPlanTitle clears the value of the "plan_title" field. +func (u *SoraAccountUpsertOne) ClearPlanTitle() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearPlanTitle() + }) +} + +// SetSubscriptionEnd sets the "subscription_end" field. +func (u *SoraAccountUpsertOne) SetSubscriptionEnd(v time.Time) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSubscriptionEnd(v) + }) +} + +// UpdateSubscriptionEnd sets the "subscription_end" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSubscriptionEnd() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSubscriptionEnd() + }) +} + +// ClearSubscriptionEnd clears the value of the "subscription_end" field. +func (u *SoraAccountUpsertOne) ClearSubscriptionEnd() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSubscriptionEnd() + }) +} + +// SetSoraSupported sets the "sora_supported" field. +func (u *SoraAccountUpsertOne) SetSoraSupported(v bool) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraSupported(v) + }) +} + +// UpdateSoraSupported sets the "sora_supported" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSoraSupported() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraSupported() + }) +} + +// SetSoraInviteCode sets the "sora_invite_code" field. +func (u *SoraAccountUpsertOne) SetSoraInviteCode(v string) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraInviteCode(v) + }) +} + +// UpdateSoraInviteCode sets the "sora_invite_code" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSoraInviteCode() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraInviteCode() + }) +} + +// ClearSoraInviteCode clears the value of the "sora_invite_code" field. +func (u *SoraAccountUpsertOne) ClearSoraInviteCode() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSoraInviteCode() + }) +} + +// SetSoraRedeemedCount sets the "sora_redeemed_count" field. +func (u *SoraAccountUpsertOne) SetSoraRedeemedCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraRedeemedCount(v) + }) +} + +// AddSoraRedeemedCount adds v to the "sora_redeemed_count" field. +func (u *SoraAccountUpsertOne) AddSoraRedeemedCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.AddSoraRedeemedCount(v) + }) +} + +// UpdateSoraRedeemedCount sets the "sora_redeemed_count" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSoraRedeemedCount() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraRedeemedCount() + }) +} + +// SetSoraRemainingCount sets the "sora_remaining_count" field. +func (u *SoraAccountUpsertOne) SetSoraRemainingCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraRemainingCount(v) + }) +} + +// AddSoraRemainingCount adds v to the "sora_remaining_count" field. +func (u *SoraAccountUpsertOne) AddSoraRemainingCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.AddSoraRemainingCount(v) + }) +} + +// UpdateSoraRemainingCount sets the "sora_remaining_count" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSoraRemainingCount() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraRemainingCount() + }) +} + +// SetSoraTotalCount sets the "sora_total_count" field. +func (u *SoraAccountUpsertOne) SetSoraTotalCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraTotalCount(v) + }) +} + +// AddSoraTotalCount adds v to the "sora_total_count" field. +func (u *SoraAccountUpsertOne) AddSoraTotalCount(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.AddSoraTotalCount(v) + }) +} + +// UpdateSoraTotalCount sets the "sora_total_count" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSoraTotalCount() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraTotalCount() + }) +} + +// SetSoraCooldownUntil sets the "sora_cooldown_until" field. +func (u *SoraAccountUpsertOne) SetSoraCooldownUntil(v time.Time) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraCooldownUntil(v) + }) +} + +// UpdateSoraCooldownUntil sets the "sora_cooldown_until" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateSoraCooldownUntil() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraCooldownUntil() + }) +} + +// ClearSoraCooldownUntil clears the value of the "sora_cooldown_until" field. +func (u *SoraAccountUpsertOne) ClearSoraCooldownUntil() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSoraCooldownUntil() + }) +} + +// SetCooledUntil sets the "cooled_until" field. +func (u *SoraAccountUpsertOne) SetCooledUntil(v time.Time) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetCooledUntil(v) + }) +} + +// UpdateCooledUntil sets the "cooled_until" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateCooledUntil() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateCooledUntil() + }) +} + +// ClearCooledUntil clears the value of the "cooled_until" field. +func (u *SoraAccountUpsertOne) ClearCooledUntil() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearCooledUntil() + }) +} + +// SetImageEnabled sets the "image_enabled" field. +func (u *SoraAccountUpsertOne) SetImageEnabled(v bool) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetImageEnabled(v) + }) +} + +// UpdateImageEnabled sets the "image_enabled" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateImageEnabled() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateImageEnabled() + }) +} + +// SetVideoEnabled sets the "video_enabled" field. +func (u *SoraAccountUpsertOne) SetVideoEnabled(v bool) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetVideoEnabled(v) + }) +} + +// UpdateVideoEnabled sets the "video_enabled" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateVideoEnabled() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateVideoEnabled() + }) +} + +// SetImageConcurrency sets the "image_concurrency" field. +func (u *SoraAccountUpsertOne) SetImageConcurrency(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetImageConcurrency(v) + }) +} + +// AddImageConcurrency adds v to the "image_concurrency" field. +func (u *SoraAccountUpsertOne) AddImageConcurrency(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.AddImageConcurrency(v) + }) +} + +// UpdateImageConcurrency sets the "image_concurrency" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateImageConcurrency() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateImageConcurrency() + }) +} + +// SetVideoConcurrency sets the "video_concurrency" field. +func (u *SoraAccountUpsertOne) SetVideoConcurrency(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetVideoConcurrency(v) + }) +} + +// AddVideoConcurrency adds v to the "video_concurrency" field. +func (u *SoraAccountUpsertOne) AddVideoConcurrency(v int) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.AddVideoConcurrency(v) + }) +} + +// UpdateVideoConcurrency sets the "video_concurrency" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateVideoConcurrency() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateVideoConcurrency() + }) +} + +// SetIsExpired sets the "is_expired" field. +func (u *SoraAccountUpsertOne) SetIsExpired(v bool) *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.SetIsExpired(v) + }) +} + +// UpdateIsExpired sets the "is_expired" field to the value that was provided on create. +func (u *SoraAccountUpsertOne) UpdateIsExpired() *SoraAccountUpsertOne { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateIsExpired() + }) +} + +// Exec executes the query. +func (u *SoraAccountUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraAccountCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraAccountUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SoraAccountUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SoraAccountUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// SoraAccountCreateBulk is the builder for creating many SoraAccount entities in bulk. +type SoraAccountCreateBulk struct { + config + err error + builders []*SoraAccountCreate + conflict []sql.ConflictOption +} + +// Save creates the SoraAccount entities in the database. +func (_c *SoraAccountCreateBulk) Save(ctx context.Context) ([]*SoraAccount, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*SoraAccount, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SoraAccountMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *SoraAccountCreateBulk) SaveX(ctx context.Context) []*SoraAccount { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraAccountCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraAccountCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraAccount.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraAccountUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SoraAccountCreateBulk) OnConflict(opts ...sql.ConflictOption) *SoraAccountUpsertBulk { + _c.conflict = opts + return &SoraAccountUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraAccount.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraAccountCreateBulk) OnConflictColumns(columns ...string) *SoraAccountUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraAccountUpsertBulk{ + create: _c, + } +} + +// SoraAccountUpsertBulk is the builder for "upsert"-ing +// a bulk of SoraAccount nodes. +type SoraAccountUpsertBulk struct { + create *SoraAccountCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.SoraAccount.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraAccountUpsertBulk) UpdateNewValues() *SoraAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(soraaccount.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraAccount.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraAccountUpsertBulk) Ignore() *SoraAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraAccountUpsertBulk) DoNothing() *SoraAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraAccountCreateBulk.OnConflict +// documentation for more info. +func (u *SoraAccountUpsertBulk) Update(set func(*SoraAccountUpsert)) *SoraAccountUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraAccountUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SoraAccountUpsertBulk) SetUpdatedAt(v time.Time) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateUpdatedAt() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraAccountUpsertBulk) SetAccountID(v int64) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraAccountUpsertBulk) AddAccountID(v int64) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateAccountID() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateAccountID() + }) +} + +// SetAccessToken sets the "access_token" field. +func (u *SoraAccountUpsertBulk) SetAccessToken(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetAccessToken(v) + }) +} + +// UpdateAccessToken sets the "access_token" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateAccessToken() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateAccessToken() + }) +} + +// ClearAccessToken clears the value of the "access_token" field. +func (u *SoraAccountUpsertBulk) ClearAccessToken() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearAccessToken() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *SoraAccountUpsertBulk) SetSessionToken(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSessionToken() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSessionToken() + }) +} + +// ClearSessionToken clears the value of the "session_token" field. +func (u *SoraAccountUpsertBulk) ClearSessionToken() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSessionToken() + }) +} + +// SetRefreshToken sets the "refresh_token" field. +func (u *SoraAccountUpsertBulk) SetRefreshToken(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetRefreshToken(v) + }) +} + +// UpdateRefreshToken sets the "refresh_token" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateRefreshToken() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateRefreshToken() + }) +} + +// ClearRefreshToken clears the value of the "refresh_token" field. +func (u *SoraAccountUpsertBulk) ClearRefreshToken() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearRefreshToken() + }) +} + +// SetClientID sets the "client_id" field. +func (u *SoraAccountUpsertBulk) SetClientID(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetClientID(v) + }) +} + +// UpdateClientID sets the "client_id" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateClientID() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateClientID() + }) +} + +// ClearClientID clears the value of the "client_id" field. +func (u *SoraAccountUpsertBulk) ClearClientID() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearClientID() + }) +} + +// SetEmail sets the "email" field. +func (u *SoraAccountUpsertBulk) SetEmail(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetEmail(v) + }) +} + +// UpdateEmail sets the "email" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateEmail() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateEmail() + }) +} + +// ClearEmail clears the value of the "email" field. +func (u *SoraAccountUpsertBulk) ClearEmail() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearEmail() + }) +} + +// SetUsername sets the "username" field. +func (u *SoraAccountUpsertBulk) SetUsername(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetUsername(v) + }) +} + +// UpdateUsername sets the "username" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateUsername() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateUsername() + }) +} + +// ClearUsername clears the value of the "username" field. +func (u *SoraAccountUpsertBulk) ClearUsername() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearUsername() + }) +} + +// SetRemark sets the "remark" field. +func (u *SoraAccountUpsertBulk) SetRemark(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetRemark(v) + }) +} + +// UpdateRemark sets the "remark" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateRemark() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateRemark() + }) +} + +// ClearRemark clears the value of the "remark" field. +func (u *SoraAccountUpsertBulk) ClearRemark() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearRemark() + }) +} + +// SetUseCount sets the "use_count" field. +func (u *SoraAccountUpsertBulk) SetUseCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetUseCount(v) + }) +} + +// AddUseCount adds v to the "use_count" field. +func (u *SoraAccountUpsertBulk) AddUseCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.AddUseCount(v) + }) +} + +// UpdateUseCount sets the "use_count" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateUseCount() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateUseCount() + }) +} + +// SetPlanType sets the "plan_type" field. +func (u *SoraAccountUpsertBulk) SetPlanType(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetPlanType(v) + }) +} + +// UpdatePlanType sets the "plan_type" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdatePlanType() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdatePlanType() + }) +} + +// ClearPlanType clears the value of the "plan_type" field. +func (u *SoraAccountUpsertBulk) ClearPlanType() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearPlanType() + }) +} + +// SetPlanTitle sets the "plan_title" field. +func (u *SoraAccountUpsertBulk) SetPlanTitle(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetPlanTitle(v) + }) +} + +// UpdatePlanTitle sets the "plan_title" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdatePlanTitle() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdatePlanTitle() + }) +} + +// ClearPlanTitle clears the value of the "plan_title" field. +func (u *SoraAccountUpsertBulk) ClearPlanTitle() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearPlanTitle() + }) +} + +// SetSubscriptionEnd sets the "subscription_end" field. +func (u *SoraAccountUpsertBulk) SetSubscriptionEnd(v time.Time) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSubscriptionEnd(v) + }) +} + +// UpdateSubscriptionEnd sets the "subscription_end" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSubscriptionEnd() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSubscriptionEnd() + }) +} + +// ClearSubscriptionEnd clears the value of the "subscription_end" field. +func (u *SoraAccountUpsertBulk) ClearSubscriptionEnd() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSubscriptionEnd() + }) +} + +// SetSoraSupported sets the "sora_supported" field. +func (u *SoraAccountUpsertBulk) SetSoraSupported(v bool) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraSupported(v) + }) +} + +// UpdateSoraSupported sets the "sora_supported" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSoraSupported() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraSupported() + }) +} + +// SetSoraInviteCode sets the "sora_invite_code" field. +func (u *SoraAccountUpsertBulk) SetSoraInviteCode(v string) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraInviteCode(v) + }) +} + +// UpdateSoraInviteCode sets the "sora_invite_code" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSoraInviteCode() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraInviteCode() + }) +} + +// ClearSoraInviteCode clears the value of the "sora_invite_code" field. +func (u *SoraAccountUpsertBulk) ClearSoraInviteCode() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSoraInviteCode() + }) +} + +// SetSoraRedeemedCount sets the "sora_redeemed_count" field. +func (u *SoraAccountUpsertBulk) SetSoraRedeemedCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraRedeemedCount(v) + }) +} + +// AddSoraRedeemedCount adds v to the "sora_redeemed_count" field. +func (u *SoraAccountUpsertBulk) AddSoraRedeemedCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.AddSoraRedeemedCount(v) + }) +} + +// UpdateSoraRedeemedCount sets the "sora_redeemed_count" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSoraRedeemedCount() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraRedeemedCount() + }) +} + +// SetSoraRemainingCount sets the "sora_remaining_count" field. +func (u *SoraAccountUpsertBulk) SetSoraRemainingCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraRemainingCount(v) + }) +} + +// AddSoraRemainingCount adds v to the "sora_remaining_count" field. +func (u *SoraAccountUpsertBulk) AddSoraRemainingCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.AddSoraRemainingCount(v) + }) +} + +// UpdateSoraRemainingCount sets the "sora_remaining_count" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSoraRemainingCount() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraRemainingCount() + }) +} + +// SetSoraTotalCount sets the "sora_total_count" field. +func (u *SoraAccountUpsertBulk) SetSoraTotalCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraTotalCount(v) + }) +} + +// AddSoraTotalCount adds v to the "sora_total_count" field. +func (u *SoraAccountUpsertBulk) AddSoraTotalCount(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.AddSoraTotalCount(v) + }) +} + +// UpdateSoraTotalCount sets the "sora_total_count" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSoraTotalCount() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraTotalCount() + }) +} + +// SetSoraCooldownUntil sets the "sora_cooldown_until" field. +func (u *SoraAccountUpsertBulk) SetSoraCooldownUntil(v time.Time) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetSoraCooldownUntil(v) + }) +} + +// UpdateSoraCooldownUntil sets the "sora_cooldown_until" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateSoraCooldownUntil() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateSoraCooldownUntil() + }) +} + +// ClearSoraCooldownUntil clears the value of the "sora_cooldown_until" field. +func (u *SoraAccountUpsertBulk) ClearSoraCooldownUntil() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearSoraCooldownUntil() + }) +} + +// SetCooledUntil sets the "cooled_until" field. +func (u *SoraAccountUpsertBulk) SetCooledUntil(v time.Time) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetCooledUntil(v) + }) +} + +// UpdateCooledUntil sets the "cooled_until" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateCooledUntil() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateCooledUntil() + }) +} + +// ClearCooledUntil clears the value of the "cooled_until" field. +func (u *SoraAccountUpsertBulk) ClearCooledUntil() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.ClearCooledUntil() + }) +} + +// SetImageEnabled sets the "image_enabled" field. +func (u *SoraAccountUpsertBulk) SetImageEnabled(v bool) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetImageEnabled(v) + }) +} + +// UpdateImageEnabled sets the "image_enabled" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateImageEnabled() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateImageEnabled() + }) +} + +// SetVideoEnabled sets the "video_enabled" field. +func (u *SoraAccountUpsertBulk) SetVideoEnabled(v bool) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetVideoEnabled(v) + }) +} + +// UpdateVideoEnabled sets the "video_enabled" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateVideoEnabled() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateVideoEnabled() + }) +} + +// SetImageConcurrency sets the "image_concurrency" field. +func (u *SoraAccountUpsertBulk) SetImageConcurrency(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetImageConcurrency(v) + }) +} + +// AddImageConcurrency adds v to the "image_concurrency" field. +func (u *SoraAccountUpsertBulk) AddImageConcurrency(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.AddImageConcurrency(v) + }) +} + +// UpdateImageConcurrency sets the "image_concurrency" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateImageConcurrency() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateImageConcurrency() + }) +} + +// SetVideoConcurrency sets the "video_concurrency" field. +func (u *SoraAccountUpsertBulk) SetVideoConcurrency(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetVideoConcurrency(v) + }) +} + +// AddVideoConcurrency adds v to the "video_concurrency" field. +func (u *SoraAccountUpsertBulk) AddVideoConcurrency(v int) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.AddVideoConcurrency(v) + }) +} + +// UpdateVideoConcurrency sets the "video_concurrency" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateVideoConcurrency() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateVideoConcurrency() + }) +} + +// SetIsExpired sets the "is_expired" field. +func (u *SoraAccountUpsertBulk) SetIsExpired(v bool) *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.SetIsExpired(v) + }) +} + +// UpdateIsExpired sets the "is_expired" field to the value that was provided on create. +func (u *SoraAccountUpsertBulk) UpdateIsExpired() *SoraAccountUpsertBulk { + return u.Update(func(s *SoraAccountUpsert) { + s.UpdateIsExpired() + }) +} + +// Exec executes the query. +func (u *SoraAccountUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SoraAccountCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraAccountCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraAccountUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/soraaccount_delete.go b/backend/ent/soraaccount_delete.go new file mode 100644 index 00000000..bed347ac --- /dev/null +++ b/backend/ent/soraaccount_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" +) + +// SoraAccountDelete is the builder for deleting a SoraAccount entity. +type SoraAccountDelete struct { + config + hooks []Hook + mutation *SoraAccountMutation +} + +// Where appends a list predicates to the SoraAccountDelete builder. +func (_d *SoraAccountDelete) Where(ps ...predicate.SoraAccount) *SoraAccountDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *SoraAccountDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraAccountDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *SoraAccountDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(soraaccount.Table, sqlgraph.NewFieldSpec(soraaccount.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// SoraAccountDeleteOne is the builder for deleting a single SoraAccount entity. +type SoraAccountDeleteOne struct { + _d *SoraAccountDelete +} + +// Where appends a list predicates to the SoraAccountDelete builder. +func (_d *SoraAccountDeleteOne) Where(ps ...predicate.SoraAccount) *SoraAccountDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *SoraAccountDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{soraaccount.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraAccountDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/soraaccount_query.go b/backend/ent/soraaccount_query.go new file mode 100644 index 00000000..cf819243 --- /dev/null +++ b/backend/ent/soraaccount_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" +) + +// SoraAccountQuery is the builder for querying SoraAccount entities. +type SoraAccountQuery struct { + config + ctx *QueryContext + order []soraaccount.OrderOption + inters []Interceptor + predicates []predicate.SoraAccount + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SoraAccountQuery builder. +func (_q *SoraAccountQuery) Where(ps ...predicate.SoraAccount) *SoraAccountQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *SoraAccountQuery) Limit(limit int) *SoraAccountQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *SoraAccountQuery) Offset(offset int) *SoraAccountQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *SoraAccountQuery) Unique(unique bool) *SoraAccountQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *SoraAccountQuery) Order(o ...soraaccount.OrderOption) *SoraAccountQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first SoraAccount entity from the query. +// Returns a *NotFoundError when no SoraAccount was found. +func (_q *SoraAccountQuery) First(ctx context.Context) (*SoraAccount, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{soraaccount.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *SoraAccountQuery) FirstX(ctx context.Context) *SoraAccount { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first SoraAccount ID from the query. +// Returns a *NotFoundError when no SoraAccount ID was found. +func (_q *SoraAccountQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{soraaccount.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *SoraAccountQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single SoraAccount entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one SoraAccount entity is found. +// Returns a *NotFoundError when no SoraAccount entities are found. +func (_q *SoraAccountQuery) Only(ctx context.Context) (*SoraAccount, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{soraaccount.Label} + default: + return nil, &NotSingularError{soraaccount.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *SoraAccountQuery) OnlyX(ctx context.Context) *SoraAccount { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only SoraAccount ID in the query. +// Returns a *NotSingularError when more than one SoraAccount ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *SoraAccountQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{soraaccount.Label} + default: + err = &NotSingularError{soraaccount.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *SoraAccountQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of SoraAccounts. +func (_q *SoraAccountQuery) All(ctx context.Context) ([]*SoraAccount, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*SoraAccount, *SoraAccountQuery]() + return withInterceptors[[]*SoraAccount](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *SoraAccountQuery) AllX(ctx context.Context) []*SoraAccount { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of SoraAccount IDs. +func (_q *SoraAccountQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(soraaccount.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *SoraAccountQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *SoraAccountQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*SoraAccountQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *SoraAccountQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *SoraAccountQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *SoraAccountQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SoraAccountQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *SoraAccountQuery) Clone() *SoraAccountQuery { + if _q == nil { + return nil + } + return &SoraAccountQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]soraaccount.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.SoraAccount{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.SoraAccount.Query(). +// GroupBy(soraaccount.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *SoraAccountQuery) GroupBy(field string, fields ...string) *SoraAccountGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &SoraAccountGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = soraaccount.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.SoraAccount.Query(). +// Select(soraaccount.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *SoraAccountQuery) Select(fields ...string) *SoraAccountSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &SoraAccountSelect{SoraAccountQuery: _q} + sbuild.label = soraaccount.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SoraAccountSelect configured with the given aggregations. +func (_q *SoraAccountQuery) Aggregate(fns ...AggregateFunc) *SoraAccountSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *SoraAccountQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !soraaccount.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *SoraAccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SoraAccount, error) { + var ( + nodes = []*SoraAccount{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*SoraAccount).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &SoraAccount{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *SoraAccountQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *SoraAccountQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(soraaccount.Table, soraaccount.Columns, sqlgraph.NewFieldSpec(soraaccount.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, soraaccount.FieldID) + for i := range fields { + if fields[i] != soraaccount.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *SoraAccountQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(soraaccount.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = soraaccount.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *SoraAccountQuery) ForUpdate(opts ...sql.LockOption) *SoraAccountQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *SoraAccountQuery) ForShare(opts ...sql.LockOption) *SoraAccountQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// SoraAccountGroupBy is the group-by builder for SoraAccount entities. +type SoraAccountGroupBy struct { + selector + build *SoraAccountQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *SoraAccountGroupBy) Aggregate(fns ...AggregateFunc) *SoraAccountGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *SoraAccountGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraAccountQuery, *SoraAccountGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *SoraAccountGroupBy) sqlScan(ctx context.Context, root *SoraAccountQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SoraAccountSelect is the builder for selecting fields of SoraAccount entities. +type SoraAccountSelect struct { + *SoraAccountQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *SoraAccountSelect) Aggregate(fns ...AggregateFunc) *SoraAccountSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *SoraAccountSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraAccountQuery, *SoraAccountSelect](ctx, _s.SoraAccountQuery, _s, _s.inters, v) +} + +func (_s *SoraAccountSelect) sqlScan(ctx context.Context, root *SoraAccountQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/soraaccount_update.go b/backend/ent/soraaccount_update.go new file mode 100644 index 00000000..dcc62853 --- /dev/null +++ b/backend/ent/soraaccount_update.go @@ -0,0 +1,1402 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soraaccount" +) + +// SoraAccountUpdate is the builder for updating SoraAccount entities. +type SoraAccountUpdate struct { + config + hooks []Hook + mutation *SoraAccountMutation +} + +// Where appends a list predicates to the SoraAccountUpdate builder. +func (_u *SoraAccountUpdate) Where(ps ...predicate.SoraAccount) *SoraAccountUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SoraAccountUpdate) SetUpdatedAt(v time.Time) *SoraAccountUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraAccountUpdate) SetAccountID(v int64) *SoraAccountUpdate { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableAccountID(v *int64) *SoraAccountUpdate { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraAccountUpdate) AddAccountID(v int64) *SoraAccountUpdate { + _u.mutation.AddAccountID(v) + return _u +} + +// SetAccessToken sets the "access_token" field. +func (_u *SoraAccountUpdate) SetAccessToken(v string) *SoraAccountUpdate { + _u.mutation.SetAccessToken(v) + return _u +} + +// SetNillableAccessToken sets the "access_token" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableAccessToken(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetAccessToken(*v) + } + return _u +} + +// ClearAccessToken clears the value of the "access_token" field. +func (_u *SoraAccountUpdate) ClearAccessToken() *SoraAccountUpdate { + _u.mutation.ClearAccessToken() + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *SoraAccountUpdate) SetSessionToken(v string) *SoraAccountUpdate { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSessionToken(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// ClearSessionToken clears the value of the "session_token" field. +func (_u *SoraAccountUpdate) ClearSessionToken() *SoraAccountUpdate { + _u.mutation.ClearSessionToken() + return _u +} + +// SetRefreshToken sets the "refresh_token" field. +func (_u *SoraAccountUpdate) SetRefreshToken(v string) *SoraAccountUpdate { + _u.mutation.SetRefreshToken(v) + return _u +} + +// SetNillableRefreshToken sets the "refresh_token" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableRefreshToken(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetRefreshToken(*v) + } + return _u +} + +// ClearRefreshToken clears the value of the "refresh_token" field. +func (_u *SoraAccountUpdate) ClearRefreshToken() *SoraAccountUpdate { + _u.mutation.ClearRefreshToken() + return _u +} + +// SetClientID sets the "client_id" field. +func (_u *SoraAccountUpdate) SetClientID(v string) *SoraAccountUpdate { + _u.mutation.SetClientID(v) + return _u +} + +// SetNillableClientID sets the "client_id" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableClientID(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetClientID(*v) + } + return _u +} + +// ClearClientID clears the value of the "client_id" field. +func (_u *SoraAccountUpdate) ClearClientID() *SoraAccountUpdate { + _u.mutation.ClearClientID() + return _u +} + +// SetEmail sets the "email" field. +func (_u *SoraAccountUpdate) SetEmail(v string) *SoraAccountUpdate { + _u.mutation.SetEmail(v) + return _u +} + +// SetNillableEmail sets the "email" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableEmail(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetEmail(*v) + } + return _u +} + +// ClearEmail clears the value of the "email" field. +func (_u *SoraAccountUpdate) ClearEmail() *SoraAccountUpdate { + _u.mutation.ClearEmail() + return _u +} + +// SetUsername sets the "username" field. +func (_u *SoraAccountUpdate) SetUsername(v string) *SoraAccountUpdate { + _u.mutation.SetUsername(v) + return _u +} + +// SetNillableUsername sets the "username" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableUsername(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetUsername(*v) + } + return _u +} + +// ClearUsername clears the value of the "username" field. +func (_u *SoraAccountUpdate) ClearUsername() *SoraAccountUpdate { + _u.mutation.ClearUsername() + return _u +} + +// SetRemark sets the "remark" field. +func (_u *SoraAccountUpdate) SetRemark(v string) *SoraAccountUpdate { + _u.mutation.SetRemark(v) + return _u +} + +// SetNillableRemark sets the "remark" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableRemark(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetRemark(*v) + } + return _u +} + +// ClearRemark clears the value of the "remark" field. +func (_u *SoraAccountUpdate) ClearRemark() *SoraAccountUpdate { + _u.mutation.ClearRemark() + return _u +} + +// SetUseCount sets the "use_count" field. +func (_u *SoraAccountUpdate) SetUseCount(v int) *SoraAccountUpdate { + _u.mutation.ResetUseCount() + _u.mutation.SetUseCount(v) + return _u +} + +// SetNillableUseCount sets the "use_count" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableUseCount(v *int) *SoraAccountUpdate { + if v != nil { + _u.SetUseCount(*v) + } + return _u +} + +// AddUseCount adds value to the "use_count" field. +func (_u *SoraAccountUpdate) AddUseCount(v int) *SoraAccountUpdate { + _u.mutation.AddUseCount(v) + return _u +} + +// SetPlanType sets the "plan_type" field. +func (_u *SoraAccountUpdate) SetPlanType(v string) *SoraAccountUpdate { + _u.mutation.SetPlanType(v) + return _u +} + +// SetNillablePlanType sets the "plan_type" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillablePlanType(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetPlanType(*v) + } + return _u +} + +// ClearPlanType clears the value of the "plan_type" field. +func (_u *SoraAccountUpdate) ClearPlanType() *SoraAccountUpdate { + _u.mutation.ClearPlanType() + return _u +} + +// SetPlanTitle sets the "plan_title" field. +func (_u *SoraAccountUpdate) SetPlanTitle(v string) *SoraAccountUpdate { + _u.mutation.SetPlanTitle(v) + return _u +} + +// SetNillablePlanTitle sets the "plan_title" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillablePlanTitle(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetPlanTitle(*v) + } + return _u +} + +// ClearPlanTitle clears the value of the "plan_title" field. +func (_u *SoraAccountUpdate) ClearPlanTitle() *SoraAccountUpdate { + _u.mutation.ClearPlanTitle() + return _u +} + +// SetSubscriptionEnd sets the "subscription_end" field. +func (_u *SoraAccountUpdate) SetSubscriptionEnd(v time.Time) *SoraAccountUpdate { + _u.mutation.SetSubscriptionEnd(v) + return _u +} + +// SetNillableSubscriptionEnd sets the "subscription_end" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSubscriptionEnd(v *time.Time) *SoraAccountUpdate { + if v != nil { + _u.SetSubscriptionEnd(*v) + } + return _u +} + +// ClearSubscriptionEnd clears the value of the "subscription_end" field. +func (_u *SoraAccountUpdate) ClearSubscriptionEnd() *SoraAccountUpdate { + _u.mutation.ClearSubscriptionEnd() + return _u +} + +// SetSoraSupported sets the "sora_supported" field. +func (_u *SoraAccountUpdate) SetSoraSupported(v bool) *SoraAccountUpdate { + _u.mutation.SetSoraSupported(v) + return _u +} + +// SetNillableSoraSupported sets the "sora_supported" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSoraSupported(v *bool) *SoraAccountUpdate { + if v != nil { + _u.SetSoraSupported(*v) + } + return _u +} + +// SetSoraInviteCode sets the "sora_invite_code" field. +func (_u *SoraAccountUpdate) SetSoraInviteCode(v string) *SoraAccountUpdate { + _u.mutation.SetSoraInviteCode(v) + return _u +} + +// SetNillableSoraInviteCode sets the "sora_invite_code" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSoraInviteCode(v *string) *SoraAccountUpdate { + if v != nil { + _u.SetSoraInviteCode(*v) + } + return _u +} + +// ClearSoraInviteCode clears the value of the "sora_invite_code" field. +func (_u *SoraAccountUpdate) ClearSoraInviteCode() *SoraAccountUpdate { + _u.mutation.ClearSoraInviteCode() + return _u +} + +// SetSoraRedeemedCount sets the "sora_redeemed_count" field. +func (_u *SoraAccountUpdate) SetSoraRedeemedCount(v int) *SoraAccountUpdate { + _u.mutation.ResetSoraRedeemedCount() + _u.mutation.SetSoraRedeemedCount(v) + return _u +} + +// SetNillableSoraRedeemedCount sets the "sora_redeemed_count" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSoraRedeemedCount(v *int) *SoraAccountUpdate { + if v != nil { + _u.SetSoraRedeemedCount(*v) + } + return _u +} + +// AddSoraRedeemedCount adds value to the "sora_redeemed_count" field. +func (_u *SoraAccountUpdate) AddSoraRedeemedCount(v int) *SoraAccountUpdate { + _u.mutation.AddSoraRedeemedCount(v) + return _u +} + +// SetSoraRemainingCount sets the "sora_remaining_count" field. +func (_u *SoraAccountUpdate) SetSoraRemainingCount(v int) *SoraAccountUpdate { + _u.mutation.ResetSoraRemainingCount() + _u.mutation.SetSoraRemainingCount(v) + return _u +} + +// SetNillableSoraRemainingCount sets the "sora_remaining_count" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSoraRemainingCount(v *int) *SoraAccountUpdate { + if v != nil { + _u.SetSoraRemainingCount(*v) + } + return _u +} + +// AddSoraRemainingCount adds value to the "sora_remaining_count" field. +func (_u *SoraAccountUpdate) AddSoraRemainingCount(v int) *SoraAccountUpdate { + _u.mutation.AddSoraRemainingCount(v) + return _u +} + +// SetSoraTotalCount sets the "sora_total_count" field. +func (_u *SoraAccountUpdate) SetSoraTotalCount(v int) *SoraAccountUpdate { + _u.mutation.ResetSoraTotalCount() + _u.mutation.SetSoraTotalCount(v) + return _u +} + +// SetNillableSoraTotalCount sets the "sora_total_count" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSoraTotalCount(v *int) *SoraAccountUpdate { + if v != nil { + _u.SetSoraTotalCount(*v) + } + return _u +} + +// AddSoraTotalCount adds value to the "sora_total_count" field. +func (_u *SoraAccountUpdate) AddSoraTotalCount(v int) *SoraAccountUpdate { + _u.mutation.AddSoraTotalCount(v) + return _u +} + +// SetSoraCooldownUntil sets the "sora_cooldown_until" field. +func (_u *SoraAccountUpdate) SetSoraCooldownUntil(v time.Time) *SoraAccountUpdate { + _u.mutation.SetSoraCooldownUntil(v) + return _u +} + +// SetNillableSoraCooldownUntil sets the "sora_cooldown_until" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableSoraCooldownUntil(v *time.Time) *SoraAccountUpdate { + if v != nil { + _u.SetSoraCooldownUntil(*v) + } + return _u +} + +// ClearSoraCooldownUntil clears the value of the "sora_cooldown_until" field. +func (_u *SoraAccountUpdate) ClearSoraCooldownUntil() *SoraAccountUpdate { + _u.mutation.ClearSoraCooldownUntil() + return _u +} + +// SetCooledUntil sets the "cooled_until" field. +func (_u *SoraAccountUpdate) SetCooledUntil(v time.Time) *SoraAccountUpdate { + _u.mutation.SetCooledUntil(v) + return _u +} + +// SetNillableCooledUntil sets the "cooled_until" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableCooledUntil(v *time.Time) *SoraAccountUpdate { + if v != nil { + _u.SetCooledUntil(*v) + } + return _u +} + +// ClearCooledUntil clears the value of the "cooled_until" field. +func (_u *SoraAccountUpdate) ClearCooledUntil() *SoraAccountUpdate { + _u.mutation.ClearCooledUntil() + return _u +} + +// SetImageEnabled sets the "image_enabled" field. +func (_u *SoraAccountUpdate) SetImageEnabled(v bool) *SoraAccountUpdate { + _u.mutation.SetImageEnabled(v) + return _u +} + +// SetNillableImageEnabled sets the "image_enabled" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableImageEnabled(v *bool) *SoraAccountUpdate { + if v != nil { + _u.SetImageEnabled(*v) + } + return _u +} + +// SetVideoEnabled sets the "video_enabled" field. +func (_u *SoraAccountUpdate) SetVideoEnabled(v bool) *SoraAccountUpdate { + _u.mutation.SetVideoEnabled(v) + return _u +} + +// SetNillableVideoEnabled sets the "video_enabled" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableVideoEnabled(v *bool) *SoraAccountUpdate { + if v != nil { + _u.SetVideoEnabled(*v) + } + return _u +} + +// SetImageConcurrency sets the "image_concurrency" field. +func (_u *SoraAccountUpdate) SetImageConcurrency(v int) *SoraAccountUpdate { + _u.mutation.ResetImageConcurrency() + _u.mutation.SetImageConcurrency(v) + return _u +} + +// SetNillableImageConcurrency sets the "image_concurrency" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableImageConcurrency(v *int) *SoraAccountUpdate { + if v != nil { + _u.SetImageConcurrency(*v) + } + return _u +} + +// AddImageConcurrency adds value to the "image_concurrency" field. +func (_u *SoraAccountUpdate) AddImageConcurrency(v int) *SoraAccountUpdate { + _u.mutation.AddImageConcurrency(v) + return _u +} + +// SetVideoConcurrency sets the "video_concurrency" field. +func (_u *SoraAccountUpdate) SetVideoConcurrency(v int) *SoraAccountUpdate { + _u.mutation.ResetVideoConcurrency() + _u.mutation.SetVideoConcurrency(v) + return _u +} + +// SetNillableVideoConcurrency sets the "video_concurrency" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableVideoConcurrency(v *int) *SoraAccountUpdate { + if v != nil { + _u.SetVideoConcurrency(*v) + } + return _u +} + +// AddVideoConcurrency adds value to the "video_concurrency" field. +func (_u *SoraAccountUpdate) AddVideoConcurrency(v int) *SoraAccountUpdate { + _u.mutation.AddVideoConcurrency(v) + return _u +} + +// SetIsExpired sets the "is_expired" field. +func (_u *SoraAccountUpdate) SetIsExpired(v bool) *SoraAccountUpdate { + _u.mutation.SetIsExpired(v) + return _u +} + +// SetNillableIsExpired sets the "is_expired" field if the given value is not nil. +func (_u *SoraAccountUpdate) SetNillableIsExpired(v *bool) *SoraAccountUpdate { + if v != nil { + _u.SetIsExpired(*v) + } + return _u +} + +// Mutation returns the SoraAccountMutation object of the builder. +func (_u *SoraAccountUpdate) Mutation() *SoraAccountMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *SoraAccountUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraAccountUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *SoraAccountUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraAccountUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SoraAccountUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := soraaccount.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +func (_u *SoraAccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { + _spec := sqlgraph.NewUpdateSpec(soraaccount.Table, soraaccount.Columns, sqlgraph.NewFieldSpec(soraaccount.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(soraaccount.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(soraaccount.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(soraaccount.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AccessToken(); ok { + _spec.SetField(soraaccount.FieldAccessToken, field.TypeString, value) + } + if _u.mutation.AccessTokenCleared() { + _spec.ClearField(soraaccount.FieldAccessToken, field.TypeString) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(soraaccount.FieldSessionToken, field.TypeString, value) + } + if _u.mutation.SessionTokenCleared() { + _spec.ClearField(soraaccount.FieldSessionToken, field.TypeString) + } + if value, ok := _u.mutation.RefreshToken(); ok { + _spec.SetField(soraaccount.FieldRefreshToken, field.TypeString, value) + } + if _u.mutation.RefreshTokenCleared() { + _spec.ClearField(soraaccount.FieldRefreshToken, field.TypeString) + } + if value, ok := _u.mutation.ClientID(); ok { + _spec.SetField(soraaccount.FieldClientID, field.TypeString, value) + } + if _u.mutation.ClientIDCleared() { + _spec.ClearField(soraaccount.FieldClientID, field.TypeString) + } + if value, ok := _u.mutation.Email(); ok { + _spec.SetField(soraaccount.FieldEmail, field.TypeString, value) + } + if _u.mutation.EmailCleared() { + _spec.ClearField(soraaccount.FieldEmail, field.TypeString) + } + if value, ok := _u.mutation.Username(); ok { + _spec.SetField(soraaccount.FieldUsername, field.TypeString, value) + } + if _u.mutation.UsernameCleared() { + _spec.ClearField(soraaccount.FieldUsername, field.TypeString) + } + if value, ok := _u.mutation.Remark(); ok { + _spec.SetField(soraaccount.FieldRemark, field.TypeString, value) + } + if _u.mutation.RemarkCleared() { + _spec.ClearField(soraaccount.FieldRemark, field.TypeString) + } + if value, ok := _u.mutation.UseCount(); ok { + _spec.SetField(soraaccount.FieldUseCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedUseCount(); ok { + _spec.AddField(soraaccount.FieldUseCount, field.TypeInt, value) + } + if value, ok := _u.mutation.PlanType(); ok { + _spec.SetField(soraaccount.FieldPlanType, field.TypeString, value) + } + if _u.mutation.PlanTypeCleared() { + _spec.ClearField(soraaccount.FieldPlanType, field.TypeString) + } + if value, ok := _u.mutation.PlanTitle(); ok { + _spec.SetField(soraaccount.FieldPlanTitle, field.TypeString, value) + } + if _u.mutation.PlanTitleCleared() { + _spec.ClearField(soraaccount.FieldPlanTitle, field.TypeString) + } + if value, ok := _u.mutation.SubscriptionEnd(); ok { + _spec.SetField(soraaccount.FieldSubscriptionEnd, field.TypeTime, value) + } + if _u.mutation.SubscriptionEndCleared() { + _spec.ClearField(soraaccount.FieldSubscriptionEnd, field.TypeTime) + } + if value, ok := _u.mutation.SoraSupported(); ok { + _spec.SetField(soraaccount.FieldSoraSupported, field.TypeBool, value) + } + if value, ok := _u.mutation.SoraInviteCode(); ok { + _spec.SetField(soraaccount.FieldSoraInviteCode, field.TypeString, value) + } + if _u.mutation.SoraInviteCodeCleared() { + _spec.ClearField(soraaccount.FieldSoraInviteCode, field.TypeString) + } + if value, ok := _u.mutation.SoraRedeemedCount(); ok { + _spec.SetField(soraaccount.FieldSoraRedeemedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSoraRedeemedCount(); ok { + _spec.AddField(soraaccount.FieldSoraRedeemedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SoraRemainingCount(); ok { + _spec.SetField(soraaccount.FieldSoraRemainingCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSoraRemainingCount(); ok { + _spec.AddField(soraaccount.FieldSoraRemainingCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SoraTotalCount(); ok { + _spec.SetField(soraaccount.FieldSoraTotalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSoraTotalCount(); ok { + _spec.AddField(soraaccount.FieldSoraTotalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SoraCooldownUntil(); ok { + _spec.SetField(soraaccount.FieldSoraCooldownUntil, field.TypeTime, value) + } + if _u.mutation.SoraCooldownUntilCleared() { + _spec.ClearField(soraaccount.FieldSoraCooldownUntil, field.TypeTime) + } + if value, ok := _u.mutation.CooledUntil(); ok { + _spec.SetField(soraaccount.FieldCooledUntil, field.TypeTime, value) + } + if _u.mutation.CooledUntilCleared() { + _spec.ClearField(soraaccount.FieldCooledUntil, field.TypeTime) + } + if value, ok := _u.mutation.ImageEnabled(); ok { + _spec.SetField(soraaccount.FieldImageEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.VideoEnabled(); ok { + _spec.SetField(soraaccount.FieldVideoEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageConcurrency(); ok { + _spec.SetField(soraaccount.FieldImageConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedImageConcurrency(); ok { + _spec.AddField(soraaccount.FieldImageConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.VideoConcurrency(); ok { + _spec.SetField(soraaccount.FieldVideoConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedVideoConcurrency(); ok { + _spec.AddField(soraaccount.FieldVideoConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.IsExpired(); ok { + _spec.SetField(soraaccount.FieldIsExpired, field.TypeBool, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{soraaccount.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// SoraAccountUpdateOne is the builder for updating a single SoraAccount entity. +type SoraAccountUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SoraAccountMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SoraAccountUpdateOne) SetUpdatedAt(v time.Time) *SoraAccountUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraAccountUpdateOne) SetAccountID(v int64) *SoraAccountUpdateOne { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableAccountID(v *int64) *SoraAccountUpdateOne { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraAccountUpdateOne) AddAccountID(v int64) *SoraAccountUpdateOne { + _u.mutation.AddAccountID(v) + return _u +} + +// SetAccessToken sets the "access_token" field. +func (_u *SoraAccountUpdateOne) SetAccessToken(v string) *SoraAccountUpdateOne { + _u.mutation.SetAccessToken(v) + return _u +} + +// SetNillableAccessToken sets the "access_token" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableAccessToken(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetAccessToken(*v) + } + return _u +} + +// ClearAccessToken clears the value of the "access_token" field. +func (_u *SoraAccountUpdateOne) ClearAccessToken() *SoraAccountUpdateOne { + _u.mutation.ClearAccessToken() + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *SoraAccountUpdateOne) SetSessionToken(v string) *SoraAccountUpdateOne { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSessionToken(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// ClearSessionToken clears the value of the "session_token" field. +func (_u *SoraAccountUpdateOne) ClearSessionToken() *SoraAccountUpdateOne { + _u.mutation.ClearSessionToken() + return _u +} + +// SetRefreshToken sets the "refresh_token" field. +func (_u *SoraAccountUpdateOne) SetRefreshToken(v string) *SoraAccountUpdateOne { + _u.mutation.SetRefreshToken(v) + return _u +} + +// SetNillableRefreshToken sets the "refresh_token" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableRefreshToken(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetRefreshToken(*v) + } + return _u +} + +// ClearRefreshToken clears the value of the "refresh_token" field. +func (_u *SoraAccountUpdateOne) ClearRefreshToken() *SoraAccountUpdateOne { + _u.mutation.ClearRefreshToken() + return _u +} + +// SetClientID sets the "client_id" field. +func (_u *SoraAccountUpdateOne) SetClientID(v string) *SoraAccountUpdateOne { + _u.mutation.SetClientID(v) + return _u +} + +// SetNillableClientID sets the "client_id" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableClientID(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetClientID(*v) + } + return _u +} + +// ClearClientID clears the value of the "client_id" field. +func (_u *SoraAccountUpdateOne) ClearClientID() *SoraAccountUpdateOne { + _u.mutation.ClearClientID() + return _u +} + +// SetEmail sets the "email" field. +func (_u *SoraAccountUpdateOne) SetEmail(v string) *SoraAccountUpdateOne { + _u.mutation.SetEmail(v) + return _u +} + +// SetNillableEmail sets the "email" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableEmail(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetEmail(*v) + } + return _u +} + +// ClearEmail clears the value of the "email" field. +func (_u *SoraAccountUpdateOne) ClearEmail() *SoraAccountUpdateOne { + _u.mutation.ClearEmail() + return _u +} + +// SetUsername sets the "username" field. +func (_u *SoraAccountUpdateOne) SetUsername(v string) *SoraAccountUpdateOne { + _u.mutation.SetUsername(v) + return _u +} + +// SetNillableUsername sets the "username" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableUsername(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetUsername(*v) + } + return _u +} + +// ClearUsername clears the value of the "username" field. +func (_u *SoraAccountUpdateOne) ClearUsername() *SoraAccountUpdateOne { + _u.mutation.ClearUsername() + return _u +} + +// SetRemark sets the "remark" field. +func (_u *SoraAccountUpdateOne) SetRemark(v string) *SoraAccountUpdateOne { + _u.mutation.SetRemark(v) + return _u +} + +// SetNillableRemark sets the "remark" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableRemark(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetRemark(*v) + } + return _u +} + +// ClearRemark clears the value of the "remark" field. +func (_u *SoraAccountUpdateOne) ClearRemark() *SoraAccountUpdateOne { + _u.mutation.ClearRemark() + return _u +} + +// SetUseCount sets the "use_count" field. +func (_u *SoraAccountUpdateOne) SetUseCount(v int) *SoraAccountUpdateOne { + _u.mutation.ResetUseCount() + _u.mutation.SetUseCount(v) + return _u +} + +// SetNillableUseCount sets the "use_count" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableUseCount(v *int) *SoraAccountUpdateOne { + if v != nil { + _u.SetUseCount(*v) + } + return _u +} + +// AddUseCount adds value to the "use_count" field. +func (_u *SoraAccountUpdateOne) AddUseCount(v int) *SoraAccountUpdateOne { + _u.mutation.AddUseCount(v) + return _u +} + +// SetPlanType sets the "plan_type" field. +func (_u *SoraAccountUpdateOne) SetPlanType(v string) *SoraAccountUpdateOne { + _u.mutation.SetPlanType(v) + return _u +} + +// SetNillablePlanType sets the "plan_type" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillablePlanType(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetPlanType(*v) + } + return _u +} + +// ClearPlanType clears the value of the "plan_type" field. +func (_u *SoraAccountUpdateOne) ClearPlanType() *SoraAccountUpdateOne { + _u.mutation.ClearPlanType() + return _u +} + +// SetPlanTitle sets the "plan_title" field. +func (_u *SoraAccountUpdateOne) SetPlanTitle(v string) *SoraAccountUpdateOne { + _u.mutation.SetPlanTitle(v) + return _u +} + +// SetNillablePlanTitle sets the "plan_title" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillablePlanTitle(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetPlanTitle(*v) + } + return _u +} + +// ClearPlanTitle clears the value of the "plan_title" field. +func (_u *SoraAccountUpdateOne) ClearPlanTitle() *SoraAccountUpdateOne { + _u.mutation.ClearPlanTitle() + return _u +} + +// SetSubscriptionEnd sets the "subscription_end" field. +func (_u *SoraAccountUpdateOne) SetSubscriptionEnd(v time.Time) *SoraAccountUpdateOne { + _u.mutation.SetSubscriptionEnd(v) + return _u +} + +// SetNillableSubscriptionEnd sets the "subscription_end" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSubscriptionEnd(v *time.Time) *SoraAccountUpdateOne { + if v != nil { + _u.SetSubscriptionEnd(*v) + } + return _u +} + +// ClearSubscriptionEnd clears the value of the "subscription_end" field. +func (_u *SoraAccountUpdateOne) ClearSubscriptionEnd() *SoraAccountUpdateOne { + _u.mutation.ClearSubscriptionEnd() + return _u +} + +// SetSoraSupported sets the "sora_supported" field. +func (_u *SoraAccountUpdateOne) SetSoraSupported(v bool) *SoraAccountUpdateOne { + _u.mutation.SetSoraSupported(v) + return _u +} + +// SetNillableSoraSupported sets the "sora_supported" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSoraSupported(v *bool) *SoraAccountUpdateOne { + if v != nil { + _u.SetSoraSupported(*v) + } + return _u +} + +// SetSoraInviteCode sets the "sora_invite_code" field. +func (_u *SoraAccountUpdateOne) SetSoraInviteCode(v string) *SoraAccountUpdateOne { + _u.mutation.SetSoraInviteCode(v) + return _u +} + +// SetNillableSoraInviteCode sets the "sora_invite_code" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSoraInviteCode(v *string) *SoraAccountUpdateOne { + if v != nil { + _u.SetSoraInviteCode(*v) + } + return _u +} + +// ClearSoraInviteCode clears the value of the "sora_invite_code" field. +func (_u *SoraAccountUpdateOne) ClearSoraInviteCode() *SoraAccountUpdateOne { + _u.mutation.ClearSoraInviteCode() + return _u +} + +// SetSoraRedeemedCount sets the "sora_redeemed_count" field. +func (_u *SoraAccountUpdateOne) SetSoraRedeemedCount(v int) *SoraAccountUpdateOne { + _u.mutation.ResetSoraRedeemedCount() + _u.mutation.SetSoraRedeemedCount(v) + return _u +} + +// SetNillableSoraRedeemedCount sets the "sora_redeemed_count" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSoraRedeemedCount(v *int) *SoraAccountUpdateOne { + if v != nil { + _u.SetSoraRedeemedCount(*v) + } + return _u +} + +// AddSoraRedeemedCount adds value to the "sora_redeemed_count" field. +func (_u *SoraAccountUpdateOne) AddSoraRedeemedCount(v int) *SoraAccountUpdateOne { + _u.mutation.AddSoraRedeemedCount(v) + return _u +} + +// SetSoraRemainingCount sets the "sora_remaining_count" field. +func (_u *SoraAccountUpdateOne) SetSoraRemainingCount(v int) *SoraAccountUpdateOne { + _u.mutation.ResetSoraRemainingCount() + _u.mutation.SetSoraRemainingCount(v) + return _u +} + +// SetNillableSoraRemainingCount sets the "sora_remaining_count" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSoraRemainingCount(v *int) *SoraAccountUpdateOne { + if v != nil { + _u.SetSoraRemainingCount(*v) + } + return _u +} + +// AddSoraRemainingCount adds value to the "sora_remaining_count" field. +func (_u *SoraAccountUpdateOne) AddSoraRemainingCount(v int) *SoraAccountUpdateOne { + _u.mutation.AddSoraRemainingCount(v) + return _u +} + +// SetSoraTotalCount sets the "sora_total_count" field. +func (_u *SoraAccountUpdateOne) SetSoraTotalCount(v int) *SoraAccountUpdateOne { + _u.mutation.ResetSoraTotalCount() + _u.mutation.SetSoraTotalCount(v) + return _u +} + +// SetNillableSoraTotalCount sets the "sora_total_count" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSoraTotalCount(v *int) *SoraAccountUpdateOne { + if v != nil { + _u.SetSoraTotalCount(*v) + } + return _u +} + +// AddSoraTotalCount adds value to the "sora_total_count" field. +func (_u *SoraAccountUpdateOne) AddSoraTotalCount(v int) *SoraAccountUpdateOne { + _u.mutation.AddSoraTotalCount(v) + return _u +} + +// SetSoraCooldownUntil sets the "sora_cooldown_until" field. +func (_u *SoraAccountUpdateOne) SetSoraCooldownUntil(v time.Time) *SoraAccountUpdateOne { + _u.mutation.SetSoraCooldownUntil(v) + return _u +} + +// SetNillableSoraCooldownUntil sets the "sora_cooldown_until" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableSoraCooldownUntil(v *time.Time) *SoraAccountUpdateOne { + if v != nil { + _u.SetSoraCooldownUntil(*v) + } + return _u +} + +// ClearSoraCooldownUntil clears the value of the "sora_cooldown_until" field. +func (_u *SoraAccountUpdateOne) ClearSoraCooldownUntil() *SoraAccountUpdateOne { + _u.mutation.ClearSoraCooldownUntil() + return _u +} + +// SetCooledUntil sets the "cooled_until" field. +func (_u *SoraAccountUpdateOne) SetCooledUntil(v time.Time) *SoraAccountUpdateOne { + _u.mutation.SetCooledUntil(v) + return _u +} + +// SetNillableCooledUntil sets the "cooled_until" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableCooledUntil(v *time.Time) *SoraAccountUpdateOne { + if v != nil { + _u.SetCooledUntil(*v) + } + return _u +} + +// ClearCooledUntil clears the value of the "cooled_until" field. +func (_u *SoraAccountUpdateOne) ClearCooledUntil() *SoraAccountUpdateOne { + _u.mutation.ClearCooledUntil() + return _u +} + +// SetImageEnabled sets the "image_enabled" field. +func (_u *SoraAccountUpdateOne) SetImageEnabled(v bool) *SoraAccountUpdateOne { + _u.mutation.SetImageEnabled(v) + return _u +} + +// SetNillableImageEnabled sets the "image_enabled" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableImageEnabled(v *bool) *SoraAccountUpdateOne { + if v != nil { + _u.SetImageEnabled(*v) + } + return _u +} + +// SetVideoEnabled sets the "video_enabled" field. +func (_u *SoraAccountUpdateOne) SetVideoEnabled(v bool) *SoraAccountUpdateOne { + _u.mutation.SetVideoEnabled(v) + return _u +} + +// SetNillableVideoEnabled sets the "video_enabled" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableVideoEnabled(v *bool) *SoraAccountUpdateOne { + if v != nil { + _u.SetVideoEnabled(*v) + } + return _u +} + +// SetImageConcurrency sets the "image_concurrency" field. +func (_u *SoraAccountUpdateOne) SetImageConcurrency(v int) *SoraAccountUpdateOne { + _u.mutation.ResetImageConcurrency() + _u.mutation.SetImageConcurrency(v) + return _u +} + +// SetNillableImageConcurrency sets the "image_concurrency" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableImageConcurrency(v *int) *SoraAccountUpdateOne { + if v != nil { + _u.SetImageConcurrency(*v) + } + return _u +} + +// AddImageConcurrency adds value to the "image_concurrency" field. +func (_u *SoraAccountUpdateOne) AddImageConcurrency(v int) *SoraAccountUpdateOne { + _u.mutation.AddImageConcurrency(v) + return _u +} + +// SetVideoConcurrency sets the "video_concurrency" field. +func (_u *SoraAccountUpdateOne) SetVideoConcurrency(v int) *SoraAccountUpdateOne { + _u.mutation.ResetVideoConcurrency() + _u.mutation.SetVideoConcurrency(v) + return _u +} + +// SetNillableVideoConcurrency sets the "video_concurrency" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableVideoConcurrency(v *int) *SoraAccountUpdateOne { + if v != nil { + _u.SetVideoConcurrency(*v) + } + return _u +} + +// AddVideoConcurrency adds value to the "video_concurrency" field. +func (_u *SoraAccountUpdateOne) AddVideoConcurrency(v int) *SoraAccountUpdateOne { + _u.mutation.AddVideoConcurrency(v) + return _u +} + +// SetIsExpired sets the "is_expired" field. +func (_u *SoraAccountUpdateOne) SetIsExpired(v bool) *SoraAccountUpdateOne { + _u.mutation.SetIsExpired(v) + return _u +} + +// SetNillableIsExpired sets the "is_expired" field if the given value is not nil. +func (_u *SoraAccountUpdateOne) SetNillableIsExpired(v *bool) *SoraAccountUpdateOne { + if v != nil { + _u.SetIsExpired(*v) + } + return _u +} + +// Mutation returns the SoraAccountMutation object of the builder. +func (_u *SoraAccountUpdateOne) Mutation() *SoraAccountMutation { + return _u.mutation +} + +// Where appends a list predicates to the SoraAccountUpdate builder. +func (_u *SoraAccountUpdateOne) Where(ps ...predicate.SoraAccount) *SoraAccountUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *SoraAccountUpdateOne) Select(field string, fields ...string) *SoraAccountUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated SoraAccount entity. +func (_u *SoraAccountUpdateOne) Save(ctx context.Context) (*SoraAccount, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraAccountUpdateOne) SaveX(ctx context.Context) *SoraAccount { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *SoraAccountUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraAccountUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SoraAccountUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := soraaccount.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +func (_u *SoraAccountUpdateOne) sqlSave(ctx context.Context) (_node *SoraAccount, err error) { + _spec := sqlgraph.NewUpdateSpec(soraaccount.Table, soraaccount.Columns, sqlgraph.NewFieldSpec(soraaccount.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SoraAccount.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, soraaccount.FieldID) + for _, f := range fields { + if !soraaccount.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != soraaccount.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(soraaccount.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(soraaccount.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(soraaccount.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AccessToken(); ok { + _spec.SetField(soraaccount.FieldAccessToken, field.TypeString, value) + } + if _u.mutation.AccessTokenCleared() { + _spec.ClearField(soraaccount.FieldAccessToken, field.TypeString) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(soraaccount.FieldSessionToken, field.TypeString, value) + } + if _u.mutation.SessionTokenCleared() { + _spec.ClearField(soraaccount.FieldSessionToken, field.TypeString) + } + if value, ok := _u.mutation.RefreshToken(); ok { + _spec.SetField(soraaccount.FieldRefreshToken, field.TypeString, value) + } + if _u.mutation.RefreshTokenCleared() { + _spec.ClearField(soraaccount.FieldRefreshToken, field.TypeString) + } + if value, ok := _u.mutation.ClientID(); ok { + _spec.SetField(soraaccount.FieldClientID, field.TypeString, value) + } + if _u.mutation.ClientIDCleared() { + _spec.ClearField(soraaccount.FieldClientID, field.TypeString) + } + if value, ok := _u.mutation.Email(); ok { + _spec.SetField(soraaccount.FieldEmail, field.TypeString, value) + } + if _u.mutation.EmailCleared() { + _spec.ClearField(soraaccount.FieldEmail, field.TypeString) + } + if value, ok := _u.mutation.Username(); ok { + _spec.SetField(soraaccount.FieldUsername, field.TypeString, value) + } + if _u.mutation.UsernameCleared() { + _spec.ClearField(soraaccount.FieldUsername, field.TypeString) + } + if value, ok := _u.mutation.Remark(); ok { + _spec.SetField(soraaccount.FieldRemark, field.TypeString, value) + } + if _u.mutation.RemarkCleared() { + _spec.ClearField(soraaccount.FieldRemark, field.TypeString) + } + if value, ok := _u.mutation.UseCount(); ok { + _spec.SetField(soraaccount.FieldUseCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedUseCount(); ok { + _spec.AddField(soraaccount.FieldUseCount, field.TypeInt, value) + } + if value, ok := _u.mutation.PlanType(); ok { + _spec.SetField(soraaccount.FieldPlanType, field.TypeString, value) + } + if _u.mutation.PlanTypeCleared() { + _spec.ClearField(soraaccount.FieldPlanType, field.TypeString) + } + if value, ok := _u.mutation.PlanTitle(); ok { + _spec.SetField(soraaccount.FieldPlanTitle, field.TypeString, value) + } + if _u.mutation.PlanTitleCleared() { + _spec.ClearField(soraaccount.FieldPlanTitle, field.TypeString) + } + if value, ok := _u.mutation.SubscriptionEnd(); ok { + _spec.SetField(soraaccount.FieldSubscriptionEnd, field.TypeTime, value) + } + if _u.mutation.SubscriptionEndCleared() { + _spec.ClearField(soraaccount.FieldSubscriptionEnd, field.TypeTime) + } + if value, ok := _u.mutation.SoraSupported(); ok { + _spec.SetField(soraaccount.FieldSoraSupported, field.TypeBool, value) + } + if value, ok := _u.mutation.SoraInviteCode(); ok { + _spec.SetField(soraaccount.FieldSoraInviteCode, field.TypeString, value) + } + if _u.mutation.SoraInviteCodeCleared() { + _spec.ClearField(soraaccount.FieldSoraInviteCode, field.TypeString) + } + if value, ok := _u.mutation.SoraRedeemedCount(); ok { + _spec.SetField(soraaccount.FieldSoraRedeemedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSoraRedeemedCount(); ok { + _spec.AddField(soraaccount.FieldSoraRedeemedCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SoraRemainingCount(); ok { + _spec.SetField(soraaccount.FieldSoraRemainingCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSoraRemainingCount(); ok { + _spec.AddField(soraaccount.FieldSoraRemainingCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SoraTotalCount(); ok { + _spec.SetField(soraaccount.FieldSoraTotalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSoraTotalCount(); ok { + _spec.AddField(soraaccount.FieldSoraTotalCount, field.TypeInt, value) + } + if value, ok := _u.mutation.SoraCooldownUntil(); ok { + _spec.SetField(soraaccount.FieldSoraCooldownUntil, field.TypeTime, value) + } + if _u.mutation.SoraCooldownUntilCleared() { + _spec.ClearField(soraaccount.FieldSoraCooldownUntil, field.TypeTime) + } + if value, ok := _u.mutation.CooledUntil(); ok { + _spec.SetField(soraaccount.FieldCooledUntil, field.TypeTime, value) + } + if _u.mutation.CooledUntilCleared() { + _spec.ClearField(soraaccount.FieldCooledUntil, field.TypeTime) + } + if value, ok := _u.mutation.ImageEnabled(); ok { + _spec.SetField(soraaccount.FieldImageEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.VideoEnabled(); ok { + _spec.SetField(soraaccount.FieldVideoEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageConcurrency(); ok { + _spec.SetField(soraaccount.FieldImageConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedImageConcurrency(); ok { + _spec.AddField(soraaccount.FieldImageConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.VideoConcurrency(); ok { + _spec.SetField(soraaccount.FieldVideoConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedVideoConcurrency(); ok { + _spec.AddField(soraaccount.FieldVideoConcurrency, field.TypeInt, value) + } + if value, ok := _u.mutation.IsExpired(); ok { + _spec.SetField(soraaccount.FieldIsExpired, field.TypeBool, value) + } + _node = &SoraAccount{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{soraaccount.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/soracachefile.go b/backend/ent/soracachefile.go new file mode 100644 index 00000000..8a44074a --- /dev/null +++ b/backend/ent/soracachefile.go @@ -0,0 +1,197 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" +) + +// SoraCacheFile is the model entity for the SoraCacheFile schema. +type SoraCacheFile struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // TaskID holds the value of the "task_id" field. + TaskID *string `json:"task_id,omitempty"` + // AccountID holds the value of the "account_id" field. + AccountID int64 `json:"account_id,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // MediaType holds the value of the "media_type" field. + MediaType string `json:"media_type,omitempty"` + // OriginalURL holds the value of the "original_url" field. + OriginalURL string `json:"original_url,omitempty"` + // CachePath holds the value of the "cache_path" field. + CachePath string `json:"cache_path,omitempty"` + // CacheURL holds the value of the "cache_url" field. + CacheURL string `json:"cache_url,omitempty"` + // SizeBytes holds the value of the "size_bytes" field. + SizeBytes int64 `json:"size_bytes,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*SoraCacheFile) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case soracachefile.FieldID, soracachefile.FieldAccountID, soracachefile.FieldUserID, soracachefile.FieldSizeBytes: + values[i] = new(sql.NullInt64) + case soracachefile.FieldTaskID, soracachefile.FieldMediaType, soracachefile.FieldOriginalURL, soracachefile.FieldCachePath, soracachefile.FieldCacheURL: + values[i] = new(sql.NullString) + case soracachefile.FieldCreatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the SoraCacheFile fields. +func (_m *SoraCacheFile) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case soracachefile.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case soracachefile.FieldTaskID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field task_id", values[i]) + } else if value.Valid { + _m.TaskID = new(string) + *_m.TaskID = value.String + } + case soracachefile.FieldAccountID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field account_id", values[i]) + } else if value.Valid { + _m.AccountID = value.Int64 + } + case soracachefile.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case soracachefile.FieldMediaType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field media_type", values[i]) + } else if value.Valid { + _m.MediaType = value.String + } + case soracachefile.FieldOriginalURL: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field original_url", values[i]) + } else if value.Valid { + _m.OriginalURL = value.String + } + case soracachefile.FieldCachePath: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field cache_path", values[i]) + } else if value.Valid { + _m.CachePath = value.String + } + case soracachefile.FieldCacheURL: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field cache_url", values[i]) + } else if value.Valid { + _m.CacheURL = value.String + } + case soracachefile.FieldSizeBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field size_bytes", values[i]) + } else if value.Valid { + _m.SizeBytes = value.Int64 + } + case soracachefile.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the SoraCacheFile. +// This includes values selected through modifiers, order, etc. +func (_m *SoraCacheFile) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this SoraCacheFile. +// Note that you need to call SoraCacheFile.Unwrap() before calling this method if this SoraCacheFile +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *SoraCacheFile) Update() *SoraCacheFileUpdateOne { + return NewSoraCacheFileClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the SoraCacheFile entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *SoraCacheFile) Unwrap() *SoraCacheFile { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: SoraCacheFile is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *SoraCacheFile) String() string { + var builder strings.Builder + builder.WriteString("SoraCacheFile(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + if v := _m.TaskID; v != nil { + builder.WriteString("task_id=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("account_id=") + builder.WriteString(fmt.Sprintf("%v", _m.AccountID)) + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("media_type=") + builder.WriteString(_m.MediaType) + builder.WriteString(", ") + builder.WriteString("original_url=") + builder.WriteString(_m.OriginalURL) + builder.WriteString(", ") + builder.WriteString("cache_path=") + builder.WriteString(_m.CachePath) + builder.WriteString(", ") + builder.WriteString("cache_url=") + builder.WriteString(_m.CacheURL) + builder.WriteString(", ") + builder.WriteString("size_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SizeBytes)) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// SoraCacheFiles is a parsable slice of SoraCacheFile. +type SoraCacheFiles []*SoraCacheFile diff --git a/backend/ent/soracachefile/soracachefile.go b/backend/ent/soracachefile/soracachefile.go new file mode 100644 index 00000000..c39436a9 --- /dev/null +++ b/backend/ent/soracachefile/soracachefile.go @@ -0,0 +1,124 @@ +// Code generated by ent, DO NOT EDIT. + +package soracachefile + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the soracachefile type in the database. + Label = "sora_cache_file" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldTaskID holds the string denoting the task_id field in the database. + FieldTaskID = "task_id" + // FieldAccountID holds the string denoting the account_id field in the database. + FieldAccountID = "account_id" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldMediaType holds the string denoting the media_type field in the database. + FieldMediaType = "media_type" + // FieldOriginalURL holds the string denoting the original_url field in the database. + FieldOriginalURL = "original_url" + // FieldCachePath holds the string denoting the cache_path field in the database. + FieldCachePath = "cache_path" + // FieldCacheURL holds the string denoting the cache_url field in the database. + FieldCacheURL = "cache_url" + // FieldSizeBytes holds the string denoting the size_bytes field in the database. + FieldSizeBytes = "size_bytes" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // Table holds the table name of the soracachefile in the database. + Table = "sora_cache_files" +) + +// Columns holds all SQL columns for soracachefile fields. +var Columns = []string{ + FieldID, + FieldTaskID, + FieldAccountID, + FieldUserID, + FieldMediaType, + FieldOriginalURL, + FieldCachePath, + FieldCacheURL, + FieldSizeBytes, + FieldCreatedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // TaskIDValidator is a validator for the "task_id" field. It is called by the builders before save. + TaskIDValidator func(string) error + // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + MediaTypeValidator func(string) error + // DefaultSizeBytes holds the default value on creation for the "size_bytes" field. + DefaultSizeBytes int64 + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time +) + +// OrderOption defines the ordering options for the SoraCacheFile queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByTaskID orders the results by the task_id field. +func ByTaskID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTaskID, opts...).ToFunc() +} + +// ByAccountID orders the results by the account_id field. +func ByAccountID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccountID, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByMediaType orders the results by the media_type field. +func ByMediaType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMediaType, opts...).ToFunc() +} + +// ByOriginalURL orders the results by the original_url field. +func ByOriginalURL(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldOriginalURL, opts...).ToFunc() +} + +// ByCachePath orders the results by the cache_path field. +func ByCachePath(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCachePath, opts...).ToFunc() +} + +// ByCacheURL orders the results by the cache_url field. +func ByCacheURL(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheURL, opts...).ToFunc() +} + +// BySizeBytes orders the results by the size_bytes field. +func BySizeBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSizeBytes, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} diff --git a/backend/ent/soracachefile/where.go b/backend/ent/soracachefile/where.go new file mode 100644 index 00000000..a4d0ac93 --- /dev/null +++ b/backend/ent/soracachefile/where.go @@ -0,0 +1,610 @@ +// Code generated by ent, DO NOT EDIT. + +package soracachefile + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldID, id)) +} + +// TaskID applies equality check predicate on the "task_id" field. It's identical to TaskIDEQ. +func TaskID(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldTaskID, v)) +} + +// AccountID applies equality check predicate on the "account_id" field. It's identical to AccountIDEQ. +func AccountID(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldAccountID, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldUserID, v)) +} + +// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. +func MediaType(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldMediaType, v)) +} + +// OriginalURL applies equality check predicate on the "original_url" field. It's identical to OriginalURLEQ. +func OriginalURL(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldOriginalURL, v)) +} + +// CachePath applies equality check predicate on the "cache_path" field. It's identical to CachePathEQ. +func CachePath(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldCachePath, v)) +} + +// CacheURL applies equality check predicate on the "cache_url" field. It's identical to CacheURLEQ. +func CacheURL(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldCacheURL, v)) +} + +// SizeBytes applies equality check predicate on the "size_bytes" field. It's identical to SizeBytesEQ. +func SizeBytes(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldSizeBytes, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldCreatedAt, v)) +} + +// TaskIDEQ applies the EQ predicate on the "task_id" field. +func TaskIDEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldTaskID, v)) +} + +// TaskIDNEQ applies the NEQ predicate on the "task_id" field. +func TaskIDNEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldTaskID, v)) +} + +// TaskIDIn applies the In predicate on the "task_id" field. +func TaskIDIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldTaskID, vs...)) +} + +// TaskIDNotIn applies the NotIn predicate on the "task_id" field. +func TaskIDNotIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldTaskID, vs...)) +} + +// TaskIDGT applies the GT predicate on the "task_id" field. +func TaskIDGT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldTaskID, v)) +} + +// TaskIDGTE applies the GTE predicate on the "task_id" field. +func TaskIDGTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldTaskID, v)) +} + +// TaskIDLT applies the LT predicate on the "task_id" field. +func TaskIDLT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldTaskID, v)) +} + +// TaskIDLTE applies the LTE predicate on the "task_id" field. +func TaskIDLTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldTaskID, v)) +} + +// TaskIDContains applies the Contains predicate on the "task_id" field. +func TaskIDContains(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContains(FieldTaskID, v)) +} + +// TaskIDHasPrefix applies the HasPrefix predicate on the "task_id" field. +func TaskIDHasPrefix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasPrefix(FieldTaskID, v)) +} + +// TaskIDHasSuffix applies the HasSuffix predicate on the "task_id" field. +func TaskIDHasSuffix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasSuffix(FieldTaskID, v)) +} + +// TaskIDIsNil applies the IsNil predicate on the "task_id" field. +func TaskIDIsNil() predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIsNull(FieldTaskID)) +} + +// TaskIDNotNil applies the NotNil predicate on the "task_id" field. +func TaskIDNotNil() predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotNull(FieldTaskID)) +} + +// TaskIDEqualFold applies the EqualFold predicate on the "task_id" field. +func TaskIDEqualFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEqualFold(FieldTaskID, v)) +} + +// TaskIDContainsFold applies the ContainsFold predicate on the "task_id" field. +func TaskIDContainsFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContainsFold(FieldTaskID, v)) +} + +// AccountIDEQ applies the EQ predicate on the "account_id" field. +func AccountIDEQ(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldAccountID, v)) +} + +// AccountIDNEQ applies the NEQ predicate on the "account_id" field. +func AccountIDNEQ(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldAccountID, v)) +} + +// AccountIDIn applies the In predicate on the "account_id" field. +func AccountIDIn(vs ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldAccountID, vs...)) +} + +// AccountIDNotIn applies the NotIn predicate on the "account_id" field. +func AccountIDNotIn(vs ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldAccountID, vs...)) +} + +// AccountIDGT applies the GT predicate on the "account_id" field. +func AccountIDGT(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldAccountID, v)) +} + +// AccountIDGTE applies the GTE predicate on the "account_id" field. +func AccountIDGTE(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldAccountID, v)) +} + +// AccountIDLT applies the LT predicate on the "account_id" field. +func AccountIDLT(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldAccountID, v)) +} + +// AccountIDLTE applies the LTE predicate on the "account_id" field. +func AccountIDLTE(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldAccountID, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldUserID, vs...)) +} + +// UserIDGT applies the GT predicate on the "user_id" field. +func UserIDGT(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldUserID, v)) +} + +// UserIDGTE applies the GTE predicate on the "user_id" field. +func UserIDGTE(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldUserID, v)) +} + +// UserIDLT applies the LT predicate on the "user_id" field. +func UserIDLT(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldUserID, v)) +} + +// UserIDLTE applies the LTE predicate on the "user_id" field. +func UserIDLTE(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldUserID, v)) +} + +// MediaTypeEQ applies the EQ predicate on the "media_type" field. +func MediaTypeEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldMediaType, v)) +} + +// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. +func MediaTypeNEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldMediaType, v)) +} + +// MediaTypeIn applies the In predicate on the "media_type" field. +func MediaTypeIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldMediaType, vs...)) +} + +// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. +func MediaTypeNotIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldMediaType, vs...)) +} + +// MediaTypeGT applies the GT predicate on the "media_type" field. +func MediaTypeGT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldMediaType, v)) +} + +// MediaTypeGTE applies the GTE predicate on the "media_type" field. +func MediaTypeGTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldMediaType, v)) +} + +// MediaTypeLT applies the LT predicate on the "media_type" field. +func MediaTypeLT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldMediaType, v)) +} + +// MediaTypeLTE applies the LTE predicate on the "media_type" field. +func MediaTypeLTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldMediaType, v)) +} + +// MediaTypeContains applies the Contains predicate on the "media_type" field. +func MediaTypeContains(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContains(FieldMediaType, v)) +} + +// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. +func MediaTypeHasPrefix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasPrefix(FieldMediaType, v)) +} + +// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. +func MediaTypeHasSuffix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasSuffix(FieldMediaType, v)) +} + +// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. +func MediaTypeEqualFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEqualFold(FieldMediaType, v)) +} + +// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. +func MediaTypeContainsFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContainsFold(FieldMediaType, v)) +} + +// OriginalURLEQ applies the EQ predicate on the "original_url" field. +func OriginalURLEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldOriginalURL, v)) +} + +// OriginalURLNEQ applies the NEQ predicate on the "original_url" field. +func OriginalURLNEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldOriginalURL, v)) +} + +// OriginalURLIn applies the In predicate on the "original_url" field. +func OriginalURLIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldOriginalURL, vs...)) +} + +// OriginalURLNotIn applies the NotIn predicate on the "original_url" field. +func OriginalURLNotIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldOriginalURL, vs...)) +} + +// OriginalURLGT applies the GT predicate on the "original_url" field. +func OriginalURLGT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldOriginalURL, v)) +} + +// OriginalURLGTE applies the GTE predicate on the "original_url" field. +func OriginalURLGTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldOriginalURL, v)) +} + +// OriginalURLLT applies the LT predicate on the "original_url" field. +func OriginalURLLT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldOriginalURL, v)) +} + +// OriginalURLLTE applies the LTE predicate on the "original_url" field. +func OriginalURLLTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldOriginalURL, v)) +} + +// OriginalURLContains applies the Contains predicate on the "original_url" field. +func OriginalURLContains(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContains(FieldOriginalURL, v)) +} + +// OriginalURLHasPrefix applies the HasPrefix predicate on the "original_url" field. +func OriginalURLHasPrefix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasPrefix(FieldOriginalURL, v)) +} + +// OriginalURLHasSuffix applies the HasSuffix predicate on the "original_url" field. +func OriginalURLHasSuffix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasSuffix(FieldOriginalURL, v)) +} + +// OriginalURLEqualFold applies the EqualFold predicate on the "original_url" field. +func OriginalURLEqualFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEqualFold(FieldOriginalURL, v)) +} + +// OriginalURLContainsFold applies the ContainsFold predicate on the "original_url" field. +func OriginalURLContainsFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContainsFold(FieldOriginalURL, v)) +} + +// CachePathEQ applies the EQ predicate on the "cache_path" field. +func CachePathEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldCachePath, v)) +} + +// CachePathNEQ applies the NEQ predicate on the "cache_path" field. +func CachePathNEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldCachePath, v)) +} + +// CachePathIn applies the In predicate on the "cache_path" field. +func CachePathIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldCachePath, vs...)) +} + +// CachePathNotIn applies the NotIn predicate on the "cache_path" field. +func CachePathNotIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldCachePath, vs...)) +} + +// CachePathGT applies the GT predicate on the "cache_path" field. +func CachePathGT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldCachePath, v)) +} + +// CachePathGTE applies the GTE predicate on the "cache_path" field. +func CachePathGTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldCachePath, v)) +} + +// CachePathLT applies the LT predicate on the "cache_path" field. +func CachePathLT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldCachePath, v)) +} + +// CachePathLTE applies the LTE predicate on the "cache_path" field. +func CachePathLTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldCachePath, v)) +} + +// CachePathContains applies the Contains predicate on the "cache_path" field. +func CachePathContains(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContains(FieldCachePath, v)) +} + +// CachePathHasPrefix applies the HasPrefix predicate on the "cache_path" field. +func CachePathHasPrefix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasPrefix(FieldCachePath, v)) +} + +// CachePathHasSuffix applies the HasSuffix predicate on the "cache_path" field. +func CachePathHasSuffix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasSuffix(FieldCachePath, v)) +} + +// CachePathEqualFold applies the EqualFold predicate on the "cache_path" field. +func CachePathEqualFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEqualFold(FieldCachePath, v)) +} + +// CachePathContainsFold applies the ContainsFold predicate on the "cache_path" field. +func CachePathContainsFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContainsFold(FieldCachePath, v)) +} + +// CacheURLEQ applies the EQ predicate on the "cache_url" field. +func CacheURLEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldCacheURL, v)) +} + +// CacheURLNEQ applies the NEQ predicate on the "cache_url" field. +func CacheURLNEQ(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldCacheURL, v)) +} + +// CacheURLIn applies the In predicate on the "cache_url" field. +func CacheURLIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldCacheURL, vs...)) +} + +// CacheURLNotIn applies the NotIn predicate on the "cache_url" field. +func CacheURLNotIn(vs ...string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldCacheURL, vs...)) +} + +// CacheURLGT applies the GT predicate on the "cache_url" field. +func CacheURLGT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldCacheURL, v)) +} + +// CacheURLGTE applies the GTE predicate on the "cache_url" field. +func CacheURLGTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldCacheURL, v)) +} + +// CacheURLLT applies the LT predicate on the "cache_url" field. +func CacheURLLT(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldCacheURL, v)) +} + +// CacheURLLTE applies the LTE predicate on the "cache_url" field. +func CacheURLLTE(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldCacheURL, v)) +} + +// CacheURLContains applies the Contains predicate on the "cache_url" field. +func CacheURLContains(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContains(FieldCacheURL, v)) +} + +// CacheURLHasPrefix applies the HasPrefix predicate on the "cache_url" field. +func CacheURLHasPrefix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasPrefix(FieldCacheURL, v)) +} + +// CacheURLHasSuffix applies the HasSuffix predicate on the "cache_url" field. +func CacheURLHasSuffix(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldHasSuffix(FieldCacheURL, v)) +} + +// CacheURLEqualFold applies the EqualFold predicate on the "cache_url" field. +func CacheURLEqualFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEqualFold(FieldCacheURL, v)) +} + +// CacheURLContainsFold applies the ContainsFold predicate on the "cache_url" field. +func CacheURLContainsFold(v string) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldContainsFold(FieldCacheURL, v)) +} + +// SizeBytesEQ applies the EQ predicate on the "size_bytes" field. +func SizeBytesEQ(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldSizeBytes, v)) +} + +// SizeBytesNEQ applies the NEQ predicate on the "size_bytes" field. +func SizeBytesNEQ(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldSizeBytes, v)) +} + +// SizeBytesIn applies the In predicate on the "size_bytes" field. +func SizeBytesIn(vs ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldSizeBytes, vs...)) +} + +// SizeBytesNotIn applies the NotIn predicate on the "size_bytes" field. +func SizeBytesNotIn(vs ...int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldSizeBytes, vs...)) +} + +// SizeBytesGT applies the GT predicate on the "size_bytes" field. +func SizeBytesGT(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldSizeBytes, v)) +} + +// SizeBytesGTE applies the GTE predicate on the "size_bytes" field. +func SizeBytesGTE(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldSizeBytes, v)) +} + +// SizeBytesLT applies the LT predicate on the "size_bytes" field. +func SizeBytesLT(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldSizeBytes, v)) +} + +// SizeBytesLTE applies the LTE predicate on the "size_bytes" field. +func SizeBytesLTE(v int64) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldSizeBytes, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.FieldLTE(FieldCreatedAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.SoraCacheFile) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.SoraCacheFile) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.SoraCacheFile) predicate.SoraCacheFile { + return predicate.SoraCacheFile(sql.NotPredicates(p)) +} diff --git a/backend/ent/soracachefile_create.go b/backend/ent/soracachefile_create.go new file mode 100644 index 00000000..35e0b525 --- /dev/null +++ b/backend/ent/soracachefile_create.go @@ -0,0 +1,1004 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" +) + +// SoraCacheFileCreate is the builder for creating a SoraCacheFile entity. +type SoraCacheFileCreate struct { + config + mutation *SoraCacheFileMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetTaskID sets the "task_id" field. +func (_c *SoraCacheFileCreate) SetTaskID(v string) *SoraCacheFileCreate { + _c.mutation.SetTaskID(v) + return _c +} + +// SetNillableTaskID sets the "task_id" field if the given value is not nil. +func (_c *SoraCacheFileCreate) SetNillableTaskID(v *string) *SoraCacheFileCreate { + if v != nil { + _c.SetTaskID(*v) + } + return _c +} + +// SetAccountID sets the "account_id" field. +func (_c *SoraCacheFileCreate) SetAccountID(v int64) *SoraCacheFileCreate { + _c.mutation.SetAccountID(v) + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *SoraCacheFileCreate) SetUserID(v int64) *SoraCacheFileCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetMediaType sets the "media_type" field. +func (_c *SoraCacheFileCreate) SetMediaType(v string) *SoraCacheFileCreate { + _c.mutation.SetMediaType(v) + return _c +} + +// SetOriginalURL sets the "original_url" field. +func (_c *SoraCacheFileCreate) SetOriginalURL(v string) *SoraCacheFileCreate { + _c.mutation.SetOriginalURL(v) + return _c +} + +// SetCachePath sets the "cache_path" field. +func (_c *SoraCacheFileCreate) SetCachePath(v string) *SoraCacheFileCreate { + _c.mutation.SetCachePath(v) + return _c +} + +// SetCacheURL sets the "cache_url" field. +func (_c *SoraCacheFileCreate) SetCacheURL(v string) *SoraCacheFileCreate { + _c.mutation.SetCacheURL(v) + return _c +} + +// SetSizeBytes sets the "size_bytes" field. +func (_c *SoraCacheFileCreate) SetSizeBytes(v int64) *SoraCacheFileCreate { + _c.mutation.SetSizeBytes(v) + return _c +} + +// SetNillableSizeBytes sets the "size_bytes" field if the given value is not nil. +func (_c *SoraCacheFileCreate) SetNillableSizeBytes(v *int64) *SoraCacheFileCreate { + if v != nil { + _c.SetSizeBytes(*v) + } + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *SoraCacheFileCreate) SetCreatedAt(v time.Time) *SoraCacheFileCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *SoraCacheFileCreate) SetNillableCreatedAt(v *time.Time) *SoraCacheFileCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// Mutation returns the SoraCacheFileMutation object of the builder. +func (_c *SoraCacheFileCreate) Mutation() *SoraCacheFileMutation { + return _c.mutation +} + +// Save creates the SoraCacheFile in the database. +func (_c *SoraCacheFileCreate) Save(ctx context.Context) (*SoraCacheFile, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *SoraCacheFileCreate) SaveX(ctx context.Context) *SoraCacheFile { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraCacheFileCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraCacheFileCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *SoraCacheFileCreate) defaults() { + if _, ok := _c.mutation.SizeBytes(); !ok { + v := soracachefile.DefaultSizeBytes + _c.mutation.SetSizeBytes(v) + } + if _, ok := _c.mutation.CreatedAt(); !ok { + v := soracachefile.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *SoraCacheFileCreate) check() error { + if v, ok := _c.mutation.TaskID(); ok { + if err := soracachefile.TaskIDValidator(v); err != nil { + return &ValidationError{Name: "task_id", err: fmt.Errorf(`ent: validator failed for field "SoraCacheFile.task_id": %w`, err)} + } + } + if _, ok := _c.mutation.AccountID(); !ok { + return &ValidationError{Name: "account_id", err: errors.New(`ent: missing required field "SoraCacheFile.account_id"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "SoraCacheFile.user_id"`)} + } + if _, ok := _c.mutation.MediaType(); !ok { + return &ValidationError{Name: "media_type", err: errors.New(`ent: missing required field "SoraCacheFile.media_type"`)} + } + if v, ok := _c.mutation.MediaType(); ok { + if err := soracachefile.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "SoraCacheFile.media_type": %w`, err)} + } + } + if _, ok := _c.mutation.OriginalURL(); !ok { + return &ValidationError{Name: "original_url", err: errors.New(`ent: missing required field "SoraCacheFile.original_url"`)} + } + if _, ok := _c.mutation.CachePath(); !ok { + return &ValidationError{Name: "cache_path", err: errors.New(`ent: missing required field "SoraCacheFile.cache_path"`)} + } + if _, ok := _c.mutation.CacheURL(); !ok { + return &ValidationError{Name: "cache_url", err: errors.New(`ent: missing required field "SoraCacheFile.cache_url"`)} + } + if _, ok := _c.mutation.SizeBytes(); !ok { + return &ValidationError{Name: "size_bytes", err: errors.New(`ent: missing required field "SoraCacheFile.size_bytes"`)} + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SoraCacheFile.created_at"`)} + } + return nil +} + +func (_c *SoraCacheFileCreate) sqlSave(ctx context.Context) (*SoraCacheFile, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *SoraCacheFileCreate) createSpec() (*SoraCacheFile, *sqlgraph.CreateSpec) { + var ( + _node = &SoraCacheFile{config: _c.config} + _spec = sqlgraph.NewCreateSpec(soracachefile.Table, sqlgraph.NewFieldSpec(soracachefile.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.TaskID(); ok { + _spec.SetField(soracachefile.FieldTaskID, field.TypeString, value) + _node.TaskID = &value + } + if value, ok := _c.mutation.AccountID(); ok { + _spec.SetField(soracachefile.FieldAccountID, field.TypeInt64, value) + _node.AccountID = value + } + if value, ok := _c.mutation.UserID(); ok { + _spec.SetField(soracachefile.FieldUserID, field.TypeInt64, value) + _node.UserID = value + } + if value, ok := _c.mutation.MediaType(); ok { + _spec.SetField(soracachefile.FieldMediaType, field.TypeString, value) + _node.MediaType = value + } + if value, ok := _c.mutation.OriginalURL(); ok { + _spec.SetField(soracachefile.FieldOriginalURL, field.TypeString, value) + _node.OriginalURL = value + } + if value, ok := _c.mutation.CachePath(); ok { + _spec.SetField(soracachefile.FieldCachePath, field.TypeString, value) + _node.CachePath = value + } + if value, ok := _c.mutation.CacheURL(); ok { + _spec.SetField(soracachefile.FieldCacheURL, field.TypeString, value) + _node.CacheURL = value + } + if value, ok := _c.mutation.SizeBytes(); ok { + _spec.SetField(soracachefile.FieldSizeBytes, field.TypeInt64, value) + _node.SizeBytes = value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(soracachefile.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraCacheFile.Create(). +// SetTaskID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraCacheFileUpsert) { +// SetTaskID(v+v). +// }). +// Exec(ctx) +func (_c *SoraCacheFileCreate) OnConflict(opts ...sql.ConflictOption) *SoraCacheFileUpsertOne { + _c.conflict = opts + return &SoraCacheFileUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraCacheFile.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraCacheFileCreate) OnConflictColumns(columns ...string) *SoraCacheFileUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraCacheFileUpsertOne{ + create: _c, + } +} + +type ( + // SoraCacheFileUpsertOne is the builder for "upsert"-ing + // one SoraCacheFile node. + SoraCacheFileUpsertOne struct { + create *SoraCacheFileCreate + } + + // SoraCacheFileUpsert is the "OnConflict" setter. + SoraCacheFileUpsert struct { + *sql.UpdateSet + } +) + +// SetTaskID sets the "task_id" field. +func (u *SoraCacheFileUpsert) SetTaskID(v string) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldTaskID, v) + return u +} + +// UpdateTaskID sets the "task_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateTaskID() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldTaskID) + return u +} + +// ClearTaskID clears the value of the "task_id" field. +func (u *SoraCacheFileUpsert) ClearTaskID() *SoraCacheFileUpsert { + u.SetNull(soracachefile.FieldTaskID) + return u +} + +// SetAccountID sets the "account_id" field. +func (u *SoraCacheFileUpsert) SetAccountID(v int64) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldAccountID, v) + return u +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateAccountID() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldAccountID) + return u +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraCacheFileUpsert) AddAccountID(v int64) *SoraCacheFileUpsert { + u.Add(soracachefile.FieldAccountID, v) + return u +} + +// SetUserID sets the "user_id" field. +func (u *SoraCacheFileUpsert) SetUserID(v int64) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateUserID() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldUserID) + return u +} + +// AddUserID adds v to the "user_id" field. +func (u *SoraCacheFileUpsert) AddUserID(v int64) *SoraCacheFileUpsert { + u.Add(soracachefile.FieldUserID, v) + return u +} + +// SetMediaType sets the "media_type" field. +func (u *SoraCacheFileUpsert) SetMediaType(v string) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldMediaType, v) + return u +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateMediaType() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldMediaType) + return u +} + +// SetOriginalURL sets the "original_url" field. +func (u *SoraCacheFileUpsert) SetOriginalURL(v string) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldOriginalURL, v) + return u +} + +// UpdateOriginalURL sets the "original_url" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateOriginalURL() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldOriginalURL) + return u +} + +// SetCachePath sets the "cache_path" field. +func (u *SoraCacheFileUpsert) SetCachePath(v string) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldCachePath, v) + return u +} + +// UpdateCachePath sets the "cache_path" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateCachePath() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldCachePath) + return u +} + +// SetCacheURL sets the "cache_url" field. +func (u *SoraCacheFileUpsert) SetCacheURL(v string) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldCacheURL, v) + return u +} + +// UpdateCacheURL sets the "cache_url" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateCacheURL() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldCacheURL) + return u +} + +// SetSizeBytes sets the "size_bytes" field. +func (u *SoraCacheFileUpsert) SetSizeBytes(v int64) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldSizeBytes, v) + return u +} + +// UpdateSizeBytes sets the "size_bytes" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateSizeBytes() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldSizeBytes) + return u +} + +// AddSizeBytes adds v to the "size_bytes" field. +func (u *SoraCacheFileUpsert) AddSizeBytes(v int64) *SoraCacheFileUpsert { + u.Add(soracachefile.FieldSizeBytes, v) + return u +} + +// SetCreatedAt sets the "created_at" field. +func (u *SoraCacheFileUpsert) SetCreatedAt(v time.Time) *SoraCacheFileUpsert { + u.Set(soracachefile.FieldCreatedAt, v) + return u +} + +// UpdateCreatedAt sets the "created_at" field to the value that was provided on create. +func (u *SoraCacheFileUpsert) UpdateCreatedAt() *SoraCacheFileUpsert { + u.SetExcluded(soracachefile.FieldCreatedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.SoraCacheFile.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraCacheFileUpsertOne) UpdateNewValues() *SoraCacheFileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraCacheFile.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraCacheFileUpsertOne) Ignore() *SoraCacheFileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraCacheFileUpsertOne) DoNothing() *SoraCacheFileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraCacheFileCreate.OnConflict +// documentation for more info. +func (u *SoraCacheFileUpsertOne) Update(set func(*SoraCacheFileUpsert)) *SoraCacheFileUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraCacheFileUpsert{UpdateSet: update}) + })) + return u +} + +// SetTaskID sets the "task_id" field. +func (u *SoraCacheFileUpsertOne) SetTaskID(v string) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetTaskID(v) + }) +} + +// UpdateTaskID sets the "task_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateTaskID() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateTaskID() + }) +} + +// ClearTaskID clears the value of the "task_id" field. +func (u *SoraCacheFileUpsertOne) ClearTaskID() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.ClearTaskID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraCacheFileUpsertOne) SetAccountID(v int64) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraCacheFileUpsertOne) AddAccountID(v int64) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateAccountID() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateAccountID() + }) +} + +// SetUserID sets the "user_id" field. +func (u *SoraCacheFileUpsertOne) SetUserID(v int64) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetUserID(v) + }) +} + +// AddUserID adds v to the "user_id" field. +func (u *SoraCacheFileUpsertOne) AddUserID(v int64) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.AddUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateUserID() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateUserID() + }) +} + +// SetMediaType sets the "media_type" field. +func (u *SoraCacheFileUpsertOne) SetMediaType(v string) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateMediaType() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateMediaType() + }) +} + +// SetOriginalURL sets the "original_url" field. +func (u *SoraCacheFileUpsertOne) SetOriginalURL(v string) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetOriginalURL(v) + }) +} + +// UpdateOriginalURL sets the "original_url" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateOriginalURL() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateOriginalURL() + }) +} + +// SetCachePath sets the "cache_path" field. +func (u *SoraCacheFileUpsertOne) SetCachePath(v string) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetCachePath(v) + }) +} + +// UpdateCachePath sets the "cache_path" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateCachePath() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateCachePath() + }) +} + +// SetCacheURL sets the "cache_url" field. +func (u *SoraCacheFileUpsertOne) SetCacheURL(v string) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetCacheURL(v) + }) +} + +// UpdateCacheURL sets the "cache_url" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateCacheURL() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateCacheURL() + }) +} + +// SetSizeBytes sets the "size_bytes" field. +func (u *SoraCacheFileUpsertOne) SetSizeBytes(v int64) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetSizeBytes(v) + }) +} + +// AddSizeBytes adds v to the "size_bytes" field. +func (u *SoraCacheFileUpsertOne) AddSizeBytes(v int64) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.AddSizeBytes(v) + }) +} + +// UpdateSizeBytes sets the "size_bytes" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateSizeBytes() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateSizeBytes() + }) +} + +// SetCreatedAt sets the "created_at" field. +func (u *SoraCacheFileUpsertOne) SetCreatedAt(v time.Time) *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetCreatedAt(v) + }) +} + +// UpdateCreatedAt sets the "created_at" field to the value that was provided on create. +func (u *SoraCacheFileUpsertOne) UpdateCreatedAt() *SoraCacheFileUpsertOne { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateCreatedAt() + }) +} + +// Exec executes the query. +func (u *SoraCacheFileUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraCacheFileCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraCacheFileUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SoraCacheFileUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SoraCacheFileUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// SoraCacheFileCreateBulk is the builder for creating many SoraCacheFile entities in bulk. +type SoraCacheFileCreateBulk struct { + config + err error + builders []*SoraCacheFileCreate + conflict []sql.ConflictOption +} + +// Save creates the SoraCacheFile entities in the database. +func (_c *SoraCacheFileCreateBulk) Save(ctx context.Context) ([]*SoraCacheFile, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*SoraCacheFile, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SoraCacheFileMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *SoraCacheFileCreateBulk) SaveX(ctx context.Context) []*SoraCacheFile { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraCacheFileCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraCacheFileCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraCacheFile.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraCacheFileUpsert) { +// SetTaskID(v+v). +// }). +// Exec(ctx) +func (_c *SoraCacheFileCreateBulk) OnConflict(opts ...sql.ConflictOption) *SoraCacheFileUpsertBulk { + _c.conflict = opts + return &SoraCacheFileUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraCacheFile.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraCacheFileCreateBulk) OnConflictColumns(columns ...string) *SoraCacheFileUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraCacheFileUpsertBulk{ + create: _c, + } +} + +// SoraCacheFileUpsertBulk is the builder for "upsert"-ing +// a bulk of SoraCacheFile nodes. +type SoraCacheFileUpsertBulk struct { + create *SoraCacheFileCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.SoraCacheFile.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraCacheFileUpsertBulk) UpdateNewValues() *SoraCacheFileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraCacheFile.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraCacheFileUpsertBulk) Ignore() *SoraCacheFileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraCacheFileUpsertBulk) DoNothing() *SoraCacheFileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraCacheFileCreateBulk.OnConflict +// documentation for more info. +func (u *SoraCacheFileUpsertBulk) Update(set func(*SoraCacheFileUpsert)) *SoraCacheFileUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraCacheFileUpsert{UpdateSet: update}) + })) + return u +} + +// SetTaskID sets the "task_id" field. +func (u *SoraCacheFileUpsertBulk) SetTaskID(v string) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetTaskID(v) + }) +} + +// UpdateTaskID sets the "task_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateTaskID() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateTaskID() + }) +} + +// ClearTaskID clears the value of the "task_id" field. +func (u *SoraCacheFileUpsertBulk) ClearTaskID() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.ClearTaskID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraCacheFileUpsertBulk) SetAccountID(v int64) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraCacheFileUpsertBulk) AddAccountID(v int64) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateAccountID() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateAccountID() + }) +} + +// SetUserID sets the "user_id" field. +func (u *SoraCacheFileUpsertBulk) SetUserID(v int64) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetUserID(v) + }) +} + +// AddUserID adds v to the "user_id" field. +func (u *SoraCacheFileUpsertBulk) AddUserID(v int64) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.AddUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateUserID() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateUserID() + }) +} + +// SetMediaType sets the "media_type" field. +func (u *SoraCacheFileUpsertBulk) SetMediaType(v string) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateMediaType() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateMediaType() + }) +} + +// SetOriginalURL sets the "original_url" field. +func (u *SoraCacheFileUpsertBulk) SetOriginalURL(v string) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetOriginalURL(v) + }) +} + +// UpdateOriginalURL sets the "original_url" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateOriginalURL() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateOriginalURL() + }) +} + +// SetCachePath sets the "cache_path" field. +func (u *SoraCacheFileUpsertBulk) SetCachePath(v string) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetCachePath(v) + }) +} + +// UpdateCachePath sets the "cache_path" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateCachePath() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateCachePath() + }) +} + +// SetCacheURL sets the "cache_url" field. +func (u *SoraCacheFileUpsertBulk) SetCacheURL(v string) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetCacheURL(v) + }) +} + +// UpdateCacheURL sets the "cache_url" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateCacheURL() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateCacheURL() + }) +} + +// SetSizeBytes sets the "size_bytes" field. +func (u *SoraCacheFileUpsertBulk) SetSizeBytes(v int64) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetSizeBytes(v) + }) +} + +// AddSizeBytes adds v to the "size_bytes" field. +func (u *SoraCacheFileUpsertBulk) AddSizeBytes(v int64) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.AddSizeBytes(v) + }) +} + +// UpdateSizeBytes sets the "size_bytes" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateSizeBytes() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateSizeBytes() + }) +} + +// SetCreatedAt sets the "created_at" field. +func (u *SoraCacheFileUpsertBulk) SetCreatedAt(v time.Time) *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.SetCreatedAt(v) + }) +} + +// UpdateCreatedAt sets the "created_at" field to the value that was provided on create. +func (u *SoraCacheFileUpsertBulk) UpdateCreatedAt() *SoraCacheFileUpsertBulk { + return u.Update(func(s *SoraCacheFileUpsert) { + s.UpdateCreatedAt() + }) +} + +// Exec executes the query. +func (u *SoraCacheFileUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SoraCacheFileCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraCacheFileCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraCacheFileUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/soracachefile_delete.go b/backend/ent/soracachefile_delete.go new file mode 100644 index 00000000..bbd18485 --- /dev/null +++ b/backend/ent/soracachefile_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" +) + +// SoraCacheFileDelete is the builder for deleting a SoraCacheFile entity. +type SoraCacheFileDelete struct { + config + hooks []Hook + mutation *SoraCacheFileMutation +} + +// Where appends a list predicates to the SoraCacheFileDelete builder. +func (_d *SoraCacheFileDelete) Where(ps ...predicate.SoraCacheFile) *SoraCacheFileDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *SoraCacheFileDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraCacheFileDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *SoraCacheFileDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(soracachefile.Table, sqlgraph.NewFieldSpec(soracachefile.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// SoraCacheFileDeleteOne is the builder for deleting a single SoraCacheFile entity. +type SoraCacheFileDeleteOne struct { + _d *SoraCacheFileDelete +} + +// Where appends a list predicates to the SoraCacheFileDelete builder. +func (_d *SoraCacheFileDeleteOne) Where(ps ...predicate.SoraCacheFile) *SoraCacheFileDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *SoraCacheFileDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{soracachefile.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraCacheFileDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/soracachefile_query.go b/backend/ent/soracachefile_query.go new file mode 100644 index 00000000..15d6d95e --- /dev/null +++ b/backend/ent/soracachefile_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" +) + +// SoraCacheFileQuery is the builder for querying SoraCacheFile entities. +type SoraCacheFileQuery struct { + config + ctx *QueryContext + order []soracachefile.OrderOption + inters []Interceptor + predicates []predicate.SoraCacheFile + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SoraCacheFileQuery builder. +func (_q *SoraCacheFileQuery) Where(ps ...predicate.SoraCacheFile) *SoraCacheFileQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *SoraCacheFileQuery) Limit(limit int) *SoraCacheFileQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *SoraCacheFileQuery) Offset(offset int) *SoraCacheFileQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *SoraCacheFileQuery) Unique(unique bool) *SoraCacheFileQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *SoraCacheFileQuery) Order(o ...soracachefile.OrderOption) *SoraCacheFileQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first SoraCacheFile entity from the query. +// Returns a *NotFoundError when no SoraCacheFile was found. +func (_q *SoraCacheFileQuery) First(ctx context.Context) (*SoraCacheFile, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{soracachefile.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *SoraCacheFileQuery) FirstX(ctx context.Context) *SoraCacheFile { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first SoraCacheFile ID from the query. +// Returns a *NotFoundError when no SoraCacheFile ID was found. +func (_q *SoraCacheFileQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{soracachefile.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *SoraCacheFileQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single SoraCacheFile entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one SoraCacheFile entity is found. +// Returns a *NotFoundError when no SoraCacheFile entities are found. +func (_q *SoraCacheFileQuery) Only(ctx context.Context) (*SoraCacheFile, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{soracachefile.Label} + default: + return nil, &NotSingularError{soracachefile.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *SoraCacheFileQuery) OnlyX(ctx context.Context) *SoraCacheFile { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only SoraCacheFile ID in the query. +// Returns a *NotSingularError when more than one SoraCacheFile ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *SoraCacheFileQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{soracachefile.Label} + default: + err = &NotSingularError{soracachefile.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *SoraCacheFileQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of SoraCacheFiles. +func (_q *SoraCacheFileQuery) All(ctx context.Context) ([]*SoraCacheFile, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*SoraCacheFile, *SoraCacheFileQuery]() + return withInterceptors[[]*SoraCacheFile](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *SoraCacheFileQuery) AllX(ctx context.Context) []*SoraCacheFile { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of SoraCacheFile IDs. +func (_q *SoraCacheFileQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(soracachefile.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *SoraCacheFileQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *SoraCacheFileQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*SoraCacheFileQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *SoraCacheFileQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *SoraCacheFileQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *SoraCacheFileQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SoraCacheFileQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *SoraCacheFileQuery) Clone() *SoraCacheFileQuery { + if _q == nil { + return nil + } + return &SoraCacheFileQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]soracachefile.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.SoraCacheFile{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// TaskID string `json:"task_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.SoraCacheFile.Query(). +// GroupBy(soracachefile.FieldTaskID). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *SoraCacheFileQuery) GroupBy(field string, fields ...string) *SoraCacheFileGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &SoraCacheFileGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = soracachefile.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// TaskID string `json:"task_id,omitempty"` +// } +// +// client.SoraCacheFile.Query(). +// Select(soracachefile.FieldTaskID). +// Scan(ctx, &v) +func (_q *SoraCacheFileQuery) Select(fields ...string) *SoraCacheFileSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &SoraCacheFileSelect{SoraCacheFileQuery: _q} + sbuild.label = soracachefile.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SoraCacheFileSelect configured with the given aggregations. +func (_q *SoraCacheFileQuery) Aggregate(fns ...AggregateFunc) *SoraCacheFileSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *SoraCacheFileQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !soracachefile.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *SoraCacheFileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SoraCacheFile, error) { + var ( + nodes = []*SoraCacheFile{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*SoraCacheFile).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &SoraCacheFile{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *SoraCacheFileQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *SoraCacheFileQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(soracachefile.Table, soracachefile.Columns, sqlgraph.NewFieldSpec(soracachefile.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, soracachefile.FieldID) + for i := range fields { + if fields[i] != soracachefile.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *SoraCacheFileQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(soracachefile.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = soracachefile.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *SoraCacheFileQuery) ForUpdate(opts ...sql.LockOption) *SoraCacheFileQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *SoraCacheFileQuery) ForShare(opts ...sql.LockOption) *SoraCacheFileQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// SoraCacheFileGroupBy is the group-by builder for SoraCacheFile entities. +type SoraCacheFileGroupBy struct { + selector + build *SoraCacheFileQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *SoraCacheFileGroupBy) Aggregate(fns ...AggregateFunc) *SoraCacheFileGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *SoraCacheFileGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraCacheFileQuery, *SoraCacheFileGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *SoraCacheFileGroupBy) sqlScan(ctx context.Context, root *SoraCacheFileQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SoraCacheFileSelect is the builder for selecting fields of SoraCacheFile entities. +type SoraCacheFileSelect struct { + *SoraCacheFileQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *SoraCacheFileSelect) Aggregate(fns ...AggregateFunc) *SoraCacheFileSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *SoraCacheFileSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraCacheFileQuery, *SoraCacheFileSelect](ctx, _s.SoraCacheFileQuery, _s, _s.inters, v) +} + +func (_s *SoraCacheFileSelect) sqlScan(ctx context.Context, root *SoraCacheFileQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/soracachefile_update.go b/backend/ent/soracachefile_update.go new file mode 100644 index 00000000..44430f76 --- /dev/null +++ b/backend/ent/soracachefile_update.go @@ -0,0 +1,596 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soracachefile" +) + +// SoraCacheFileUpdate is the builder for updating SoraCacheFile entities. +type SoraCacheFileUpdate struct { + config + hooks []Hook + mutation *SoraCacheFileMutation +} + +// Where appends a list predicates to the SoraCacheFileUpdate builder. +func (_u *SoraCacheFileUpdate) Where(ps ...predicate.SoraCacheFile) *SoraCacheFileUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetTaskID sets the "task_id" field. +func (_u *SoraCacheFileUpdate) SetTaskID(v string) *SoraCacheFileUpdate { + _u.mutation.SetTaskID(v) + return _u +} + +// SetNillableTaskID sets the "task_id" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableTaskID(v *string) *SoraCacheFileUpdate { + if v != nil { + _u.SetTaskID(*v) + } + return _u +} + +// ClearTaskID clears the value of the "task_id" field. +func (_u *SoraCacheFileUpdate) ClearTaskID() *SoraCacheFileUpdate { + _u.mutation.ClearTaskID() + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraCacheFileUpdate) SetAccountID(v int64) *SoraCacheFileUpdate { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableAccountID(v *int64) *SoraCacheFileUpdate { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraCacheFileUpdate) AddAccountID(v int64) *SoraCacheFileUpdate { + _u.mutation.AddAccountID(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *SoraCacheFileUpdate) SetUserID(v int64) *SoraCacheFileUpdate { + _u.mutation.ResetUserID() + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableUserID(v *int64) *SoraCacheFileUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// AddUserID adds value to the "user_id" field. +func (_u *SoraCacheFileUpdate) AddUserID(v int64) *SoraCacheFileUpdate { + _u.mutation.AddUserID(v) + return _u +} + +// SetMediaType sets the "media_type" field. +func (_u *SoraCacheFileUpdate) SetMediaType(v string) *SoraCacheFileUpdate { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableMediaType(v *string) *SoraCacheFileUpdate { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// SetOriginalURL sets the "original_url" field. +func (_u *SoraCacheFileUpdate) SetOriginalURL(v string) *SoraCacheFileUpdate { + _u.mutation.SetOriginalURL(v) + return _u +} + +// SetNillableOriginalURL sets the "original_url" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableOriginalURL(v *string) *SoraCacheFileUpdate { + if v != nil { + _u.SetOriginalURL(*v) + } + return _u +} + +// SetCachePath sets the "cache_path" field. +func (_u *SoraCacheFileUpdate) SetCachePath(v string) *SoraCacheFileUpdate { + _u.mutation.SetCachePath(v) + return _u +} + +// SetNillableCachePath sets the "cache_path" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableCachePath(v *string) *SoraCacheFileUpdate { + if v != nil { + _u.SetCachePath(*v) + } + return _u +} + +// SetCacheURL sets the "cache_url" field. +func (_u *SoraCacheFileUpdate) SetCacheURL(v string) *SoraCacheFileUpdate { + _u.mutation.SetCacheURL(v) + return _u +} + +// SetNillableCacheURL sets the "cache_url" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableCacheURL(v *string) *SoraCacheFileUpdate { + if v != nil { + _u.SetCacheURL(*v) + } + return _u +} + +// SetSizeBytes sets the "size_bytes" field. +func (_u *SoraCacheFileUpdate) SetSizeBytes(v int64) *SoraCacheFileUpdate { + _u.mutation.ResetSizeBytes() + _u.mutation.SetSizeBytes(v) + return _u +} + +// SetNillableSizeBytes sets the "size_bytes" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableSizeBytes(v *int64) *SoraCacheFileUpdate { + if v != nil { + _u.SetSizeBytes(*v) + } + return _u +} + +// AddSizeBytes adds value to the "size_bytes" field. +func (_u *SoraCacheFileUpdate) AddSizeBytes(v int64) *SoraCacheFileUpdate { + _u.mutation.AddSizeBytes(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *SoraCacheFileUpdate) SetCreatedAt(v time.Time) *SoraCacheFileUpdate { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *SoraCacheFileUpdate) SetNillableCreatedAt(v *time.Time) *SoraCacheFileUpdate { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// Mutation returns the SoraCacheFileMutation object of the builder. +func (_u *SoraCacheFileUpdate) Mutation() *SoraCacheFileMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *SoraCacheFileUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraCacheFileUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *SoraCacheFileUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraCacheFileUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SoraCacheFileUpdate) check() error { + if v, ok := _u.mutation.TaskID(); ok { + if err := soracachefile.TaskIDValidator(v); err != nil { + return &ValidationError{Name: "task_id", err: fmt.Errorf(`ent: validator failed for field "SoraCacheFile.task_id": %w`, err)} + } + } + if v, ok := _u.mutation.MediaType(); ok { + if err := soracachefile.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "SoraCacheFile.media_type": %w`, err)} + } + } + return nil +} + +func (_u *SoraCacheFileUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(soracachefile.Table, soracachefile.Columns, sqlgraph.NewFieldSpec(soracachefile.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.TaskID(); ok { + _spec.SetField(soracachefile.FieldTaskID, field.TypeString, value) + } + if _u.mutation.TaskIDCleared() { + _spec.ClearField(soracachefile.FieldTaskID, field.TypeString) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(soracachefile.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(soracachefile.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.UserID(); ok { + _spec.SetField(soracachefile.FieldUserID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedUserID(); ok { + _spec.AddField(soracachefile.FieldUserID, field.TypeInt64, value) + } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(soracachefile.FieldMediaType, field.TypeString, value) + } + if value, ok := _u.mutation.OriginalURL(); ok { + _spec.SetField(soracachefile.FieldOriginalURL, field.TypeString, value) + } + if value, ok := _u.mutation.CachePath(); ok { + _spec.SetField(soracachefile.FieldCachePath, field.TypeString, value) + } + if value, ok := _u.mutation.CacheURL(); ok { + _spec.SetField(soracachefile.FieldCacheURL, field.TypeString, value) + } + if value, ok := _u.mutation.SizeBytes(); ok { + _spec.SetField(soracachefile.FieldSizeBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSizeBytes(); ok { + _spec.AddField(soracachefile.FieldSizeBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(soracachefile.FieldCreatedAt, field.TypeTime, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{soracachefile.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// SoraCacheFileUpdateOne is the builder for updating a single SoraCacheFile entity. +type SoraCacheFileUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SoraCacheFileMutation +} + +// SetTaskID sets the "task_id" field. +func (_u *SoraCacheFileUpdateOne) SetTaskID(v string) *SoraCacheFileUpdateOne { + _u.mutation.SetTaskID(v) + return _u +} + +// SetNillableTaskID sets the "task_id" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableTaskID(v *string) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetTaskID(*v) + } + return _u +} + +// ClearTaskID clears the value of the "task_id" field. +func (_u *SoraCacheFileUpdateOne) ClearTaskID() *SoraCacheFileUpdateOne { + _u.mutation.ClearTaskID() + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraCacheFileUpdateOne) SetAccountID(v int64) *SoraCacheFileUpdateOne { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableAccountID(v *int64) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraCacheFileUpdateOne) AddAccountID(v int64) *SoraCacheFileUpdateOne { + _u.mutation.AddAccountID(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *SoraCacheFileUpdateOne) SetUserID(v int64) *SoraCacheFileUpdateOne { + _u.mutation.ResetUserID() + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableUserID(v *int64) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// AddUserID adds value to the "user_id" field. +func (_u *SoraCacheFileUpdateOne) AddUserID(v int64) *SoraCacheFileUpdateOne { + _u.mutation.AddUserID(v) + return _u +} + +// SetMediaType sets the "media_type" field. +func (_u *SoraCacheFileUpdateOne) SetMediaType(v string) *SoraCacheFileUpdateOne { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableMediaType(v *string) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// SetOriginalURL sets the "original_url" field. +func (_u *SoraCacheFileUpdateOne) SetOriginalURL(v string) *SoraCacheFileUpdateOne { + _u.mutation.SetOriginalURL(v) + return _u +} + +// SetNillableOriginalURL sets the "original_url" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableOriginalURL(v *string) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetOriginalURL(*v) + } + return _u +} + +// SetCachePath sets the "cache_path" field. +func (_u *SoraCacheFileUpdateOne) SetCachePath(v string) *SoraCacheFileUpdateOne { + _u.mutation.SetCachePath(v) + return _u +} + +// SetNillableCachePath sets the "cache_path" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableCachePath(v *string) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetCachePath(*v) + } + return _u +} + +// SetCacheURL sets the "cache_url" field. +func (_u *SoraCacheFileUpdateOne) SetCacheURL(v string) *SoraCacheFileUpdateOne { + _u.mutation.SetCacheURL(v) + return _u +} + +// SetNillableCacheURL sets the "cache_url" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableCacheURL(v *string) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetCacheURL(*v) + } + return _u +} + +// SetSizeBytes sets the "size_bytes" field. +func (_u *SoraCacheFileUpdateOne) SetSizeBytes(v int64) *SoraCacheFileUpdateOne { + _u.mutation.ResetSizeBytes() + _u.mutation.SetSizeBytes(v) + return _u +} + +// SetNillableSizeBytes sets the "size_bytes" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableSizeBytes(v *int64) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetSizeBytes(*v) + } + return _u +} + +// AddSizeBytes adds value to the "size_bytes" field. +func (_u *SoraCacheFileUpdateOne) AddSizeBytes(v int64) *SoraCacheFileUpdateOne { + _u.mutation.AddSizeBytes(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *SoraCacheFileUpdateOne) SetCreatedAt(v time.Time) *SoraCacheFileUpdateOne { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *SoraCacheFileUpdateOne) SetNillableCreatedAt(v *time.Time) *SoraCacheFileUpdateOne { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// Mutation returns the SoraCacheFileMutation object of the builder. +func (_u *SoraCacheFileUpdateOne) Mutation() *SoraCacheFileMutation { + return _u.mutation +} + +// Where appends a list predicates to the SoraCacheFileUpdate builder. +func (_u *SoraCacheFileUpdateOne) Where(ps ...predicate.SoraCacheFile) *SoraCacheFileUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *SoraCacheFileUpdateOne) Select(field string, fields ...string) *SoraCacheFileUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated SoraCacheFile entity. +func (_u *SoraCacheFileUpdateOne) Save(ctx context.Context) (*SoraCacheFile, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraCacheFileUpdateOne) SaveX(ctx context.Context) *SoraCacheFile { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *SoraCacheFileUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraCacheFileUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SoraCacheFileUpdateOne) check() error { + if v, ok := _u.mutation.TaskID(); ok { + if err := soracachefile.TaskIDValidator(v); err != nil { + return &ValidationError{Name: "task_id", err: fmt.Errorf(`ent: validator failed for field "SoraCacheFile.task_id": %w`, err)} + } + } + if v, ok := _u.mutation.MediaType(); ok { + if err := soracachefile.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "SoraCacheFile.media_type": %w`, err)} + } + } + return nil +} + +func (_u *SoraCacheFileUpdateOne) sqlSave(ctx context.Context) (_node *SoraCacheFile, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(soracachefile.Table, soracachefile.Columns, sqlgraph.NewFieldSpec(soracachefile.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SoraCacheFile.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, soracachefile.FieldID) + for _, f := range fields { + if !soracachefile.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != soracachefile.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.TaskID(); ok { + _spec.SetField(soracachefile.FieldTaskID, field.TypeString, value) + } + if _u.mutation.TaskIDCleared() { + _spec.ClearField(soracachefile.FieldTaskID, field.TypeString) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(soracachefile.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(soracachefile.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.UserID(); ok { + _spec.SetField(soracachefile.FieldUserID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedUserID(); ok { + _spec.AddField(soracachefile.FieldUserID, field.TypeInt64, value) + } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(soracachefile.FieldMediaType, field.TypeString, value) + } + if value, ok := _u.mutation.OriginalURL(); ok { + _spec.SetField(soracachefile.FieldOriginalURL, field.TypeString, value) + } + if value, ok := _u.mutation.CachePath(); ok { + _spec.SetField(soracachefile.FieldCachePath, field.TypeString, value) + } + if value, ok := _u.mutation.CacheURL(); ok { + _spec.SetField(soracachefile.FieldCacheURL, field.TypeString, value) + } + if value, ok := _u.mutation.SizeBytes(); ok { + _spec.SetField(soracachefile.FieldSizeBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSizeBytes(); ok { + _spec.AddField(soracachefile.FieldSizeBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(soracachefile.FieldCreatedAt, field.TypeTime, value) + } + _node = &SoraCacheFile{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{soracachefile.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/soratask.go b/backend/ent/soratask.go new file mode 100644 index 00000000..806badf1 --- /dev/null +++ b/backend/ent/soratask.go @@ -0,0 +1,227 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/soratask" +) + +// SoraTask is the model entity for the SoraTask schema. +type SoraTask struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // TaskID holds the value of the "task_id" field. + TaskID string `json:"task_id,omitempty"` + // AccountID holds the value of the "account_id" field. + AccountID int64 `json:"account_id,omitempty"` + // Model holds the value of the "model" field. + Model string `json:"model,omitempty"` + // Prompt holds the value of the "prompt" field. + Prompt string `json:"prompt,omitempty"` + // Status holds the value of the "status" field. + Status string `json:"status,omitempty"` + // Progress holds the value of the "progress" field. + Progress float64 `json:"progress,omitempty"` + // ResultUrls holds the value of the "result_urls" field. + ResultUrls *string `json:"result_urls,omitempty"` + // ErrorMessage holds the value of the "error_message" field. + ErrorMessage *string `json:"error_message,omitempty"` + // RetryCount holds the value of the "retry_count" field. + RetryCount int `json:"retry_count,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // CompletedAt holds the value of the "completed_at" field. + CompletedAt *time.Time `json:"completed_at,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*SoraTask) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case soratask.FieldProgress: + values[i] = new(sql.NullFloat64) + case soratask.FieldID, soratask.FieldAccountID, soratask.FieldRetryCount: + values[i] = new(sql.NullInt64) + case soratask.FieldTaskID, soratask.FieldModel, soratask.FieldPrompt, soratask.FieldStatus, soratask.FieldResultUrls, soratask.FieldErrorMessage: + values[i] = new(sql.NullString) + case soratask.FieldCreatedAt, soratask.FieldCompletedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the SoraTask fields. +func (_m *SoraTask) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case soratask.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case soratask.FieldTaskID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field task_id", values[i]) + } else if value.Valid { + _m.TaskID = value.String + } + case soratask.FieldAccountID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field account_id", values[i]) + } else if value.Valid { + _m.AccountID = value.Int64 + } + case soratask.FieldModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model", values[i]) + } else if value.Valid { + _m.Model = value.String + } + case soratask.FieldPrompt: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field prompt", values[i]) + } else if value.Valid { + _m.Prompt = value.String + } + case soratask.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + _m.Status = value.String + } + case soratask.FieldProgress: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field progress", values[i]) + } else if value.Valid { + _m.Progress = value.Float64 + } + case soratask.FieldResultUrls: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field result_urls", values[i]) + } else if value.Valid { + _m.ResultUrls = new(string) + *_m.ResultUrls = value.String + } + case soratask.FieldErrorMessage: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field error_message", values[i]) + } else if value.Valid { + _m.ErrorMessage = new(string) + *_m.ErrorMessage = value.String + } + case soratask.FieldRetryCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field retry_count", values[i]) + } else if value.Valid { + _m.RetryCount = int(value.Int64) + } + case soratask.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case soratask.FieldCompletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field completed_at", values[i]) + } else if value.Valid { + _m.CompletedAt = new(time.Time) + *_m.CompletedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the SoraTask. +// This includes values selected through modifiers, order, etc. +func (_m *SoraTask) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this SoraTask. +// Note that you need to call SoraTask.Unwrap() before calling this method if this SoraTask +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *SoraTask) Update() *SoraTaskUpdateOne { + return NewSoraTaskClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the SoraTask entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *SoraTask) Unwrap() *SoraTask { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: SoraTask is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *SoraTask) String() string { + var builder strings.Builder + builder.WriteString("SoraTask(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("task_id=") + builder.WriteString(_m.TaskID) + builder.WriteString(", ") + builder.WriteString("account_id=") + builder.WriteString(fmt.Sprintf("%v", _m.AccountID)) + builder.WriteString(", ") + builder.WriteString("model=") + builder.WriteString(_m.Model) + builder.WriteString(", ") + builder.WriteString("prompt=") + builder.WriteString(_m.Prompt) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(_m.Status) + builder.WriteString(", ") + builder.WriteString("progress=") + builder.WriteString(fmt.Sprintf("%v", _m.Progress)) + builder.WriteString(", ") + if v := _m.ResultUrls; v != nil { + builder.WriteString("result_urls=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ErrorMessage; v != nil { + builder.WriteString("error_message=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("retry_count=") + builder.WriteString(fmt.Sprintf("%v", _m.RetryCount)) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.CompletedAt; v != nil { + builder.WriteString("completed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// SoraTasks is a parsable slice of SoraTask. +type SoraTasks []*SoraTask diff --git a/backend/ent/soratask/soratask.go b/backend/ent/soratask/soratask.go new file mode 100644 index 00000000..fc4e894b --- /dev/null +++ b/backend/ent/soratask/soratask.go @@ -0,0 +1,146 @@ +// Code generated by ent, DO NOT EDIT. + +package soratask + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the soratask type in the database. + Label = "sora_task" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldTaskID holds the string denoting the task_id field in the database. + FieldTaskID = "task_id" + // FieldAccountID holds the string denoting the account_id field in the database. + FieldAccountID = "account_id" + // FieldModel holds the string denoting the model field in the database. + FieldModel = "model" + // FieldPrompt holds the string denoting the prompt field in the database. + FieldPrompt = "prompt" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldProgress holds the string denoting the progress field in the database. + FieldProgress = "progress" + // FieldResultUrls holds the string denoting the result_urls field in the database. + FieldResultUrls = "result_urls" + // FieldErrorMessage holds the string denoting the error_message field in the database. + FieldErrorMessage = "error_message" + // FieldRetryCount holds the string denoting the retry_count field in the database. + FieldRetryCount = "retry_count" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldCompletedAt holds the string denoting the completed_at field in the database. + FieldCompletedAt = "completed_at" + // Table holds the table name of the soratask in the database. + Table = "sora_tasks" +) + +// Columns holds all SQL columns for soratask fields. +var Columns = []string{ + FieldID, + FieldTaskID, + FieldAccountID, + FieldModel, + FieldPrompt, + FieldStatus, + FieldProgress, + FieldResultUrls, + FieldErrorMessage, + FieldRetryCount, + FieldCreatedAt, + FieldCompletedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // TaskIDValidator is a validator for the "task_id" field. It is called by the builders before save. + TaskIDValidator func(string) error + // ModelValidator is a validator for the "model" field. It is called by the builders before save. + ModelValidator func(string) error + // DefaultStatus holds the default value on creation for the "status" field. + DefaultStatus string + // StatusValidator is a validator for the "status" field. It is called by the builders before save. + StatusValidator func(string) error + // DefaultProgress holds the default value on creation for the "progress" field. + DefaultProgress float64 + // DefaultRetryCount holds the default value on creation for the "retry_count" field. + DefaultRetryCount int + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time +) + +// OrderOption defines the ordering options for the SoraTask queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByTaskID orders the results by the task_id field. +func ByTaskID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTaskID, opts...).ToFunc() +} + +// ByAccountID orders the results by the account_id field. +func ByAccountID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccountID, opts...).ToFunc() +} + +// ByModel orders the results by the model field. +func ByModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModel, opts...).ToFunc() +} + +// ByPrompt orders the results by the prompt field. +func ByPrompt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPrompt, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByProgress orders the results by the progress field. +func ByProgress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProgress, opts...).ToFunc() +} + +// ByResultUrls orders the results by the result_urls field. +func ByResultUrls(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResultUrls, opts...).ToFunc() +} + +// ByErrorMessage orders the results by the error_message field. +func ByErrorMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorMessage, opts...).ToFunc() +} + +// ByRetryCount orders the results by the retry_count field. +func ByRetryCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRetryCount, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByCompletedAt orders the results by the completed_at field. +func ByCompletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletedAt, opts...).ToFunc() +} diff --git a/backend/ent/soratask/where.go b/backend/ent/soratask/where.go new file mode 100644 index 00000000..2d52c6dd --- /dev/null +++ b/backend/ent/soratask/where.go @@ -0,0 +1,745 @@ +// Code generated by ent, DO NOT EDIT. + +package soratask + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldID, id)) +} + +// TaskID applies equality check predicate on the "task_id" field. It's identical to TaskIDEQ. +func TaskID(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldTaskID, v)) +} + +// AccountID applies equality check predicate on the "account_id" field. It's identical to AccountIDEQ. +func AccountID(v int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldAccountID, v)) +} + +// Model applies equality check predicate on the "model" field. It's identical to ModelEQ. +func Model(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldModel, v)) +} + +// Prompt applies equality check predicate on the "prompt" field. It's identical to PromptEQ. +func Prompt(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldPrompt, v)) +} + +// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. +func Status(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldStatus, v)) +} + +// Progress applies equality check predicate on the "progress" field. It's identical to ProgressEQ. +func Progress(v float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldProgress, v)) +} + +// ResultUrls applies equality check predicate on the "result_urls" field. It's identical to ResultUrlsEQ. +func ResultUrls(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldResultUrls, v)) +} + +// ErrorMessage applies equality check predicate on the "error_message" field. It's identical to ErrorMessageEQ. +func ErrorMessage(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldErrorMessage, v)) +} + +// RetryCount applies equality check predicate on the "retry_count" field. It's identical to RetryCountEQ. +func RetryCount(v int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldRetryCount, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CompletedAt applies equality check predicate on the "completed_at" field. It's identical to CompletedAtEQ. +func CompletedAt(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldCompletedAt, v)) +} + +// TaskIDEQ applies the EQ predicate on the "task_id" field. +func TaskIDEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldTaskID, v)) +} + +// TaskIDNEQ applies the NEQ predicate on the "task_id" field. +func TaskIDNEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldTaskID, v)) +} + +// TaskIDIn applies the In predicate on the "task_id" field. +func TaskIDIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldTaskID, vs...)) +} + +// TaskIDNotIn applies the NotIn predicate on the "task_id" field. +func TaskIDNotIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldTaskID, vs...)) +} + +// TaskIDGT applies the GT predicate on the "task_id" field. +func TaskIDGT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldTaskID, v)) +} + +// TaskIDGTE applies the GTE predicate on the "task_id" field. +func TaskIDGTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldTaskID, v)) +} + +// TaskIDLT applies the LT predicate on the "task_id" field. +func TaskIDLT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldTaskID, v)) +} + +// TaskIDLTE applies the LTE predicate on the "task_id" field. +func TaskIDLTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldTaskID, v)) +} + +// TaskIDContains applies the Contains predicate on the "task_id" field. +func TaskIDContains(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContains(FieldTaskID, v)) +} + +// TaskIDHasPrefix applies the HasPrefix predicate on the "task_id" field. +func TaskIDHasPrefix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasPrefix(FieldTaskID, v)) +} + +// TaskIDHasSuffix applies the HasSuffix predicate on the "task_id" field. +func TaskIDHasSuffix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasSuffix(FieldTaskID, v)) +} + +// TaskIDEqualFold applies the EqualFold predicate on the "task_id" field. +func TaskIDEqualFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEqualFold(FieldTaskID, v)) +} + +// TaskIDContainsFold applies the ContainsFold predicate on the "task_id" field. +func TaskIDContainsFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContainsFold(FieldTaskID, v)) +} + +// AccountIDEQ applies the EQ predicate on the "account_id" field. +func AccountIDEQ(v int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldAccountID, v)) +} + +// AccountIDNEQ applies the NEQ predicate on the "account_id" field. +func AccountIDNEQ(v int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldAccountID, v)) +} + +// AccountIDIn applies the In predicate on the "account_id" field. +func AccountIDIn(vs ...int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldAccountID, vs...)) +} + +// AccountIDNotIn applies the NotIn predicate on the "account_id" field. +func AccountIDNotIn(vs ...int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldAccountID, vs...)) +} + +// AccountIDGT applies the GT predicate on the "account_id" field. +func AccountIDGT(v int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldAccountID, v)) +} + +// AccountIDGTE applies the GTE predicate on the "account_id" field. +func AccountIDGTE(v int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldAccountID, v)) +} + +// AccountIDLT applies the LT predicate on the "account_id" field. +func AccountIDLT(v int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldAccountID, v)) +} + +// AccountIDLTE applies the LTE predicate on the "account_id" field. +func AccountIDLTE(v int64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldAccountID, v)) +} + +// ModelEQ applies the EQ predicate on the "model" field. +func ModelEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldModel, v)) +} + +// ModelNEQ applies the NEQ predicate on the "model" field. +func ModelNEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldModel, v)) +} + +// ModelIn applies the In predicate on the "model" field. +func ModelIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldModel, vs...)) +} + +// ModelNotIn applies the NotIn predicate on the "model" field. +func ModelNotIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldModel, vs...)) +} + +// ModelGT applies the GT predicate on the "model" field. +func ModelGT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldModel, v)) +} + +// ModelGTE applies the GTE predicate on the "model" field. +func ModelGTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldModel, v)) +} + +// ModelLT applies the LT predicate on the "model" field. +func ModelLT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldModel, v)) +} + +// ModelLTE applies the LTE predicate on the "model" field. +func ModelLTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldModel, v)) +} + +// ModelContains applies the Contains predicate on the "model" field. +func ModelContains(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContains(FieldModel, v)) +} + +// ModelHasPrefix applies the HasPrefix predicate on the "model" field. +func ModelHasPrefix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasPrefix(FieldModel, v)) +} + +// ModelHasSuffix applies the HasSuffix predicate on the "model" field. +func ModelHasSuffix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasSuffix(FieldModel, v)) +} + +// ModelEqualFold applies the EqualFold predicate on the "model" field. +func ModelEqualFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEqualFold(FieldModel, v)) +} + +// ModelContainsFold applies the ContainsFold predicate on the "model" field. +func ModelContainsFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContainsFold(FieldModel, v)) +} + +// PromptEQ applies the EQ predicate on the "prompt" field. +func PromptEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldPrompt, v)) +} + +// PromptNEQ applies the NEQ predicate on the "prompt" field. +func PromptNEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldPrompt, v)) +} + +// PromptIn applies the In predicate on the "prompt" field. +func PromptIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldPrompt, vs...)) +} + +// PromptNotIn applies the NotIn predicate on the "prompt" field. +func PromptNotIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldPrompt, vs...)) +} + +// PromptGT applies the GT predicate on the "prompt" field. +func PromptGT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldPrompt, v)) +} + +// PromptGTE applies the GTE predicate on the "prompt" field. +func PromptGTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldPrompt, v)) +} + +// PromptLT applies the LT predicate on the "prompt" field. +func PromptLT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldPrompt, v)) +} + +// PromptLTE applies the LTE predicate on the "prompt" field. +func PromptLTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldPrompt, v)) +} + +// PromptContains applies the Contains predicate on the "prompt" field. +func PromptContains(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContains(FieldPrompt, v)) +} + +// PromptHasPrefix applies the HasPrefix predicate on the "prompt" field. +func PromptHasPrefix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasPrefix(FieldPrompt, v)) +} + +// PromptHasSuffix applies the HasSuffix predicate on the "prompt" field. +func PromptHasSuffix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasSuffix(FieldPrompt, v)) +} + +// PromptEqualFold applies the EqualFold predicate on the "prompt" field. +func PromptEqualFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEqualFold(FieldPrompt, v)) +} + +// PromptContainsFold applies the ContainsFold predicate on the "prompt" field. +func PromptContainsFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContainsFold(FieldPrompt, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldStatus, vs...)) +} + +// StatusGT applies the GT predicate on the "status" field. +func StatusGT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldStatus, v)) +} + +// StatusGTE applies the GTE predicate on the "status" field. +func StatusGTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldStatus, v)) +} + +// StatusLT applies the LT predicate on the "status" field. +func StatusLT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldStatus, v)) +} + +// StatusLTE applies the LTE predicate on the "status" field. +func StatusLTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldStatus, v)) +} + +// StatusContains applies the Contains predicate on the "status" field. +func StatusContains(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContains(FieldStatus, v)) +} + +// StatusHasPrefix applies the HasPrefix predicate on the "status" field. +func StatusHasPrefix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasPrefix(FieldStatus, v)) +} + +// StatusHasSuffix applies the HasSuffix predicate on the "status" field. +func StatusHasSuffix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasSuffix(FieldStatus, v)) +} + +// StatusEqualFold applies the EqualFold predicate on the "status" field. +func StatusEqualFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEqualFold(FieldStatus, v)) +} + +// StatusContainsFold applies the ContainsFold predicate on the "status" field. +func StatusContainsFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContainsFold(FieldStatus, v)) +} + +// ProgressEQ applies the EQ predicate on the "progress" field. +func ProgressEQ(v float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldProgress, v)) +} + +// ProgressNEQ applies the NEQ predicate on the "progress" field. +func ProgressNEQ(v float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldProgress, v)) +} + +// ProgressIn applies the In predicate on the "progress" field. +func ProgressIn(vs ...float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldProgress, vs...)) +} + +// ProgressNotIn applies the NotIn predicate on the "progress" field. +func ProgressNotIn(vs ...float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldProgress, vs...)) +} + +// ProgressGT applies the GT predicate on the "progress" field. +func ProgressGT(v float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldProgress, v)) +} + +// ProgressGTE applies the GTE predicate on the "progress" field. +func ProgressGTE(v float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldProgress, v)) +} + +// ProgressLT applies the LT predicate on the "progress" field. +func ProgressLT(v float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldProgress, v)) +} + +// ProgressLTE applies the LTE predicate on the "progress" field. +func ProgressLTE(v float64) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldProgress, v)) +} + +// ResultUrlsEQ applies the EQ predicate on the "result_urls" field. +func ResultUrlsEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldResultUrls, v)) +} + +// ResultUrlsNEQ applies the NEQ predicate on the "result_urls" field. +func ResultUrlsNEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldResultUrls, v)) +} + +// ResultUrlsIn applies the In predicate on the "result_urls" field. +func ResultUrlsIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldResultUrls, vs...)) +} + +// ResultUrlsNotIn applies the NotIn predicate on the "result_urls" field. +func ResultUrlsNotIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldResultUrls, vs...)) +} + +// ResultUrlsGT applies the GT predicate on the "result_urls" field. +func ResultUrlsGT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldResultUrls, v)) +} + +// ResultUrlsGTE applies the GTE predicate on the "result_urls" field. +func ResultUrlsGTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldResultUrls, v)) +} + +// ResultUrlsLT applies the LT predicate on the "result_urls" field. +func ResultUrlsLT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldResultUrls, v)) +} + +// ResultUrlsLTE applies the LTE predicate on the "result_urls" field. +func ResultUrlsLTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldResultUrls, v)) +} + +// ResultUrlsContains applies the Contains predicate on the "result_urls" field. +func ResultUrlsContains(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContains(FieldResultUrls, v)) +} + +// ResultUrlsHasPrefix applies the HasPrefix predicate on the "result_urls" field. +func ResultUrlsHasPrefix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasPrefix(FieldResultUrls, v)) +} + +// ResultUrlsHasSuffix applies the HasSuffix predicate on the "result_urls" field. +func ResultUrlsHasSuffix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasSuffix(FieldResultUrls, v)) +} + +// ResultUrlsIsNil applies the IsNil predicate on the "result_urls" field. +func ResultUrlsIsNil() predicate.SoraTask { + return predicate.SoraTask(sql.FieldIsNull(FieldResultUrls)) +} + +// ResultUrlsNotNil applies the NotNil predicate on the "result_urls" field. +func ResultUrlsNotNil() predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotNull(FieldResultUrls)) +} + +// ResultUrlsEqualFold applies the EqualFold predicate on the "result_urls" field. +func ResultUrlsEqualFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEqualFold(FieldResultUrls, v)) +} + +// ResultUrlsContainsFold applies the ContainsFold predicate on the "result_urls" field. +func ResultUrlsContainsFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContainsFold(FieldResultUrls, v)) +} + +// ErrorMessageEQ applies the EQ predicate on the "error_message" field. +func ErrorMessageEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldErrorMessage, v)) +} + +// ErrorMessageNEQ applies the NEQ predicate on the "error_message" field. +func ErrorMessageNEQ(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldErrorMessage, v)) +} + +// ErrorMessageIn applies the In predicate on the "error_message" field. +func ErrorMessageIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldErrorMessage, vs...)) +} + +// ErrorMessageNotIn applies the NotIn predicate on the "error_message" field. +func ErrorMessageNotIn(vs ...string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldErrorMessage, vs...)) +} + +// ErrorMessageGT applies the GT predicate on the "error_message" field. +func ErrorMessageGT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldErrorMessage, v)) +} + +// ErrorMessageGTE applies the GTE predicate on the "error_message" field. +func ErrorMessageGTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldErrorMessage, v)) +} + +// ErrorMessageLT applies the LT predicate on the "error_message" field. +func ErrorMessageLT(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldErrorMessage, v)) +} + +// ErrorMessageLTE applies the LTE predicate on the "error_message" field. +func ErrorMessageLTE(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldErrorMessage, v)) +} + +// ErrorMessageContains applies the Contains predicate on the "error_message" field. +func ErrorMessageContains(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContains(FieldErrorMessage, v)) +} + +// ErrorMessageHasPrefix applies the HasPrefix predicate on the "error_message" field. +func ErrorMessageHasPrefix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasPrefix(FieldErrorMessage, v)) +} + +// ErrorMessageHasSuffix applies the HasSuffix predicate on the "error_message" field. +func ErrorMessageHasSuffix(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldHasSuffix(FieldErrorMessage, v)) +} + +// ErrorMessageIsNil applies the IsNil predicate on the "error_message" field. +func ErrorMessageIsNil() predicate.SoraTask { + return predicate.SoraTask(sql.FieldIsNull(FieldErrorMessage)) +} + +// ErrorMessageNotNil applies the NotNil predicate on the "error_message" field. +func ErrorMessageNotNil() predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotNull(FieldErrorMessage)) +} + +// ErrorMessageEqualFold applies the EqualFold predicate on the "error_message" field. +func ErrorMessageEqualFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEqualFold(FieldErrorMessage, v)) +} + +// ErrorMessageContainsFold applies the ContainsFold predicate on the "error_message" field. +func ErrorMessageContainsFold(v string) predicate.SoraTask { + return predicate.SoraTask(sql.FieldContainsFold(FieldErrorMessage, v)) +} + +// RetryCountEQ applies the EQ predicate on the "retry_count" field. +func RetryCountEQ(v int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldRetryCount, v)) +} + +// RetryCountNEQ applies the NEQ predicate on the "retry_count" field. +func RetryCountNEQ(v int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldRetryCount, v)) +} + +// RetryCountIn applies the In predicate on the "retry_count" field. +func RetryCountIn(vs ...int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldRetryCount, vs...)) +} + +// RetryCountNotIn applies the NotIn predicate on the "retry_count" field. +func RetryCountNotIn(vs ...int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldRetryCount, vs...)) +} + +// RetryCountGT applies the GT predicate on the "retry_count" field. +func RetryCountGT(v int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldRetryCount, v)) +} + +// RetryCountGTE applies the GTE predicate on the "retry_count" field. +func RetryCountGTE(v int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldRetryCount, v)) +} + +// RetryCountLT applies the LT predicate on the "retry_count" field. +func RetryCountLT(v int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldRetryCount, v)) +} + +// RetryCountLTE applies the LTE predicate on the "retry_count" field. +func RetryCountLTE(v int) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldRetryCount, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldCreatedAt, v)) +} + +// CompletedAtEQ applies the EQ predicate on the "completed_at" field. +func CompletedAtEQ(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldEQ(FieldCompletedAt, v)) +} + +// CompletedAtNEQ applies the NEQ predicate on the "completed_at" field. +func CompletedAtNEQ(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNEQ(FieldCompletedAt, v)) +} + +// CompletedAtIn applies the In predicate on the "completed_at" field. +func CompletedAtIn(vs ...time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldIn(FieldCompletedAt, vs...)) +} + +// CompletedAtNotIn applies the NotIn predicate on the "completed_at" field. +func CompletedAtNotIn(vs ...time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotIn(FieldCompletedAt, vs...)) +} + +// CompletedAtGT applies the GT predicate on the "completed_at" field. +func CompletedAtGT(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGT(FieldCompletedAt, v)) +} + +// CompletedAtGTE applies the GTE predicate on the "completed_at" field. +func CompletedAtGTE(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldGTE(FieldCompletedAt, v)) +} + +// CompletedAtLT applies the LT predicate on the "completed_at" field. +func CompletedAtLT(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLT(FieldCompletedAt, v)) +} + +// CompletedAtLTE applies the LTE predicate on the "completed_at" field. +func CompletedAtLTE(v time.Time) predicate.SoraTask { + return predicate.SoraTask(sql.FieldLTE(FieldCompletedAt, v)) +} + +// CompletedAtIsNil applies the IsNil predicate on the "completed_at" field. +func CompletedAtIsNil() predicate.SoraTask { + return predicate.SoraTask(sql.FieldIsNull(FieldCompletedAt)) +} + +// CompletedAtNotNil applies the NotNil predicate on the "completed_at" field. +func CompletedAtNotNil() predicate.SoraTask { + return predicate.SoraTask(sql.FieldNotNull(FieldCompletedAt)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.SoraTask) predicate.SoraTask { + return predicate.SoraTask(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.SoraTask) predicate.SoraTask { + return predicate.SoraTask(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.SoraTask) predicate.SoraTask { + return predicate.SoraTask(sql.NotPredicates(p)) +} diff --git a/backend/ent/soratask_create.go b/backend/ent/soratask_create.go new file mode 100644 index 00000000..57efb168 --- /dev/null +++ b/backend/ent/soratask_create.go @@ -0,0 +1,1189 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/soratask" +) + +// SoraTaskCreate is the builder for creating a SoraTask entity. +type SoraTaskCreate struct { + config + mutation *SoraTaskMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetTaskID sets the "task_id" field. +func (_c *SoraTaskCreate) SetTaskID(v string) *SoraTaskCreate { + _c.mutation.SetTaskID(v) + return _c +} + +// SetAccountID sets the "account_id" field. +func (_c *SoraTaskCreate) SetAccountID(v int64) *SoraTaskCreate { + _c.mutation.SetAccountID(v) + return _c +} + +// SetModel sets the "model" field. +func (_c *SoraTaskCreate) SetModel(v string) *SoraTaskCreate { + _c.mutation.SetModel(v) + return _c +} + +// SetPrompt sets the "prompt" field. +func (_c *SoraTaskCreate) SetPrompt(v string) *SoraTaskCreate { + _c.mutation.SetPrompt(v) + return _c +} + +// SetStatus sets the "status" field. +func (_c *SoraTaskCreate) SetStatus(v string) *SoraTaskCreate { + _c.mutation.SetStatus(v) + return _c +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_c *SoraTaskCreate) SetNillableStatus(v *string) *SoraTaskCreate { + if v != nil { + _c.SetStatus(*v) + } + return _c +} + +// SetProgress sets the "progress" field. +func (_c *SoraTaskCreate) SetProgress(v float64) *SoraTaskCreate { + _c.mutation.SetProgress(v) + return _c +} + +// SetNillableProgress sets the "progress" field if the given value is not nil. +func (_c *SoraTaskCreate) SetNillableProgress(v *float64) *SoraTaskCreate { + if v != nil { + _c.SetProgress(*v) + } + return _c +} + +// SetResultUrls sets the "result_urls" field. +func (_c *SoraTaskCreate) SetResultUrls(v string) *SoraTaskCreate { + _c.mutation.SetResultUrls(v) + return _c +} + +// SetNillableResultUrls sets the "result_urls" field if the given value is not nil. +func (_c *SoraTaskCreate) SetNillableResultUrls(v *string) *SoraTaskCreate { + if v != nil { + _c.SetResultUrls(*v) + } + return _c +} + +// SetErrorMessage sets the "error_message" field. +func (_c *SoraTaskCreate) SetErrorMessage(v string) *SoraTaskCreate { + _c.mutation.SetErrorMessage(v) + return _c +} + +// SetNillableErrorMessage sets the "error_message" field if the given value is not nil. +func (_c *SoraTaskCreate) SetNillableErrorMessage(v *string) *SoraTaskCreate { + if v != nil { + _c.SetErrorMessage(*v) + } + return _c +} + +// SetRetryCount sets the "retry_count" field. +func (_c *SoraTaskCreate) SetRetryCount(v int) *SoraTaskCreate { + _c.mutation.SetRetryCount(v) + return _c +} + +// SetNillableRetryCount sets the "retry_count" field if the given value is not nil. +func (_c *SoraTaskCreate) SetNillableRetryCount(v *int) *SoraTaskCreate { + if v != nil { + _c.SetRetryCount(*v) + } + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *SoraTaskCreate) SetCreatedAt(v time.Time) *SoraTaskCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *SoraTaskCreate) SetNillableCreatedAt(v *time.Time) *SoraTaskCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetCompletedAt sets the "completed_at" field. +func (_c *SoraTaskCreate) SetCompletedAt(v time.Time) *SoraTaskCreate { + _c.mutation.SetCompletedAt(v) + return _c +} + +// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil. +func (_c *SoraTaskCreate) SetNillableCompletedAt(v *time.Time) *SoraTaskCreate { + if v != nil { + _c.SetCompletedAt(*v) + } + return _c +} + +// Mutation returns the SoraTaskMutation object of the builder. +func (_c *SoraTaskCreate) Mutation() *SoraTaskMutation { + return _c.mutation +} + +// Save creates the SoraTask in the database. +func (_c *SoraTaskCreate) Save(ctx context.Context) (*SoraTask, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *SoraTaskCreate) SaveX(ctx context.Context) *SoraTask { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraTaskCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraTaskCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *SoraTaskCreate) defaults() { + if _, ok := _c.mutation.Status(); !ok { + v := soratask.DefaultStatus + _c.mutation.SetStatus(v) + } + if _, ok := _c.mutation.Progress(); !ok { + v := soratask.DefaultProgress + _c.mutation.SetProgress(v) + } + if _, ok := _c.mutation.RetryCount(); !ok { + v := soratask.DefaultRetryCount + _c.mutation.SetRetryCount(v) + } + if _, ok := _c.mutation.CreatedAt(); !ok { + v := soratask.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *SoraTaskCreate) check() error { + if _, ok := _c.mutation.TaskID(); !ok { + return &ValidationError{Name: "task_id", err: errors.New(`ent: missing required field "SoraTask.task_id"`)} + } + if v, ok := _c.mutation.TaskID(); ok { + if err := soratask.TaskIDValidator(v); err != nil { + return &ValidationError{Name: "task_id", err: fmt.Errorf(`ent: validator failed for field "SoraTask.task_id": %w`, err)} + } + } + if _, ok := _c.mutation.AccountID(); !ok { + return &ValidationError{Name: "account_id", err: errors.New(`ent: missing required field "SoraTask.account_id"`)} + } + if _, ok := _c.mutation.Model(); !ok { + return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "SoraTask.model"`)} + } + if v, ok := _c.mutation.Model(); ok { + if err := soratask.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "SoraTask.model": %w`, err)} + } + } + if _, ok := _c.mutation.Prompt(); !ok { + return &ValidationError{Name: "prompt", err: errors.New(`ent: missing required field "SoraTask.prompt"`)} + } + if _, ok := _c.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "SoraTask.status"`)} + } + if v, ok := _c.mutation.Status(); ok { + if err := soratask.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "SoraTask.status": %w`, err)} + } + } + if _, ok := _c.mutation.Progress(); !ok { + return &ValidationError{Name: "progress", err: errors.New(`ent: missing required field "SoraTask.progress"`)} + } + if _, ok := _c.mutation.RetryCount(); !ok { + return &ValidationError{Name: "retry_count", err: errors.New(`ent: missing required field "SoraTask.retry_count"`)} + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SoraTask.created_at"`)} + } + return nil +} + +func (_c *SoraTaskCreate) sqlSave(ctx context.Context) (*SoraTask, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *SoraTaskCreate) createSpec() (*SoraTask, *sqlgraph.CreateSpec) { + var ( + _node = &SoraTask{config: _c.config} + _spec = sqlgraph.NewCreateSpec(soratask.Table, sqlgraph.NewFieldSpec(soratask.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.TaskID(); ok { + _spec.SetField(soratask.FieldTaskID, field.TypeString, value) + _node.TaskID = value + } + if value, ok := _c.mutation.AccountID(); ok { + _spec.SetField(soratask.FieldAccountID, field.TypeInt64, value) + _node.AccountID = value + } + if value, ok := _c.mutation.Model(); ok { + _spec.SetField(soratask.FieldModel, field.TypeString, value) + _node.Model = value + } + if value, ok := _c.mutation.Prompt(); ok { + _spec.SetField(soratask.FieldPrompt, field.TypeString, value) + _node.Prompt = value + } + if value, ok := _c.mutation.Status(); ok { + _spec.SetField(soratask.FieldStatus, field.TypeString, value) + _node.Status = value + } + if value, ok := _c.mutation.Progress(); ok { + _spec.SetField(soratask.FieldProgress, field.TypeFloat64, value) + _node.Progress = value + } + if value, ok := _c.mutation.ResultUrls(); ok { + _spec.SetField(soratask.FieldResultUrls, field.TypeString, value) + _node.ResultUrls = &value + } + if value, ok := _c.mutation.ErrorMessage(); ok { + _spec.SetField(soratask.FieldErrorMessage, field.TypeString, value) + _node.ErrorMessage = &value + } + if value, ok := _c.mutation.RetryCount(); ok { + _spec.SetField(soratask.FieldRetryCount, field.TypeInt, value) + _node.RetryCount = value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(soratask.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.CompletedAt(); ok { + _spec.SetField(soratask.FieldCompletedAt, field.TypeTime, value) + _node.CompletedAt = &value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraTask.Create(). +// SetTaskID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraTaskUpsert) { +// SetTaskID(v+v). +// }). +// Exec(ctx) +func (_c *SoraTaskCreate) OnConflict(opts ...sql.ConflictOption) *SoraTaskUpsertOne { + _c.conflict = opts + return &SoraTaskUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraTask.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraTaskCreate) OnConflictColumns(columns ...string) *SoraTaskUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraTaskUpsertOne{ + create: _c, + } +} + +type ( + // SoraTaskUpsertOne is the builder for "upsert"-ing + // one SoraTask node. + SoraTaskUpsertOne struct { + create *SoraTaskCreate + } + + // SoraTaskUpsert is the "OnConflict" setter. + SoraTaskUpsert struct { + *sql.UpdateSet + } +) + +// SetTaskID sets the "task_id" field. +func (u *SoraTaskUpsert) SetTaskID(v string) *SoraTaskUpsert { + u.Set(soratask.FieldTaskID, v) + return u +} + +// UpdateTaskID sets the "task_id" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateTaskID() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldTaskID) + return u +} + +// SetAccountID sets the "account_id" field. +func (u *SoraTaskUpsert) SetAccountID(v int64) *SoraTaskUpsert { + u.Set(soratask.FieldAccountID, v) + return u +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateAccountID() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldAccountID) + return u +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraTaskUpsert) AddAccountID(v int64) *SoraTaskUpsert { + u.Add(soratask.FieldAccountID, v) + return u +} + +// SetModel sets the "model" field. +func (u *SoraTaskUpsert) SetModel(v string) *SoraTaskUpsert { + u.Set(soratask.FieldModel, v) + return u +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateModel() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldModel) + return u +} + +// SetPrompt sets the "prompt" field. +func (u *SoraTaskUpsert) SetPrompt(v string) *SoraTaskUpsert { + u.Set(soratask.FieldPrompt, v) + return u +} + +// UpdatePrompt sets the "prompt" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdatePrompt() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldPrompt) + return u +} + +// SetStatus sets the "status" field. +func (u *SoraTaskUpsert) SetStatus(v string) *SoraTaskUpsert { + u.Set(soratask.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateStatus() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldStatus) + return u +} + +// SetProgress sets the "progress" field. +func (u *SoraTaskUpsert) SetProgress(v float64) *SoraTaskUpsert { + u.Set(soratask.FieldProgress, v) + return u +} + +// UpdateProgress sets the "progress" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateProgress() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldProgress) + return u +} + +// AddProgress adds v to the "progress" field. +func (u *SoraTaskUpsert) AddProgress(v float64) *SoraTaskUpsert { + u.Add(soratask.FieldProgress, v) + return u +} + +// SetResultUrls sets the "result_urls" field. +func (u *SoraTaskUpsert) SetResultUrls(v string) *SoraTaskUpsert { + u.Set(soratask.FieldResultUrls, v) + return u +} + +// UpdateResultUrls sets the "result_urls" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateResultUrls() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldResultUrls) + return u +} + +// ClearResultUrls clears the value of the "result_urls" field. +func (u *SoraTaskUpsert) ClearResultUrls() *SoraTaskUpsert { + u.SetNull(soratask.FieldResultUrls) + return u +} + +// SetErrorMessage sets the "error_message" field. +func (u *SoraTaskUpsert) SetErrorMessage(v string) *SoraTaskUpsert { + u.Set(soratask.FieldErrorMessage, v) + return u +} + +// UpdateErrorMessage sets the "error_message" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateErrorMessage() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldErrorMessage) + return u +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (u *SoraTaskUpsert) ClearErrorMessage() *SoraTaskUpsert { + u.SetNull(soratask.FieldErrorMessage) + return u +} + +// SetRetryCount sets the "retry_count" field. +func (u *SoraTaskUpsert) SetRetryCount(v int) *SoraTaskUpsert { + u.Set(soratask.FieldRetryCount, v) + return u +} + +// UpdateRetryCount sets the "retry_count" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateRetryCount() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldRetryCount) + return u +} + +// AddRetryCount adds v to the "retry_count" field. +func (u *SoraTaskUpsert) AddRetryCount(v int) *SoraTaskUpsert { + u.Add(soratask.FieldRetryCount, v) + return u +} + +// SetCreatedAt sets the "created_at" field. +func (u *SoraTaskUpsert) SetCreatedAt(v time.Time) *SoraTaskUpsert { + u.Set(soratask.FieldCreatedAt, v) + return u +} + +// UpdateCreatedAt sets the "created_at" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateCreatedAt() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldCreatedAt) + return u +} + +// SetCompletedAt sets the "completed_at" field. +func (u *SoraTaskUpsert) SetCompletedAt(v time.Time) *SoraTaskUpsert { + u.Set(soratask.FieldCompletedAt, v) + return u +} + +// UpdateCompletedAt sets the "completed_at" field to the value that was provided on create. +func (u *SoraTaskUpsert) UpdateCompletedAt() *SoraTaskUpsert { + u.SetExcluded(soratask.FieldCompletedAt) + return u +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (u *SoraTaskUpsert) ClearCompletedAt() *SoraTaskUpsert { + u.SetNull(soratask.FieldCompletedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.SoraTask.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraTaskUpsertOne) UpdateNewValues() *SoraTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraTask.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraTaskUpsertOne) Ignore() *SoraTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraTaskUpsertOne) DoNothing() *SoraTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraTaskCreate.OnConflict +// documentation for more info. +func (u *SoraTaskUpsertOne) Update(set func(*SoraTaskUpsert)) *SoraTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraTaskUpsert{UpdateSet: update}) + })) + return u +} + +// SetTaskID sets the "task_id" field. +func (u *SoraTaskUpsertOne) SetTaskID(v string) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetTaskID(v) + }) +} + +// UpdateTaskID sets the "task_id" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateTaskID() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateTaskID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraTaskUpsertOne) SetAccountID(v int64) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraTaskUpsertOne) AddAccountID(v int64) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateAccountID() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateAccountID() + }) +} + +// SetModel sets the "model" field. +func (u *SoraTaskUpsertOne) SetModel(v string) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateModel() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateModel() + }) +} + +// SetPrompt sets the "prompt" field. +func (u *SoraTaskUpsertOne) SetPrompt(v string) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetPrompt(v) + }) +} + +// UpdatePrompt sets the "prompt" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdatePrompt() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdatePrompt() + }) +} + +// SetStatus sets the "status" field. +func (u *SoraTaskUpsertOne) SetStatus(v string) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateStatus() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateStatus() + }) +} + +// SetProgress sets the "progress" field. +func (u *SoraTaskUpsertOne) SetProgress(v float64) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetProgress(v) + }) +} + +// AddProgress adds v to the "progress" field. +func (u *SoraTaskUpsertOne) AddProgress(v float64) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.AddProgress(v) + }) +} + +// UpdateProgress sets the "progress" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateProgress() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateProgress() + }) +} + +// SetResultUrls sets the "result_urls" field. +func (u *SoraTaskUpsertOne) SetResultUrls(v string) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetResultUrls(v) + }) +} + +// UpdateResultUrls sets the "result_urls" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateResultUrls() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateResultUrls() + }) +} + +// ClearResultUrls clears the value of the "result_urls" field. +func (u *SoraTaskUpsertOne) ClearResultUrls() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.ClearResultUrls() + }) +} + +// SetErrorMessage sets the "error_message" field. +func (u *SoraTaskUpsertOne) SetErrorMessage(v string) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetErrorMessage(v) + }) +} + +// UpdateErrorMessage sets the "error_message" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateErrorMessage() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateErrorMessage() + }) +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (u *SoraTaskUpsertOne) ClearErrorMessage() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.ClearErrorMessage() + }) +} + +// SetRetryCount sets the "retry_count" field. +func (u *SoraTaskUpsertOne) SetRetryCount(v int) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetRetryCount(v) + }) +} + +// AddRetryCount adds v to the "retry_count" field. +func (u *SoraTaskUpsertOne) AddRetryCount(v int) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.AddRetryCount(v) + }) +} + +// UpdateRetryCount sets the "retry_count" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateRetryCount() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateRetryCount() + }) +} + +// SetCreatedAt sets the "created_at" field. +func (u *SoraTaskUpsertOne) SetCreatedAt(v time.Time) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetCreatedAt(v) + }) +} + +// UpdateCreatedAt sets the "created_at" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateCreatedAt() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateCreatedAt() + }) +} + +// SetCompletedAt sets the "completed_at" field. +func (u *SoraTaskUpsertOne) SetCompletedAt(v time.Time) *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.SetCompletedAt(v) + }) +} + +// UpdateCompletedAt sets the "completed_at" field to the value that was provided on create. +func (u *SoraTaskUpsertOne) UpdateCompletedAt() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateCompletedAt() + }) +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (u *SoraTaskUpsertOne) ClearCompletedAt() *SoraTaskUpsertOne { + return u.Update(func(s *SoraTaskUpsert) { + s.ClearCompletedAt() + }) +} + +// Exec executes the query. +func (u *SoraTaskUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraTaskCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraTaskUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SoraTaskUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SoraTaskUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// SoraTaskCreateBulk is the builder for creating many SoraTask entities in bulk. +type SoraTaskCreateBulk struct { + config + err error + builders []*SoraTaskCreate + conflict []sql.ConflictOption +} + +// Save creates the SoraTask entities in the database. +func (_c *SoraTaskCreateBulk) Save(ctx context.Context) ([]*SoraTask, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*SoraTask, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SoraTaskMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *SoraTaskCreateBulk) SaveX(ctx context.Context) []*SoraTask { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraTaskCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraTaskCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraTask.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraTaskUpsert) { +// SetTaskID(v+v). +// }). +// Exec(ctx) +func (_c *SoraTaskCreateBulk) OnConflict(opts ...sql.ConflictOption) *SoraTaskUpsertBulk { + _c.conflict = opts + return &SoraTaskUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraTask.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraTaskCreateBulk) OnConflictColumns(columns ...string) *SoraTaskUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraTaskUpsertBulk{ + create: _c, + } +} + +// SoraTaskUpsertBulk is the builder for "upsert"-ing +// a bulk of SoraTask nodes. +type SoraTaskUpsertBulk struct { + create *SoraTaskCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.SoraTask.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraTaskUpsertBulk) UpdateNewValues() *SoraTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraTask.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraTaskUpsertBulk) Ignore() *SoraTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraTaskUpsertBulk) DoNothing() *SoraTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraTaskCreateBulk.OnConflict +// documentation for more info. +func (u *SoraTaskUpsertBulk) Update(set func(*SoraTaskUpsert)) *SoraTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraTaskUpsert{UpdateSet: update}) + })) + return u +} + +// SetTaskID sets the "task_id" field. +func (u *SoraTaskUpsertBulk) SetTaskID(v string) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetTaskID(v) + }) +} + +// UpdateTaskID sets the "task_id" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateTaskID() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateTaskID() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraTaskUpsertBulk) SetAccountID(v int64) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraTaskUpsertBulk) AddAccountID(v int64) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateAccountID() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateAccountID() + }) +} + +// SetModel sets the "model" field. +func (u *SoraTaskUpsertBulk) SetModel(v string) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetModel(v) + }) +} + +// UpdateModel sets the "model" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateModel() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateModel() + }) +} + +// SetPrompt sets the "prompt" field. +func (u *SoraTaskUpsertBulk) SetPrompt(v string) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetPrompt(v) + }) +} + +// UpdatePrompt sets the "prompt" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdatePrompt() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdatePrompt() + }) +} + +// SetStatus sets the "status" field. +func (u *SoraTaskUpsertBulk) SetStatus(v string) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateStatus() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateStatus() + }) +} + +// SetProgress sets the "progress" field. +func (u *SoraTaskUpsertBulk) SetProgress(v float64) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetProgress(v) + }) +} + +// AddProgress adds v to the "progress" field. +func (u *SoraTaskUpsertBulk) AddProgress(v float64) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.AddProgress(v) + }) +} + +// UpdateProgress sets the "progress" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateProgress() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateProgress() + }) +} + +// SetResultUrls sets the "result_urls" field. +func (u *SoraTaskUpsertBulk) SetResultUrls(v string) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetResultUrls(v) + }) +} + +// UpdateResultUrls sets the "result_urls" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateResultUrls() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateResultUrls() + }) +} + +// ClearResultUrls clears the value of the "result_urls" field. +func (u *SoraTaskUpsertBulk) ClearResultUrls() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.ClearResultUrls() + }) +} + +// SetErrorMessage sets the "error_message" field. +func (u *SoraTaskUpsertBulk) SetErrorMessage(v string) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetErrorMessage(v) + }) +} + +// UpdateErrorMessage sets the "error_message" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateErrorMessage() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateErrorMessage() + }) +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (u *SoraTaskUpsertBulk) ClearErrorMessage() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.ClearErrorMessage() + }) +} + +// SetRetryCount sets the "retry_count" field. +func (u *SoraTaskUpsertBulk) SetRetryCount(v int) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetRetryCount(v) + }) +} + +// AddRetryCount adds v to the "retry_count" field. +func (u *SoraTaskUpsertBulk) AddRetryCount(v int) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.AddRetryCount(v) + }) +} + +// UpdateRetryCount sets the "retry_count" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateRetryCount() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateRetryCount() + }) +} + +// SetCreatedAt sets the "created_at" field. +func (u *SoraTaskUpsertBulk) SetCreatedAt(v time.Time) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetCreatedAt(v) + }) +} + +// UpdateCreatedAt sets the "created_at" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateCreatedAt() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateCreatedAt() + }) +} + +// SetCompletedAt sets the "completed_at" field. +func (u *SoraTaskUpsertBulk) SetCompletedAt(v time.Time) *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.SetCompletedAt(v) + }) +} + +// UpdateCompletedAt sets the "completed_at" field to the value that was provided on create. +func (u *SoraTaskUpsertBulk) UpdateCompletedAt() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.UpdateCompletedAt() + }) +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (u *SoraTaskUpsertBulk) ClearCompletedAt() *SoraTaskUpsertBulk { + return u.Update(func(s *SoraTaskUpsert) { + s.ClearCompletedAt() + }) +} + +// Exec executes the query. +func (u *SoraTaskUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SoraTaskCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraTaskCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraTaskUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/soratask_delete.go b/backend/ent/soratask_delete.go new file mode 100644 index 00000000..b33b181f --- /dev/null +++ b/backend/ent/soratask_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soratask" +) + +// SoraTaskDelete is the builder for deleting a SoraTask entity. +type SoraTaskDelete struct { + config + hooks []Hook + mutation *SoraTaskMutation +} + +// Where appends a list predicates to the SoraTaskDelete builder. +func (_d *SoraTaskDelete) Where(ps ...predicate.SoraTask) *SoraTaskDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *SoraTaskDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraTaskDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *SoraTaskDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(soratask.Table, sqlgraph.NewFieldSpec(soratask.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// SoraTaskDeleteOne is the builder for deleting a single SoraTask entity. +type SoraTaskDeleteOne struct { + _d *SoraTaskDelete +} + +// Where appends a list predicates to the SoraTaskDelete builder. +func (_d *SoraTaskDeleteOne) Where(ps ...predicate.SoraTask) *SoraTaskDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *SoraTaskDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{soratask.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraTaskDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/soratask_query.go b/backend/ent/soratask_query.go new file mode 100644 index 00000000..f6a466b0 --- /dev/null +++ b/backend/ent/soratask_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soratask" +) + +// SoraTaskQuery is the builder for querying SoraTask entities. +type SoraTaskQuery struct { + config + ctx *QueryContext + order []soratask.OrderOption + inters []Interceptor + predicates []predicate.SoraTask + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SoraTaskQuery builder. +func (_q *SoraTaskQuery) Where(ps ...predicate.SoraTask) *SoraTaskQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *SoraTaskQuery) Limit(limit int) *SoraTaskQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *SoraTaskQuery) Offset(offset int) *SoraTaskQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *SoraTaskQuery) Unique(unique bool) *SoraTaskQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *SoraTaskQuery) Order(o ...soratask.OrderOption) *SoraTaskQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first SoraTask entity from the query. +// Returns a *NotFoundError when no SoraTask was found. +func (_q *SoraTaskQuery) First(ctx context.Context) (*SoraTask, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{soratask.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *SoraTaskQuery) FirstX(ctx context.Context) *SoraTask { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first SoraTask ID from the query. +// Returns a *NotFoundError when no SoraTask ID was found. +func (_q *SoraTaskQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{soratask.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *SoraTaskQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single SoraTask entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one SoraTask entity is found. +// Returns a *NotFoundError when no SoraTask entities are found. +func (_q *SoraTaskQuery) Only(ctx context.Context) (*SoraTask, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{soratask.Label} + default: + return nil, &NotSingularError{soratask.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *SoraTaskQuery) OnlyX(ctx context.Context) *SoraTask { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only SoraTask ID in the query. +// Returns a *NotSingularError when more than one SoraTask ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *SoraTaskQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{soratask.Label} + default: + err = &NotSingularError{soratask.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *SoraTaskQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of SoraTasks. +func (_q *SoraTaskQuery) All(ctx context.Context) ([]*SoraTask, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*SoraTask, *SoraTaskQuery]() + return withInterceptors[[]*SoraTask](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *SoraTaskQuery) AllX(ctx context.Context) []*SoraTask { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of SoraTask IDs. +func (_q *SoraTaskQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(soratask.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *SoraTaskQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *SoraTaskQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*SoraTaskQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *SoraTaskQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *SoraTaskQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *SoraTaskQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SoraTaskQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *SoraTaskQuery) Clone() *SoraTaskQuery { + if _q == nil { + return nil + } + return &SoraTaskQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]soratask.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.SoraTask{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// TaskID string `json:"task_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.SoraTask.Query(). +// GroupBy(soratask.FieldTaskID). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *SoraTaskQuery) GroupBy(field string, fields ...string) *SoraTaskGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &SoraTaskGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = soratask.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// TaskID string `json:"task_id,omitempty"` +// } +// +// client.SoraTask.Query(). +// Select(soratask.FieldTaskID). +// Scan(ctx, &v) +func (_q *SoraTaskQuery) Select(fields ...string) *SoraTaskSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &SoraTaskSelect{SoraTaskQuery: _q} + sbuild.label = soratask.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SoraTaskSelect configured with the given aggregations. +func (_q *SoraTaskQuery) Aggregate(fns ...AggregateFunc) *SoraTaskSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *SoraTaskQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !soratask.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *SoraTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SoraTask, error) { + var ( + nodes = []*SoraTask{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*SoraTask).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &SoraTask{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *SoraTaskQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *SoraTaskQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(soratask.Table, soratask.Columns, sqlgraph.NewFieldSpec(soratask.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, soratask.FieldID) + for i := range fields { + if fields[i] != soratask.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *SoraTaskQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(soratask.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = soratask.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *SoraTaskQuery) ForUpdate(opts ...sql.LockOption) *SoraTaskQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *SoraTaskQuery) ForShare(opts ...sql.LockOption) *SoraTaskQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// SoraTaskGroupBy is the group-by builder for SoraTask entities. +type SoraTaskGroupBy struct { + selector + build *SoraTaskQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *SoraTaskGroupBy) Aggregate(fns ...AggregateFunc) *SoraTaskGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *SoraTaskGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraTaskQuery, *SoraTaskGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *SoraTaskGroupBy) sqlScan(ctx context.Context, root *SoraTaskQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SoraTaskSelect is the builder for selecting fields of SoraTask entities. +type SoraTaskSelect struct { + *SoraTaskQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *SoraTaskSelect) Aggregate(fns ...AggregateFunc) *SoraTaskSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *SoraTaskSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraTaskQuery, *SoraTaskSelect](ctx, _s.SoraTaskQuery, _s, _s.inters, v) +} + +func (_s *SoraTaskSelect) sqlScan(ctx context.Context, root *SoraTaskQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/soratask_update.go b/backend/ent/soratask_update.go new file mode 100644 index 00000000..d7937ef6 --- /dev/null +++ b/backend/ent/soratask_update.go @@ -0,0 +1,710 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/soratask" +) + +// SoraTaskUpdate is the builder for updating SoraTask entities. +type SoraTaskUpdate struct { + config + hooks []Hook + mutation *SoraTaskMutation +} + +// Where appends a list predicates to the SoraTaskUpdate builder. +func (_u *SoraTaskUpdate) Where(ps ...predicate.SoraTask) *SoraTaskUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetTaskID sets the "task_id" field. +func (_u *SoraTaskUpdate) SetTaskID(v string) *SoraTaskUpdate { + _u.mutation.SetTaskID(v) + return _u +} + +// SetNillableTaskID sets the "task_id" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableTaskID(v *string) *SoraTaskUpdate { + if v != nil { + _u.SetTaskID(*v) + } + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraTaskUpdate) SetAccountID(v int64) *SoraTaskUpdate { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableAccountID(v *int64) *SoraTaskUpdate { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraTaskUpdate) AddAccountID(v int64) *SoraTaskUpdate { + _u.mutation.AddAccountID(v) + return _u +} + +// SetModel sets the "model" field. +func (_u *SoraTaskUpdate) SetModel(v string) *SoraTaskUpdate { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableModel(v *string) *SoraTaskUpdate { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetPrompt sets the "prompt" field. +func (_u *SoraTaskUpdate) SetPrompt(v string) *SoraTaskUpdate { + _u.mutation.SetPrompt(v) + return _u +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillablePrompt(v *string) *SoraTaskUpdate { + if v != nil { + _u.SetPrompt(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *SoraTaskUpdate) SetStatus(v string) *SoraTaskUpdate { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableStatus(v *string) *SoraTaskUpdate { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetProgress sets the "progress" field. +func (_u *SoraTaskUpdate) SetProgress(v float64) *SoraTaskUpdate { + _u.mutation.ResetProgress() + _u.mutation.SetProgress(v) + return _u +} + +// SetNillableProgress sets the "progress" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableProgress(v *float64) *SoraTaskUpdate { + if v != nil { + _u.SetProgress(*v) + } + return _u +} + +// AddProgress adds value to the "progress" field. +func (_u *SoraTaskUpdate) AddProgress(v float64) *SoraTaskUpdate { + _u.mutation.AddProgress(v) + return _u +} + +// SetResultUrls sets the "result_urls" field. +func (_u *SoraTaskUpdate) SetResultUrls(v string) *SoraTaskUpdate { + _u.mutation.SetResultUrls(v) + return _u +} + +// SetNillableResultUrls sets the "result_urls" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableResultUrls(v *string) *SoraTaskUpdate { + if v != nil { + _u.SetResultUrls(*v) + } + return _u +} + +// ClearResultUrls clears the value of the "result_urls" field. +func (_u *SoraTaskUpdate) ClearResultUrls() *SoraTaskUpdate { + _u.mutation.ClearResultUrls() + return _u +} + +// SetErrorMessage sets the "error_message" field. +func (_u *SoraTaskUpdate) SetErrorMessage(v string) *SoraTaskUpdate { + _u.mutation.SetErrorMessage(v) + return _u +} + +// SetNillableErrorMessage sets the "error_message" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableErrorMessage(v *string) *SoraTaskUpdate { + if v != nil { + _u.SetErrorMessage(*v) + } + return _u +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (_u *SoraTaskUpdate) ClearErrorMessage() *SoraTaskUpdate { + _u.mutation.ClearErrorMessage() + return _u +} + +// SetRetryCount sets the "retry_count" field. +func (_u *SoraTaskUpdate) SetRetryCount(v int) *SoraTaskUpdate { + _u.mutation.ResetRetryCount() + _u.mutation.SetRetryCount(v) + return _u +} + +// SetNillableRetryCount sets the "retry_count" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableRetryCount(v *int) *SoraTaskUpdate { + if v != nil { + _u.SetRetryCount(*v) + } + return _u +} + +// AddRetryCount adds value to the "retry_count" field. +func (_u *SoraTaskUpdate) AddRetryCount(v int) *SoraTaskUpdate { + _u.mutation.AddRetryCount(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *SoraTaskUpdate) SetCreatedAt(v time.Time) *SoraTaskUpdate { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableCreatedAt(v *time.Time) *SoraTaskUpdate { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// SetCompletedAt sets the "completed_at" field. +func (_u *SoraTaskUpdate) SetCompletedAt(v time.Time) *SoraTaskUpdate { + _u.mutation.SetCompletedAt(v) + return _u +} + +// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil. +func (_u *SoraTaskUpdate) SetNillableCompletedAt(v *time.Time) *SoraTaskUpdate { + if v != nil { + _u.SetCompletedAt(*v) + } + return _u +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (_u *SoraTaskUpdate) ClearCompletedAt() *SoraTaskUpdate { + _u.mutation.ClearCompletedAt() + return _u +} + +// Mutation returns the SoraTaskMutation object of the builder. +func (_u *SoraTaskUpdate) Mutation() *SoraTaskMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *SoraTaskUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraTaskUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *SoraTaskUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraTaskUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SoraTaskUpdate) check() error { + if v, ok := _u.mutation.TaskID(); ok { + if err := soratask.TaskIDValidator(v); err != nil { + return &ValidationError{Name: "task_id", err: fmt.Errorf(`ent: validator failed for field "SoraTask.task_id": %w`, err)} + } + } + if v, ok := _u.mutation.Model(); ok { + if err := soratask.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "SoraTask.model": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := soratask.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "SoraTask.status": %w`, err)} + } + } + return nil +} + +func (_u *SoraTaskUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(soratask.Table, soratask.Columns, sqlgraph.NewFieldSpec(soratask.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.TaskID(); ok { + _spec.SetField(soratask.FieldTaskID, field.TypeString, value) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(soratask.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(soratask.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(soratask.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.Prompt(); ok { + _spec.SetField(soratask.FieldPrompt, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(soratask.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.Progress(); ok { + _spec.SetField(soratask.FieldProgress, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedProgress(); ok { + _spec.AddField(soratask.FieldProgress, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ResultUrls(); ok { + _spec.SetField(soratask.FieldResultUrls, field.TypeString, value) + } + if _u.mutation.ResultUrlsCleared() { + _spec.ClearField(soratask.FieldResultUrls, field.TypeString) + } + if value, ok := _u.mutation.ErrorMessage(); ok { + _spec.SetField(soratask.FieldErrorMessage, field.TypeString, value) + } + if _u.mutation.ErrorMessageCleared() { + _spec.ClearField(soratask.FieldErrorMessage, field.TypeString) + } + if value, ok := _u.mutation.RetryCount(); ok { + _spec.SetField(soratask.FieldRetryCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRetryCount(); ok { + _spec.AddField(soratask.FieldRetryCount, field.TypeInt, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(soratask.FieldCreatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.CompletedAt(); ok { + _spec.SetField(soratask.FieldCompletedAt, field.TypeTime, value) + } + if _u.mutation.CompletedAtCleared() { + _spec.ClearField(soratask.FieldCompletedAt, field.TypeTime) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{soratask.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// SoraTaskUpdateOne is the builder for updating a single SoraTask entity. +type SoraTaskUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SoraTaskMutation +} + +// SetTaskID sets the "task_id" field. +func (_u *SoraTaskUpdateOne) SetTaskID(v string) *SoraTaskUpdateOne { + _u.mutation.SetTaskID(v) + return _u +} + +// SetNillableTaskID sets the "task_id" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableTaskID(v *string) *SoraTaskUpdateOne { + if v != nil { + _u.SetTaskID(*v) + } + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraTaskUpdateOne) SetAccountID(v int64) *SoraTaskUpdateOne { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableAccountID(v *int64) *SoraTaskUpdateOne { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraTaskUpdateOne) AddAccountID(v int64) *SoraTaskUpdateOne { + _u.mutation.AddAccountID(v) + return _u +} + +// SetModel sets the "model" field. +func (_u *SoraTaskUpdateOne) SetModel(v string) *SoraTaskUpdateOne { + _u.mutation.SetModel(v) + return _u +} + +// SetNillableModel sets the "model" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableModel(v *string) *SoraTaskUpdateOne { + if v != nil { + _u.SetModel(*v) + } + return _u +} + +// SetPrompt sets the "prompt" field. +func (_u *SoraTaskUpdateOne) SetPrompt(v string) *SoraTaskUpdateOne { + _u.mutation.SetPrompt(v) + return _u +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillablePrompt(v *string) *SoraTaskUpdateOne { + if v != nil { + _u.SetPrompt(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *SoraTaskUpdateOne) SetStatus(v string) *SoraTaskUpdateOne { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableStatus(v *string) *SoraTaskUpdateOne { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetProgress sets the "progress" field. +func (_u *SoraTaskUpdateOne) SetProgress(v float64) *SoraTaskUpdateOne { + _u.mutation.ResetProgress() + _u.mutation.SetProgress(v) + return _u +} + +// SetNillableProgress sets the "progress" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableProgress(v *float64) *SoraTaskUpdateOne { + if v != nil { + _u.SetProgress(*v) + } + return _u +} + +// AddProgress adds value to the "progress" field. +func (_u *SoraTaskUpdateOne) AddProgress(v float64) *SoraTaskUpdateOne { + _u.mutation.AddProgress(v) + return _u +} + +// SetResultUrls sets the "result_urls" field. +func (_u *SoraTaskUpdateOne) SetResultUrls(v string) *SoraTaskUpdateOne { + _u.mutation.SetResultUrls(v) + return _u +} + +// SetNillableResultUrls sets the "result_urls" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableResultUrls(v *string) *SoraTaskUpdateOne { + if v != nil { + _u.SetResultUrls(*v) + } + return _u +} + +// ClearResultUrls clears the value of the "result_urls" field. +func (_u *SoraTaskUpdateOne) ClearResultUrls() *SoraTaskUpdateOne { + _u.mutation.ClearResultUrls() + return _u +} + +// SetErrorMessage sets the "error_message" field. +func (_u *SoraTaskUpdateOne) SetErrorMessage(v string) *SoraTaskUpdateOne { + _u.mutation.SetErrorMessage(v) + return _u +} + +// SetNillableErrorMessage sets the "error_message" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableErrorMessage(v *string) *SoraTaskUpdateOne { + if v != nil { + _u.SetErrorMessage(*v) + } + return _u +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (_u *SoraTaskUpdateOne) ClearErrorMessage() *SoraTaskUpdateOne { + _u.mutation.ClearErrorMessage() + return _u +} + +// SetRetryCount sets the "retry_count" field. +func (_u *SoraTaskUpdateOne) SetRetryCount(v int) *SoraTaskUpdateOne { + _u.mutation.ResetRetryCount() + _u.mutation.SetRetryCount(v) + return _u +} + +// SetNillableRetryCount sets the "retry_count" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableRetryCount(v *int) *SoraTaskUpdateOne { + if v != nil { + _u.SetRetryCount(*v) + } + return _u +} + +// AddRetryCount adds value to the "retry_count" field. +func (_u *SoraTaskUpdateOne) AddRetryCount(v int) *SoraTaskUpdateOne { + _u.mutation.AddRetryCount(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *SoraTaskUpdateOne) SetCreatedAt(v time.Time) *SoraTaskUpdateOne { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableCreatedAt(v *time.Time) *SoraTaskUpdateOne { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// SetCompletedAt sets the "completed_at" field. +func (_u *SoraTaskUpdateOne) SetCompletedAt(v time.Time) *SoraTaskUpdateOne { + _u.mutation.SetCompletedAt(v) + return _u +} + +// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil. +func (_u *SoraTaskUpdateOne) SetNillableCompletedAt(v *time.Time) *SoraTaskUpdateOne { + if v != nil { + _u.SetCompletedAt(*v) + } + return _u +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (_u *SoraTaskUpdateOne) ClearCompletedAt() *SoraTaskUpdateOne { + _u.mutation.ClearCompletedAt() + return _u +} + +// Mutation returns the SoraTaskMutation object of the builder. +func (_u *SoraTaskUpdateOne) Mutation() *SoraTaskMutation { + return _u.mutation +} + +// Where appends a list predicates to the SoraTaskUpdate builder. +func (_u *SoraTaskUpdateOne) Where(ps ...predicate.SoraTask) *SoraTaskUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *SoraTaskUpdateOne) Select(field string, fields ...string) *SoraTaskUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated SoraTask entity. +func (_u *SoraTaskUpdateOne) Save(ctx context.Context) (*SoraTask, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraTaskUpdateOne) SaveX(ctx context.Context) *SoraTask { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *SoraTaskUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraTaskUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SoraTaskUpdateOne) check() error { + if v, ok := _u.mutation.TaskID(); ok { + if err := soratask.TaskIDValidator(v); err != nil { + return &ValidationError{Name: "task_id", err: fmt.Errorf(`ent: validator failed for field "SoraTask.task_id": %w`, err)} + } + } + if v, ok := _u.mutation.Model(); ok { + if err := soratask.ModelValidator(v); err != nil { + return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "SoraTask.model": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := soratask.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "SoraTask.status": %w`, err)} + } + } + return nil +} + +func (_u *SoraTaskUpdateOne) sqlSave(ctx context.Context) (_node *SoraTask, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(soratask.Table, soratask.Columns, sqlgraph.NewFieldSpec(soratask.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SoraTask.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, soratask.FieldID) + for _, f := range fields { + if !soratask.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != soratask.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.TaskID(); ok { + _spec.SetField(soratask.FieldTaskID, field.TypeString, value) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(soratask.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(soratask.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.Model(); ok { + _spec.SetField(soratask.FieldModel, field.TypeString, value) + } + if value, ok := _u.mutation.Prompt(); ok { + _spec.SetField(soratask.FieldPrompt, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(soratask.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.Progress(); ok { + _spec.SetField(soratask.FieldProgress, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedProgress(); ok { + _spec.AddField(soratask.FieldProgress, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ResultUrls(); ok { + _spec.SetField(soratask.FieldResultUrls, field.TypeString, value) + } + if _u.mutation.ResultUrlsCleared() { + _spec.ClearField(soratask.FieldResultUrls, field.TypeString) + } + if value, ok := _u.mutation.ErrorMessage(); ok { + _spec.SetField(soratask.FieldErrorMessage, field.TypeString, value) + } + if _u.mutation.ErrorMessageCleared() { + _spec.ClearField(soratask.FieldErrorMessage, field.TypeString) + } + if value, ok := _u.mutation.RetryCount(); ok { + _spec.SetField(soratask.FieldRetryCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRetryCount(); ok { + _spec.AddField(soratask.FieldRetryCount, field.TypeInt, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(soratask.FieldCreatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.CompletedAt(); ok { + _spec.SetField(soratask.FieldCompletedAt, field.TypeTime, value) + } + if _u.mutation.CompletedAtCleared() { + _spec.ClearField(soratask.FieldCompletedAt, field.TypeTime) + } + _node = &SoraTask{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{soratask.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/sorausagestat.go b/backend/ent/sorausagestat.go new file mode 100644 index 00000000..b99f313f --- /dev/null +++ b/backend/ent/sorausagestat.go @@ -0,0 +1,231 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" +) + +// SoraUsageStat is the model entity for the SoraUsageStat schema. +type SoraUsageStat struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // 关联 accounts 表的 ID + AccountID int64 `json:"account_id,omitempty"` + // ImageCount holds the value of the "image_count" field. + ImageCount int `json:"image_count,omitempty"` + // VideoCount holds the value of the "video_count" field. + VideoCount int `json:"video_count,omitempty"` + // ErrorCount holds the value of the "error_count" field. + ErrorCount int `json:"error_count,omitempty"` + // LastErrorAt holds the value of the "last_error_at" field. + LastErrorAt *time.Time `json:"last_error_at,omitempty"` + // TodayImageCount holds the value of the "today_image_count" field. + TodayImageCount int `json:"today_image_count,omitempty"` + // TodayVideoCount holds the value of the "today_video_count" field. + TodayVideoCount int `json:"today_video_count,omitempty"` + // TodayErrorCount holds the value of the "today_error_count" field. + TodayErrorCount int `json:"today_error_count,omitempty"` + // TodayDate holds the value of the "today_date" field. + TodayDate *time.Time `json:"today_date,omitempty"` + // ConsecutiveErrorCount holds the value of the "consecutive_error_count" field. + ConsecutiveErrorCount int `json:"consecutive_error_count,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*SoraUsageStat) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case sorausagestat.FieldID, sorausagestat.FieldAccountID, sorausagestat.FieldImageCount, sorausagestat.FieldVideoCount, sorausagestat.FieldErrorCount, sorausagestat.FieldTodayImageCount, sorausagestat.FieldTodayVideoCount, sorausagestat.FieldTodayErrorCount, sorausagestat.FieldConsecutiveErrorCount: + values[i] = new(sql.NullInt64) + case sorausagestat.FieldCreatedAt, sorausagestat.FieldUpdatedAt, sorausagestat.FieldLastErrorAt, sorausagestat.FieldTodayDate: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the SoraUsageStat fields. +func (_m *SoraUsageStat) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case sorausagestat.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case sorausagestat.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case sorausagestat.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case sorausagestat.FieldAccountID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field account_id", values[i]) + } else if value.Valid { + _m.AccountID = value.Int64 + } + case sorausagestat.FieldImageCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field image_count", values[i]) + } else if value.Valid { + _m.ImageCount = int(value.Int64) + } + case sorausagestat.FieldVideoCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field video_count", values[i]) + } else if value.Valid { + _m.VideoCount = int(value.Int64) + } + case sorausagestat.FieldErrorCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field error_count", values[i]) + } else if value.Valid { + _m.ErrorCount = int(value.Int64) + } + case sorausagestat.FieldLastErrorAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_error_at", values[i]) + } else if value.Valid { + _m.LastErrorAt = new(time.Time) + *_m.LastErrorAt = value.Time + } + case sorausagestat.FieldTodayImageCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field today_image_count", values[i]) + } else if value.Valid { + _m.TodayImageCount = int(value.Int64) + } + case sorausagestat.FieldTodayVideoCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field today_video_count", values[i]) + } else if value.Valid { + _m.TodayVideoCount = int(value.Int64) + } + case sorausagestat.FieldTodayErrorCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field today_error_count", values[i]) + } else if value.Valid { + _m.TodayErrorCount = int(value.Int64) + } + case sorausagestat.FieldTodayDate: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field today_date", values[i]) + } else if value.Valid { + _m.TodayDate = new(time.Time) + *_m.TodayDate = value.Time + } + case sorausagestat.FieldConsecutiveErrorCount: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field consecutive_error_count", values[i]) + } else if value.Valid { + _m.ConsecutiveErrorCount = int(value.Int64) + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the SoraUsageStat. +// This includes values selected through modifiers, order, etc. +func (_m *SoraUsageStat) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this SoraUsageStat. +// Note that you need to call SoraUsageStat.Unwrap() before calling this method if this SoraUsageStat +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *SoraUsageStat) Update() *SoraUsageStatUpdateOne { + return NewSoraUsageStatClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the SoraUsageStat entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *SoraUsageStat) Unwrap() *SoraUsageStat { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: SoraUsageStat is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *SoraUsageStat) String() string { + var builder strings.Builder + builder.WriteString("SoraUsageStat(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("account_id=") + builder.WriteString(fmt.Sprintf("%v", _m.AccountID)) + builder.WriteString(", ") + builder.WriteString("image_count=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageCount)) + builder.WriteString(", ") + builder.WriteString("video_count=") + builder.WriteString(fmt.Sprintf("%v", _m.VideoCount)) + builder.WriteString(", ") + builder.WriteString("error_count=") + builder.WriteString(fmt.Sprintf("%v", _m.ErrorCount)) + builder.WriteString(", ") + if v := _m.LastErrorAt; v != nil { + builder.WriteString("last_error_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("today_image_count=") + builder.WriteString(fmt.Sprintf("%v", _m.TodayImageCount)) + builder.WriteString(", ") + builder.WriteString("today_video_count=") + builder.WriteString(fmt.Sprintf("%v", _m.TodayVideoCount)) + builder.WriteString(", ") + builder.WriteString("today_error_count=") + builder.WriteString(fmt.Sprintf("%v", _m.TodayErrorCount)) + builder.WriteString(", ") + if v := _m.TodayDate; v != nil { + builder.WriteString("today_date=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("consecutive_error_count=") + builder.WriteString(fmt.Sprintf("%v", _m.ConsecutiveErrorCount)) + builder.WriteByte(')') + return builder.String() +} + +// SoraUsageStats is a parsable slice of SoraUsageStat. +type SoraUsageStats []*SoraUsageStat diff --git a/backend/ent/sorausagestat/sorausagestat.go b/backend/ent/sorausagestat/sorausagestat.go new file mode 100644 index 00000000..070de5ff --- /dev/null +++ b/backend/ent/sorausagestat/sorausagestat.go @@ -0,0 +1,160 @@ +// Code generated by ent, DO NOT EDIT. + +package sorausagestat + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the sorausagestat type in the database. + Label = "sora_usage_stat" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldAccountID holds the string denoting the account_id field in the database. + FieldAccountID = "account_id" + // FieldImageCount holds the string denoting the image_count field in the database. + FieldImageCount = "image_count" + // FieldVideoCount holds the string denoting the video_count field in the database. + FieldVideoCount = "video_count" + // FieldErrorCount holds the string denoting the error_count field in the database. + FieldErrorCount = "error_count" + // FieldLastErrorAt holds the string denoting the last_error_at field in the database. + FieldLastErrorAt = "last_error_at" + // FieldTodayImageCount holds the string denoting the today_image_count field in the database. + FieldTodayImageCount = "today_image_count" + // FieldTodayVideoCount holds the string denoting the today_video_count field in the database. + FieldTodayVideoCount = "today_video_count" + // FieldTodayErrorCount holds the string denoting the today_error_count field in the database. + FieldTodayErrorCount = "today_error_count" + // FieldTodayDate holds the string denoting the today_date field in the database. + FieldTodayDate = "today_date" + // FieldConsecutiveErrorCount holds the string denoting the consecutive_error_count field in the database. + FieldConsecutiveErrorCount = "consecutive_error_count" + // Table holds the table name of the sorausagestat in the database. + Table = "sora_usage_stats" +) + +// Columns holds all SQL columns for sorausagestat fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldAccountID, + FieldImageCount, + FieldVideoCount, + FieldErrorCount, + FieldLastErrorAt, + FieldTodayImageCount, + FieldTodayVideoCount, + FieldTodayErrorCount, + FieldTodayDate, + FieldConsecutiveErrorCount, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultImageCount holds the default value on creation for the "image_count" field. + DefaultImageCount int + // DefaultVideoCount holds the default value on creation for the "video_count" field. + DefaultVideoCount int + // DefaultErrorCount holds the default value on creation for the "error_count" field. + DefaultErrorCount int + // DefaultTodayImageCount holds the default value on creation for the "today_image_count" field. + DefaultTodayImageCount int + // DefaultTodayVideoCount holds the default value on creation for the "today_video_count" field. + DefaultTodayVideoCount int + // DefaultTodayErrorCount holds the default value on creation for the "today_error_count" field. + DefaultTodayErrorCount int + // DefaultConsecutiveErrorCount holds the default value on creation for the "consecutive_error_count" field. + DefaultConsecutiveErrorCount int +) + +// OrderOption defines the ordering options for the SoraUsageStat queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByAccountID orders the results by the account_id field. +func ByAccountID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAccountID, opts...).ToFunc() +} + +// ByImageCount orders the results by the image_count field. +func ByImageCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageCount, opts...).ToFunc() +} + +// ByVideoCount orders the results by the video_count field. +func ByVideoCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVideoCount, opts...).ToFunc() +} + +// ByErrorCount orders the results by the error_count field. +func ByErrorCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorCount, opts...).ToFunc() +} + +// ByLastErrorAt orders the results by the last_error_at field. +func ByLastErrorAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastErrorAt, opts...).ToFunc() +} + +// ByTodayImageCount orders the results by the today_image_count field. +func ByTodayImageCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTodayImageCount, opts...).ToFunc() +} + +// ByTodayVideoCount orders the results by the today_video_count field. +func ByTodayVideoCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTodayVideoCount, opts...).ToFunc() +} + +// ByTodayErrorCount orders the results by the today_error_count field. +func ByTodayErrorCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTodayErrorCount, opts...).ToFunc() +} + +// ByTodayDate orders the results by the today_date field. +func ByTodayDate(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTodayDate, opts...).ToFunc() +} + +// ByConsecutiveErrorCount orders the results by the consecutive_error_count field. +func ByConsecutiveErrorCount(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldConsecutiveErrorCount, opts...).ToFunc() +} diff --git a/backend/ent/sorausagestat/where.go b/backend/ent/sorausagestat/where.go new file mode 100644 index 00000000..336a3d24 --- /dev/null +++ b/backend/ent/sorausagestat/where.go @@ -0,0 +1,630 @@ +// Code generated by ent, DO NOT EDIT. + +package sorausagestat + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// AccountID applies equality check predicate on the "account_id" field. It's identical to AccountIDEQ. +func AccountID(v int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldAccountID, v)) +} + +// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ. +func ImageCount(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldImageCount, v)) +} + +// VideoCount applies equality check predicate on the "video_count" field. It's identical to VideoCountEQ. +func VideoCount(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldVideoCount, v)) +} + +// ErrorCount applies equality check predicate on the "error_count" field. It's identical to ErrorCountEQ. +func ErrorCount(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldErrorCount, v)) +} + +// LastErrorAt applies equality check predicate on the "last_error_at" field. It's identical to LastErrorAtEQ. +func LastErrorAt(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldLastErrorAt, v)) +} + +// TodayImageCount applies equality check predicate on the "today_image_count" field. It's identical to TodayImageCountEQ. +func TodayImageCount(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayImageCount, v)) +} + +// TodayVideoCount applies equality check predicate on the "today_video_count" field. It's identical to TodayVideoCountEQ. +func TodayVideoCount(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayVideoCount, v)) +} + +// TodayErrorCount applies equality check predicate on the "today_error_count" field. It's identical to TodayErrorCountEQ. +func TodayErrorCount(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayErrorCount, v)) +} + +// TodayDate applies equality check predicate on the "today_date" field. It's identical to TodayDateEQ. +func TodayDate(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayDate, v)) +} + +// ConsecutiveErrorCount applies equality check predicate on the "consecutive_error_count" field. It's identical to ConsecutiveErrorCountEQ. +func ConsecutiveErrorCount(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldConsecutiveErrorCount, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// AccountIDEQ applies the EQ predicate on the "account_id" field. +func AccountIDEQ(v int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldAccountID, v)) +} + +// AccountIDNEQ applies the NEQ predicate on the "account_id" field. +func AccountIDNEQ(v int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldAccountID, v)) +} + +// AccountIDIn applies the In predicate on the "account_id" field. +func AccountIDIn(vs ...int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldAccountID, vs...)) +} + +// AccountIDNotIn applies the NotIn predicate on the "account_id" field. +func AccountIDNotIn(vs ...int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldAccountID, vs...)) +} + +// AccountIDGT applies the GT predicate on the "account_id" field. +func AccountIDGT(v int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldAccountID, v)) +} + +// AccountIDGTE applies the GTE predicate on the "account_id" field. +func AccountIDGTE(v int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldAccountID, v)) +} + +// AccountIDLT applies the LT predicate on the "account_id" field. +func AccountIDLT(v int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldAccountID, v)) +} + +// AccountIDLTE applies the LTE predicate on the "account_id" field. +func AccountIDLTE(v int64) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldAccountID, v)) +} + +// ImageCountEQ applies the EQ predicate on the "image_count" field. +func ImageCountEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldImageCount, v)) +} + +// ImageCountNEQ applies the NEQ predicate on the "image_count" field. +func ImageCountNEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldImageCount, v)) +} + +// ImageCountIn applies the In predicate on the "image_count" field. +func ImageCountIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldImageCount, vs...)) +} + +// ImageCountNotIn applies the NotIn predicate on the "image_count" field. +func ImageCountNotIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldImageCount, vs...)) +} + +// ImageCountGT applies the GT predicate on the "image_count" field. +func ImageCountGT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldImageCount, v)) +} + +// ImageCountGTE applies the GTE predicate on the "image_count" field. +func ImageCountGTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldImageCount, v)) +} + +// ImageCountLT applies the LT predicate on the "image_count" field. +func ImageCountLT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldImageCount, v)) +} + +// ImageCountLTE applies the LTE predicate on the "image_count" field. +func ImageCountLTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldImageCount, v)) +} + +// VideoCountEQ applies the EQ predicate on the "video_count" field. +func VideoCountEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldVideoCount, v)) +} + +// VideoCountNEQ applies the NEQ predicate on the "video_count" field. +func VideoCountNEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldVideoCount, v)) +} + +// VideoCountIn applies the In predicate on the "video_count" field. +func VideoCountIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldVideoCount, vs...)) +} + +// VideoCountNotIn applies the NotIn predicate on the "video_count" field. +func VideoCountNotIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldVideoCount, vs...)) +} + +// VideoCountGT applies the GT predicate on the "video_count" field. +func VideoCountGT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldVideoCount, v)) +} + +// VideoCountGTE applies the GTE predicate on the "video_count" field. +func VideoCountGTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldVideoCount, v)) +} + +// VideoCountLT applies the LT predicate on the "video_count" field. +func VideoCountLT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldVideoCount, v)) +} + +// VideoCountLTE applies the LTE predicate on the "video_count" field. +func VideoCountLTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldVideoCount, v)) +} + +// ErrorCountEQ applies the EQ predicate on the "error_count" field. +func ErrorCountEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldErrorCount, v)) +} + +// ErrorCountNEQ applies the NEQ predicate on the "error_count" field. +func ErrorCountNEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldErrorCount, v)) +} + +// ErrorCountIn applies the In predicate on the "error_count" field. +func ErrorCountIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldErrorCount, vs...)) +} + +// ErrorCountNotIn applies the NotIn predicate on the "error_count" field. +func ErrorCountNotIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldErrorCount, vs...)) +} + +// ErrorCountGT applies the GT predicate on the "error_count" field. +func ErrorCountGT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldErrorCount, v)) +} + +// ErrorCountGTE applies the GTE predicate on the "error_count" field. +func ErrorCountGTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldErrorCount, v)) +} + +// ErrorCountLT applies the LT predicate on the "error_count" field. +func ErrorCountLT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldErrorCount, v)) +} + +// ErrorCountLTE applies the LTE predicate on the "error_count" field. +func ErrorCountLTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldErrorCount, v)) +} + +// LastErrorAtEQ applies the EQ predicate on the "last_error_at" field. +func LastErrorAtEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldLastErrorAt, v)) +} + +// LastErrorAtNEQ applies the NEQ predicate on the "last_error_at" field. +func LastErrorAtNEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldLastErrorAt, v)) +} + +// LastErrorAtIn applies the In predicate on the "last_error_at" field. +func LastErrorAtIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldLastErrorAt, vs...)) +} + +// LastErrorAtNotIn applies the NotIn predicate on the "last_error_at" field. +func LastErrorAtNotIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldLastErrorAt, vs...)) +} + +// LastErrorAtGT applies the GT predicate on the "last_error_at" field. +func LastErrorAtGT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldLastErrorAt, v)) +} + +// LastErrorAtGTE applies the GTE predicate on the "last_error_at" field. +func LastErrorAtGTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldLastErrorAt, v)) +} + +// LastErrorAtLT applies the LT predicate on the "last_error_at" field. +func LastErrorAtLT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldLastErrorAt, v)) +} + +// LastErrorAtLTE applies the LTE predicate on the "last_error_at" field. +func LastErrorAtLTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldLastErrorAt, v)) +} + +// LastErrorAtIsNil applies the IsNil predicate on the "last_error_at" field. +func LastErrorAtIsNil() predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIsNull(FieldLastErrorAt)) +} + +// LastErrorAtNotNil applies the NotNil predicate on the "last_error_at" field. +func LastErrorAtNotNil() predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotNull(FieldLastErrorAt)) +} + +// TodayImageCountEQ applies the EQ predicate on the "today_image_count" field. +func TodayImageCountEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayImageCount, v)) +} + +// TodayImageCountNEQ applies the NEQ predicate on the "today_image_count" field. +func TodayImageCountNEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldTodayImageCount, v)) +} + +// TodayImageCountIn applies the In predicate on the "today_image_count" field. +func TodayImageCountIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldTodayImageCount, vs...)) +} + +// TodayImageCountNotIn applies the NotIn predicate on the "today_image_count" field. +func TodayImageCountNotIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldTodayImageCount, vs...)) +} + +// TodayImageCountGT applies the GT predicate on the "today_image_count" field. +func TodayImageCountGT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldTodayImageCount, v)) +} + +// TodayImageCountGTE applies the GTE predicate on the "today_image_count" field. +func TodayImageCountGTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldTodayImageCount, v)) +} + +// TodayImageCountLT applies the LT predicate on the "today_image_count" field. +func TodayImageCountLT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldTodayImageCount, v)) +} + +// TodayImageCountLTE applies the LTE predicate on the "today_image_count" field. +func TodayImageCountLTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldTodayImageCount, v)) +} + +// TodayVideoCountEQ applies the EQ predicate on the "today_video_count" field. +func TodayVideoCountEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayVideoCount, v)) +} + +// TodayVideoCountNEQ applies the NEQ predicate on the "today_video_count" field. +func TodayVideoCountNEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldTodayVideoCount, v)) +} + +// TodayVideoCountIn applies the In predicate on the "today_video_count" field. +func TodayVideoCountIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldTodayVideoCount, vs...)) +} + +// TodayVideoCountNotIn applies the NotIn predicate on the "today_video_count" field. +func TodayVideoCountNotIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldTodayVideoCount, vs...)) +} + +// TodayVideoCountGT applies the GT predicate on the "today_video_count" field. +func TodayVideoCountGT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldTodayVideoCount, v)) +} + +// TodayVideoCountGTE applies the GTE predicate on the "today_video_count" field. +func TodayVideoCountGTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldTodayVideoCount, v)) +} + +// TodayVideoCountLT applies the LT predicate on the "today_video_count" field. +func TodayVideoCountLT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldTodayVideoCount, v)) +} + +// TodayVideoCountLTE applies the LTE predicate on the "today_video_count" field. +func TodayVideoCountLTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldTodayVideoCount, v)) +} + +// TodayErrorCountEQ applies the EQ predicate on the "today_error_count" field. +func TodayErrorCountEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayErrorCount, v)) +} + +// TodayErrorCountNEQ applies the NEQ predicate on the "today_error_count" field. +func TodayErrorCountNEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldTodayErrorCount, v)) +} + +// TodayErrorCountIn applies the In predicate on the "today_error_count" field. +func TodayErrorCountIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldTodayErrorCount, vs...)) +} + +// TodayErrorCountNotIn applies the NotIn predicate on the "today_error_count" field. +func TodayErrorCountNotIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldTodayErrorCount, vs...)) +} + +// TodayErrorCountGT applies the GT predicate on the "today_error_count" field. +func TodayErrorCountGT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldTodayErrorCount, v)) +} + +// TodayErrorCountGTE applies the GTE predicate on the "today_error_count" field. +func TodayErrorCountGTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldTodayErrorCount, v)) +} + +// TodayErrorCountLT applies the LT predicate on the "today_error_count" field. +func TodayErrorCountLT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldTodayErrorCount, v)) +} + +// TodayErrorCountLTE applies the LTE predicate on the "today_error_count" field. +func TodayErrorCountLTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldTodayErrorCount, v)) +} + +// TodayDateEQ applies the EQ predicate on the "today_date" field. +func TodayDateEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldTodayDate, v)) +} + +// TodayDateNEQ applies the NEQ predicate on the "today_date" field. +func TodayDateNEQ(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldTodayDate, v)) +} + +// TodayDateIn applies the In predicate on the "today_date" field. +func TodayDateIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldTodayDate, vs...)) +} + +// TodayDateNotIn applies the NotIn predicate on the "today_date" field. +func TodayDateNotIn(vs ...time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldTodayDate, vs...)) +} + +// TodayDateGT applies the GT predicate on the "today_date" field. +func TodayDateGT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldTodayDate, v)) +} + +// TodayDateGTE applies the GTE predicate on the "today_date" field. +func TodayDateGTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldTodayDate, v)) +} + +// TodayDateLT applies the LT predicate on the "today_date" field. +func TodayDateLT(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldTodayDate, v)) +} + +// TodayDateLTE applies the LTE predicate on the "today_date" field. +func TodayDateLTE(v time.Time) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldTodayDate, v)) +} + +// TodayDateIsNil applies the IsNil predicate on the "today_date" field. +func TodayDateIsNil() predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIsNull(FieldTodayDate)) +} + +// TodayDateNotNil applies the NotNil predicate on the "today_date" field. +func TodayDateNotNil() predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotNull(FieldTodayDate)) +} + +// ConsecutiveErrorCountEQ applies the EQ predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldEQ(FieldConsecutiveErrorCount, v)) +} + +// ConsecutiveErrorCountNEQ applies the NEQ predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountNEQ(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNEQ(FieldConsecutiveErrorCount, v)) +} + +// ConsecutiveErrorCountIn applies the In predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldIn(FieldConsecutiveErrorCount, vs...)) +} + +// ConsecutiveErrorCountNotIn applies the NotIn predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountNotIn(vs ...int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldNotIn(FieldConsecutiveErrorCount, vs...)) +} + +// ConsecutiveErrorCountGT applies the GT predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountGT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGT(FieldConsecutiveErrorCount, v)) +} + +// ConsecutiveErrorCountGTE applies the GTE predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountGTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldGTE(FieldConsecutiveErrorCount, v)) +} + +// ConsecutiveErrorCountLT applies the LT predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountLT(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLT(FieldConsecutiveErrorCount, v)) +} + +// ConsecutiveErrorCountLTE applies the LTE predicate on the "consecutive_error_count" field. +func ConsecutiveErrorCountLTE(v int) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.FieldLTE(FieldConsecutiveErrorCount, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.SoraUsageStat) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.SoraUsageStat) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.SoraUsageStat) predicate.SoraUsageStat { + return predicate.SoraUsageStat(sql.NotPredicates(p)) +} diff --git a/backend/ent/sorausagestat_create.go b/backend/ent/sorausagestat_create.go new file mode 100644 index 00000000..c9aab3be --- /dev/null +++ b/backend/ent/sorausagestat_create.go @@ -0,0 +1,1334 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" +) + +// SoraUsageStatCreate is the builder for creating a SoraUsageStat entity. +type SoraUsageStatCreate struct { + config + mutation *SoraUsageStatMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *SoraUsageStatCreate) SetCreatedAt(v time.Time) *SoraUsageStatCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableCreatedAt(v *time.Time) *SoraUsageStatCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *SoraUsageStatCreate) SetUpdatedAt(v time.Time) *SoraUsageStatCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableUpdatedAt(v *time.Time) *SoraUsageStatCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetAccountID sets the "account_id" field. +func (_c *SoraUsageStatCreate) SetAccountID(v int64) *SoraUsageStatCreate { + _c.mutation.SetAccountID(v) + return _c +} + +// SetImageCount sets the "image_count" field. +func (_c *SoraUsageStatCreate) SetImageCount(v int) *SoraUsageStatCreate { + _c.mutation.SetImageCount(v) + return _c +} + +// SetNillableImageCount sets the "image_count" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableImageCount(v *int) *SoraUsageStatCreate { + if v != nil { + _c.SetImageCount(*v) + } + return _c +} + +// SetVideoCount sets the "video_count" field. +func (_c *SoraUsageStatCreate) SetVideoCount(v int) *SoraUsageStatCreate { + _c.mutation.SetVideoCount(v) + return _c +} + +// SetNillableVideoCount sets the "video_count" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableVideoCount(v *int) *SoraUsageStatCreate { + if v != nil { + _c.SetVideoCount(*v) + } + return _c +} + +// SetErrorCount sets the "error_count" field. +func (_c *SoraUsageStatCreate) SetErrorCount(v int) *SoraUsageStatCreate { + _c.mutation.SetErrorCount(v) + return _c +} + +// SetNillableErrorCount sets the "error_count" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableErrorCount(v *int) *SoraUsageStatCreate { + if v != nil { + _c.SetErrorCount(*v) + } + return _c +} + +// SetLastErrorAt sets the "last_error_at" field. +func (_c *SoraUsageStatCreate) SetLastErrorAt(v time.Time) *SoraUsageStatCreate { + _c.mutation.SetLastErrorAt(v) + return _c +} + +// SetNillableLastErrorAt sets the "last_error_at" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableLastErrorAt(v *time.Time) *SoraUsageStatCreate { + if v != nil { + _c.SetLastErrorAt(*v) + } + return _c +} + +// SetTodayImageCount sets the "today_image_count" field. +func (_c *SoraUsageStatCreate) SetTodayImageCount(v int) *SoraUsageStatCreate { + _c.mutation.SetTodayImageCount(v) + return _c +} + +// SetNillableTodayImageCount sets the "today_image_count" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableTodayImageCount(v *int) *SoraUsageStatCreate { + if v != nil { + _c.SetTodayImageCount(*v) + } + return _c +} + +// SetTodayVideoCount sets the "today_video_count" field. +func (_c *SoraUsageStatCreate) SetTodayVideoCount(v int) *SoraUsageStatCreate { + _c.mutation.SetTodayVideoCount(v) + return _c +} + +// SetNillableTodayVideoCount sets the "today_video_count" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableTodayVideoCount(v *int) *SoraUsageStatCreate { + if v != nil { + _c.SetTodayVideoCount(*v) + } + return _c +} + +// SetTodayErrorCount sets the "today_error_count" field. +func (_c *SoraUsageStatCreate) SetTodayErrorCount(v int) *SoraUsageStatCreate { + _c.mutation.SetTodayErrorCount(v) + return _c +} + +// SetNillableTodayErrorCount sets the "today_error_count" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableTodayErrorCount(v *int) *SoraUsageStatCreate { + if v != nil { + _c.SetTodayErrorCount(*v) + } + return _c +} + +// SetTodayDate sets the "today_date" field. +func (_c *SoraUsageStatCreate) SetTodayDate(v time.Time) *SoraUsageStatCreate { + _c.mutation.SetTodayDate(v) + return _c +} + +// SetNillableTodayDate sets the "today_date" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableTodayDate(v *time.Time) *SoraUsageStatCreate { + if v != nil { + _c.SetTodayDate(*v) + } + return _c +} + +// SetConsecutiveErrorCount sets the "consecutive_error_count" field. +func (_c *SoraUsageStatCreate) SetConsecutiveErrorCount(v int) *SoraUsageStatCreate { + _c.mutation.SetConsecutiveErrorCount(v) + return _c +} + +// SetNillableConsecutiveErrorCount sets the "consecutive_error_count" field if the given value is not nil. +func (_c *SoraUsageStatCreate) SetNillableConsecutiveErrorCount(v *int) *SoraUsageStatCreate { + if v != nil { + _c.SetConsecutiveErrorCount(*v) + } + return _c +} + +// Mutation returns the SoraUsageStatMutation object of the builder. +func (_c *SoraUsageStatCreate) Mutation() *SoraUsageStatMutation { + return _c.mutation +} + +// Save creates the SoraUsageStat in the database. +func (_c *SoraUsageStatCreate) Save(ctx context.Context) (*SoraUsageStat, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *SoraUsageStatCreate) SaveX(ctx context.Context) *SoraUsageStat { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraUsageStatCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraUsageStatCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *SoraUsageStatCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := sorausagestat.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := sorausagestat.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.ImageCount(); !ok { + v := sorausagestat.DefaultImageCount + _c.mutation.SetImageCount(v) + } + if _, ok := _c.mutation.VideoCount(); !ok { + v := sorausagestat.DefaultVideoCount + _c.mutation.SetVideoCount(v) + } + if _, ok := _c.mutation.ErrorCount(); !ok { + v := sorausagestat.DefaultErrorCount + _c.mutation.SetErrorCount(v) + } + if _, ok := _c.mutation.TodayImageCount(); !ok { + v := sorausagestat.DefaultTodayImageCount + _c.mutation.SetTodayImageCount(v) + } + if _, ok := _c.mutation.TodayVideoCount(); !ok { + v := sorausagestat.DefaultTodayVideoCount + _c.mutation.SetTodayVideoCount(v) + } + if _, ok := _c.mutation.TodayErrorCount(); !ok { + v := sorausagestat.DefaultTodayErrorCount + _c.mutation.SetTodayErrorCount(v) + } + if _, ok := _c.mutation.ConsecutiveErrorCount(); !ok { + v := sorausagestat.DefaultConsecutiveErrorCount + _c.mutation.SetConsecutiveErrorCount(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *SoraUsageStatCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SoraUsageStat.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SoraUsageStat.updated_at"`)} + } + if _, ok := _c.mutation.AccountID(); !ok { + return &ValidationError{Name: "account_id", err: errors.New(`ent: missing required field "SoraUsageStat.account_id"`)} + } + if _, ok := _c.mutation.ImageCount(); !ok { + return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "SoraUsageStat.image_count"`)} + } + if _, ok := _c.mutation.VideoCount(); !ok { + return &ValidationError{Name: "video_count", err: errors.New(`ent: missing required field "SoraUsageStat.video_count"`)} + } + if _, ok := _c.mutation.ErrorCount(); !ok { + return &ValidationError{Name: "error_count", err: errors.New(`ent: missing required field "SoraUsageStat.error_count"`)} + } + if _, ok := _c.mutation.TodayImageCount(); !ok { + return &ValidationError{Name: "today_image_count", err: errors.New(`ent: missing required field "SoraUsageStat.today_image_count"`)} + } + if _, ok := _c.mutation.TodayVideoCount(); !ok { + return &ValidationError{Name: "today_video_count", err: errors.New(`ent: missing required field "SoraUsageStat.today_video_count"`)} + } + if _, ok := _c.mutation.TodayErrorCount(); !ok { + return &ValidationError{Name: "today_error_count", err: errors.New(`ent: missing required field "SoraUsageStat.today_error_count"`)} + } + if _, ok := _c.mutation.ConsecutiveErrorCount(); !ok { + return &ValidationError{Name: "consecutive_error_count", err: errors.New(`ent: missing required field "SoraUsageStat.consecutive_error_count"`)} + } + return nil +} + +func (_c *SoraUsageStatCreate) sqlSave(ctx context.Context) (*SoraUsageStat, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *SoraUsageStatCreate) createSpec() (*SoraUsageStat, *sqlgraph.CreateSpec) { + var ( + _node = &SoraUsageStat{config: _c.config} + _spec = sqlgraph.NewCreateSpec(sorausagestat.Table, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(sorausagestat.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(sorausagestat.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.AccountID(); ok { + _spec.SetField(sorausagestat.FieldAccountID, field.TypeInt64, value) + _node.AccountID = value + } + if value, ok := _c.mutation.ImageCount(); ok { + _spec.SetField(sorausagestat.FieldImageCount, field.TypeInt, value) + _node.ImageCount = value + } + if value, ok := _c.mutation.VideoCount(); ok { + _spec.SetField(sorausagestat.FieldVideoCount, field.TypeInt, value) + _node.VideoCount = value + } + if value, ok := _c.mutation.ErrorCount(); ok { + _spec.SetField(sorausagestat.FieldErrorCount, field.TypeInt, value) + _node.ErrorCount = value + } + if value, ok := _c.mutation.LastErrorAt(); ok { + _spec.SetField(sorausagestat.FieldLastErrorAt, field.TypeTime, value) + _node.LastErrorAt = &value + } + if value, ok := _c.mutation.TodayImageCount(); ok { + _spec.SetField(sorausagestat.FieldTodayImageCount, field.TypeInt, value) + _node.TodayImageCount = value + } + if value, ok := _c.mutation.TodayVideoCount(); ok { + _spec.SetField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value) + _node.TodayVideoCount = value + } + if value, ok := _c.mutation.TodayErrorCount(); ok { + _spec.SetField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value) + _node.TodayErrorCount = value + } + if value, ok := _c.mutation.TodayDate(); ok { + _spec.SetField(sorausagestat.FieldTodayDate, field.TypeTime, value) + _node.TodayDate = &value + } + if value, ok := _c.mutation.ConsecutiveErrorCount(); ok { + _spec.SetField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value) + _node.ConsecutiveErrorCount = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraUsageStat.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraUsageStatUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SoraUsageStatCreate) OnConflict(opts ...sql.ConflictOption) *SoraUsageStatUpsertOne { + _c.conflict = opts + return &SoraUsageStatUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraUsageStat.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraUsageStatCreate) OnConflictColumns(columns ...string) *SoraUsageStatUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraUsageStatUpsertOne{ + create: _c, + } +} + +type ( + // SoraUsageStatUpsertOne is the builder for "upsert"-ing + // one SoraUsageStat node. + SoraUsageStatUpsertOne struct { + create *SoraUsageStatCreate + } + + // SoraUsageStatUpsert is the "OnConflict" setter. + SoraUsageStatUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *SoraUsageStatUpsert) SetUpdatedAt(v time.Time) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateUpdatedAt() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldUpdatedAt) + return u +} + +// SetAccountID sets the "account_id" field. +func (u *SoraUsageStatUpsert) SetAccountID(v int64) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldAccountID, v) + return u +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateAccountID() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldAccountID) + return u +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraUsageStatUpsert) AddAccountID(v int64) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldAccountID, v) + return u +} + +// SetImageCount sets the "image_count" field. +func (u *SoraUsageStatUpsert) SetImageCount(v int) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldImageCount, v) + return u +} + +// UpdateImageCount sets the "image_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateImageCount() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldImageCount) + return u +} + +// AddImageCount adds v to the "image_count" field. +func (u *SoraUsageStatUpsert) AddImageCount(v int) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldImageCount, v) + return u +} + +// SetVideoCount sets the "video_count" field. +func (u *SoraUsageStatUpsert) SetVideoCount(v int) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldVideoCount, v) + return u +} + +// UpdateVideoCount sets the "video_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateVideoCount() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldVideoCount) + return u +} + +// AddVideoCount adds v to the "video_count" field. +func (u *SoraUsageStatUpsert) AddVideoCount(v int) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldVideoCount, v) + return u +} + +// SetErrorCount sets the "error_count" field. +func (u *SoraUsageStatUpsert) SetErrorCount(v int) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldErrorCount, v) + return u +} + +// UpdateErrorCount sets the "error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateErrorCount() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldErrorCount) + return u +} + +// AddErrorCount adds v to the "error_count" field. +func (u *SoraUsageStatUpsert) AddErrorCount(v int) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldErrorCount, v) + return u +} + +// SetLastErrorAt sets the "last_error_at" field. +func (u *SoraUsageStatUpsert) SetLastErrorAt(v time.Time) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldLastErrorAt, v) + return u +} + +// UpdateLastErrorAt sets the "last_error_at" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateLastErrorAt() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldLastErrorAt) + return u +} + +// ClearLastErrorAt clears the value of the "last_error_at" field. +func (u *SoraUsageStatUpsert) ClearLastErrorAt() *SoraUsageStatUpsert { + u.SetNull(sorausagestat.FieldLastErrorAt) + return u +} + +// SetTodayImageCount sets the "today_image_count" field. +func (u *SoraUsageStatUpsert) SetTodayImageCount(v int) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldTodayImageCount, v) + return u +} + +// UpdateTodayImageCount sets the "today_image_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateTodayImageCount() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldTodayImageCount) + return u +} + +// AddTodayImageCount adds v to the "today_image_count" field. +func (u *SoraUsageStatUpsert) AddTodayImageCount(v int) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldTodayImageCount, v) + return u +} + +// SetTodayVideoCount sets the "today_video_count" field. +func (u *SoraUsageStatUpsert) SetTodayVideoCount(v int) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldTodayVideoCount, v) + return u +} + +// UpdateTodayVideoCount sets the "today_video_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateTodayVideoCount() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldTodayVideoCount) + return u +} + +// AddTodayVideoCount adds v to the "today_video_count" field. +func (u *SoraUsageStatUpsert) AddTodayVideoCount(v int) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldTodayVideoCount, v) + return u +} + +// SetTodayErrorCount sets the "today_error_count" field. +func (u *SoraUsageStatUpsert) SetTodayErrorCount(v int) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldTodayErrorCount, v) + return u +} + +// UpdateTodayErrorCount sets the "today_error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateTodayErrorCount() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldTodayErrorCount) + return u +} + +// AddTodayErrorCount adds v to the "today_error_count" field. +func (u *SoraUsageStatUpsert) AddTodayErrorCount(v int) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldTodayErrorCount, v) + return u +} + +// SetTodayDate sets the "today_date" field. +func (u *SoraUsageStatUpsert) SetTodayDate(v time.Time) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldTodayDate, v) + return u +} + +// UpdateTodayDate sets the "today_date" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateTodayDate() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldTodayDate) + return u +} + +// ClearTodayDate clears the value of the "today_date" field. +func (u *SoraUsageStatUpsert) ClearTodayDate() *SoraUsageStatUpsert { + u.SetNull(sorausagestat.FieldTodayDate) + return u +} + +// SetConsecutiveErrorCount sets the "consecutive_error_count" field. +func (u *SoraUsageStatUpsert) SetConsecutiveErrorCount(v int) *SoraUsageStatUpsert { + u.Set(sorausagestat.FieldConsecutiveErrorCount, v) + return u +} + +// UpdateConsecutiveErrorCount sets the "consecutive_error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsert) UpdateConsecutiveErrorCount() *SoraUsageStatUpsert { + u.SetExcluded(sorausagestat.FieldConsecutiveErrorCount) + return u +} + +// AddConsecutiveErrorCount adds v to the "consecutive_error_count" field. +func (u *SoraUsageStatUpsert) AddConsecutiveErrorCount(v int) *SoraUsageStatUpsert { + u.Add(sorausagestat.FieldConsecutiveErrorCount, v) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.SoraUsageStat.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraUsageStatUpsertOne) UpdateNewValues() *SoraUsageStatUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(sorausagestat.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraUsageStat.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraUsageStatUpsertOne) Ignore() *SoraUsageStatUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraUsageStatUpsertOne) DoNothing() *SoraUsageStatUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraUsageStatCreate.OnConflict +// documentation for more info. +func (u *SoraUsageStatUpsertOne) Update(set func(*SoraUsageStatUpsert)) *SoraUsageStatUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraUsageStatUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SoraUsageStatUpsertOne) SetUpdatedAt(v time.Time) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateUpdatedAt() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraUsageStatUpsertOne) SetAccountID(v int64) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraUsageStatUpsertOne) AddAccountID(v int64) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateAccountID() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateAccountID() + }) +} + +// SetImageCount sets the "image_count" field. +func (u *SoraUsageStatUpsertOne) SetImageCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetImageCount(v) + }) +} + +// AddImageCount adds v to the "image_count" field. +func (u *SoraUsageStatUpsertOne) AddImageCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddImageCount(v) + }) +} + +// UpdateImageCount sets the "image_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateImageCount() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateImageCount() + }) +} + +// SetVideoCount sets the "video_count" field. +func (u *SoraUsageStatUpsertOne) SetVideoCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetVideoCount(v) + }) +} + +// AddVideoCount adds v to the "video_count" field. +func (u *SoraUsageStatUpsertOne) AddVideoCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddVideoCount(v) + }) +} + +// UpdateVideoCount sets the "video_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateVideoCount() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateVideoCount() + }) +} + +// SetErrorCount sets the "error_count" field. +func (u *SoraUsageStatUpsertOne) SetErrorCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetErrorCount(v) + }) +} + +// AddErrorCount adds v to the "error_count" field. +func (u *SoraUsageStatUpsertOne) AddErrorCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddErrorCount(v) + }) +} + +// UpdateErrorCount sets the "error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateErrorCount() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateErrorCount() + }) +} + +// SetLastErrorAt sets the "last_error_at" field. +func (u *SoraUsageStatUpsertOne) SetLastErrorAt(v time.Time) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetLastErrorAt(v) + }) +} + +// UpdateLastErrorAt sets the "last_error_at" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateLastErrorAt() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateLastErrorAt() + }) +} + +// ClearLastErrorAt clears the value of the "last_error_at" field. +func (u *SoraUsageStatUpsertOne) ClearLastErrorAt() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.ClearLastErrorAt() + }) +} + +// SetTodayImageCount sets the "today_image_count" field. +func (u *SoraUsageStatUpsertOne) SetTodayImageCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayImageCount(v) + }) +} + +// AddTodayImageCount adds v to the "today_image_count" field. +func (u *SoraUsageStatUpsertOne) AddTodayImageCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddTodayImageCount(v) + }) +} + +// UpdateTodayImageCount sets the "today_image_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateTodayImageCount() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayImageCount() + }) +} + +// SetTodayVideoCount sets the "today_video_count" field. +func (u *SoraUsageStatUpsertOne) SetTodayVideoCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayVideoCount(v) + }) +} + +// AddTodayVideoCount adds v to the "today_video_count" field. +func (u *SoraUsageStatUpsertOne) AddTodayVideoCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddTodayVideoCount(v) + }) +} + +// UpdateTodayVideoCount sets the "today_video_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateTodayVideoCount() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayVideoCount() + }) +} + +// SetTodayErrorCount sets the "today_error_count" field. +func (u *SoraUsageStatUpsertOne) SetTodayErrorCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayErrorCount(v) + }) +} + +// AddTodayErrorCount adds v to the "today_error_count" field. +func (u *SoraUsageStatUpsertOne) AddTodayErrorCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddTodayErrorCount(v) + }) +} + +// UpdateTodayErrorCount sets the "today_error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateTodayErrorCount() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayErrorCount() + }) +} + +// SetTodayDate sets the "today_date" field. +func (u *SoraUsageStatUpsertOne) SetTodayDate(v time.Time) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayDate(v) + }) +} + +// UpdateTodayDate sets the "today_date" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateTodayDate() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayDate() + }) +} + +// ClearTodayDate clears the value of the "today_date" field. +func (u *SoraUsageStatUpsertOne) ClearTodayDate() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.ClearTodayDate() + }) +} + +// SetConsecutiveErrorCount sets the "consecutive_error_count" field. +func (u *SoraUsageStatUpsertOne) SetConsecutiveErrorCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetConsecutiveErrorCount(v) + }) +} + +// AddConsecutiveErrorCount adds v to the "consecutive_error_count" field. +func (u *SoraUsageStatUpsertOne) AddConsecutiveErrorCount(v int) *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddConsecutiveErrorCount(v) + }) +} + +// UpdateConsecutiveErrorCount sets the "consecutive_error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertOne) UpdateConsecutiveErrorCount() *SoraUsageStatUpsertOne { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateConsecutiveErrorCount() + }) +} + +// Exec executes the query. +func (u *SoraUsageStatUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraUsageStatCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraUsageStatUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SoraUsageStatUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SoraUsageStatUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// SoraUsageStatCreateBulk is the builder for creating many SoraUsageStat entities in bulk. +type SoraUsageStatCreateBulk struct { + config + err error + builders []*SoraUsageStatCreate + conflict []sql.ConflictOption +} + +// Save creates the SoraUsageStat entities in the database. +func (_c *SoraUsageStatCreateBulk) Save(ctx context.Context) ([]*SoraUsageStat, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*SoraUsageStat, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SoraUsageStatMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *SoraUsageStatCreateBulk) SaveX(ctx context.Context) []*SoraUsageStat { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SoraUsageStatCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SoraUsageStatCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SoraUsageStat.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SoraUsageStatUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SoraUsageStatCreateBulk) OnConflict(opts ...sql.ConflictOption) *SoraUsageStatUpsertBulk { + _c.conflict = opts + return &SoraUsageStatUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SoraUsageStat.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SoraUsageStatCreateBulk) OnConflictColumns(columns ...string) *SoraUsageStatUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SoraUsageStatUpsertBulk{ + create: _c, + } +} + +// SoraUsageStatUpsertBulk is the builder for "upsert"-ing +// a bulk of SoraUsageStat nodes. +type SoraUsageStatUpsertBulk struct { + create *SoraUsageStatCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.SoraUsageStat.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SoraUsageStatUpsertBulk) UpdateNewValues() *SoraUsageStatUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(sorausagestat.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SoraUsageStat.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SoraUsageStatUpsertBulk) Ignore() *SoraUsageStatUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SoraUsageStatUpsertBulk) DoNothing() *SoraUsageStatUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SoraUsageStatCreateBulk.OnConflict +// documentation for more info. +func (u *SoraUsageStatUpsertBulk) Update(set func(*SoraUsageStatUpsert)) *SoraUsageStatUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SoraUsageStatUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SoraUsageStatUpsertBulk) SetUpdatedAt(v time.Time) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateUpdatedAt() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetAccountID sets the "account_id" field. +func (u *SoraUsageStatUpsertBulk) SetAccountID(v int64) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetAccountID(v) + }) +} + +// AddAccountID adds v to the "account_id" field. +func (u *SoraUsageStatUpsertBulk) AddAccountID(v int64) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddAccountID(v) + }) +} + +// UpdateAccountID sets the "account_id" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateAccountID() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateAccountID() + }) +} + +// SetImageCount sets the "image_count" field. +func (u *SoraUsageStatUpsertBulk) SetImageCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetImageCount(v) + }) +} + +// AddImageCount adds v to the "image_count" field. +func (u *SoraUsageStatUpsertBulk) AddImageCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddImageCount(v) + }) +} + +// UpdateImageCount sets the "image_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateImageCount() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateImageCount() + }) +} + +// SetVideoCount sets the "video_count" field. +func (u *SoraUsageStatUpsertBulk) SetVideoCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetVideoCount(v) + }) +} + +// AddVideoCount adds v to the "video_count" field. +func (u *SoraUsageStatUpsertBulk) AddVideoCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddVideoCount(v) + }) +} + +// UpdateVideoCount sets the "video_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateVideoCount() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateVideoCount() + }) +} + +// SetErrorCount sets the "error_count" field. +func (u *SoraUsageStatUpsertBulk) SetErrorCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetErrorCount(v) + }) +} + +// AddErrorCount adds v to the "error_count" field. +func (u *SoraUsageStatUpsertBulk) AddErrorCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddErrorCount(v) + }) +} + +// UpdateErrorCount sets the "error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateErrorCount() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateErrorCount() + }) +} + +// SetLastErrorAt sets the "last_error_at" field. +func (u *SoraUsageStatUpsertBulk) SetLastErrorAt(v time.Time) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetLastErrorAt(v) + }) +} + +// UpdateLastErrorAt sets the "last_error_at" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateLastErrorAt() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateLastErrorAt() + }) +} + +// ClearLastErrorAt clears the value of the "last_error_at" field. +func (u *SoraUsageStatUpsertBulk) ClearLastErrorAt() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.ClearLastErrorAt() + }) +} + +// SetTodayImageCount sets the "today_image_count" field. +func (u *SoraUsageStatUpsertBulk) SetTodayImageCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayImageCount(v) + }) +} + +// AddTodayImageCount adds v to the "today_image_count" field. +func (u *SoraUsageStatUpsertBulk) AddTodayImageCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddTodayImageCount(v) + }) +} + +// UpdateTodayImageCount sets the "today_image_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateTodayImageCount() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayImageCount() + }) +} + +// SetTodayVideoCount sets the "today_video_count" field. +func (u *SoraUsageStatUpsertBulk) SetTodayVideoCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayVideoCount(v) + }) +} + +// AddTodayVideoCount adds v to the "today_video_count" field. +func (u *SoraUsageStatUpsertBulk) AddTodayVideoCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddTodayVideoCount(v) + }) +} + +// UpdateTodayVideoCount sets the "today_video_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateTodayVideoCount() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayVideoCount() + }) +} + +// SetTodayErrorCount sets the "today_error_count" field. +func (u *SoraUsageStatUpsertBulk) SetTodayErrorCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayErrorCount(v) + }) +} + +// AddTodayErrorCount adds v to the "today_error_count" field. +func (u *SoraUsageStatUpsertBulk) AddTodayErrorCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddTodayErrorCount(v) + }) +} + +// UpdateTodayErrorCount sets the "today_error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateTodayErrorCount() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayErrorCount() + }) +} + +// SetTodayDate sets the "today_date" field. +func (u *SoraUsageStatUpsertBulk) SetTodayDate(v time.Time) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetTodayDate(v) + }) +} + +// UpdateTodayDate sets the "today_date" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateTodayDate() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateTodayDate() + }) +} + +// ClearTodayDate clears the value of the "today_date" field. +func (u *SoraUsageStatUpsertBulk) ClearTodayDate() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.ClearTodayDate() + }) +} + +// SetConsecutiveErrorCount sets the "consecutive_error_count" field. +func (u *SoraUsageStatUpsertBulk) SetConsecutiveErrorCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.SetConsecutiveErrorCount(v) + }) +} + +// AddConsecutiveErrorCount adds v to the "consecutive_error_count" field. +func (u *SoraUsageStatUpsertBulk) AddConsecutiveErrorCount(v int) *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.AddConsecutiveErrorCount(v) + }) +} + +// UpdateConsecutiveErrorCount sets the "consecutive_error_count" field to the value that was provided on create. +func (u *SoraUsageStatUpsertBulk) UpdateConsecutiveErrorCount() *SoraUsageStatUpsertBulk { + return u.Update(func(s *SoraUsageStatUpsert) { + s.UpdateConsecutiveErrorCount() + }) +} + +// Exec executes the query. +func (u *SoraUsageStatUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SoraUsageStatCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SoraUsageStatCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SoraUsageStatUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/sorausagestat_delete.go b/backend/ent/sorausagestat_delete.go new file mode 100644 index 00000000..df4406a8 --- /dev/null +++ b/backend/ent/sorausagestat_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" +) + +// SoraUsageStatDelete is the builder for deleting a SoraUsageStat entity. +type SoraUsageStatDelete struct { + config + hooks []Hook + mutation *SoraUsageStatMutation +} + +// Where appends a list predicates to the SoraUsageStatDelete builder. +func (_d *SoraUsageStatDelete) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *SoraUsageStatDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraUsageStatDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *SoraUsageStatDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(sorausagestat.Table, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// SoraUsageStatDeleteOne is the builder for deleting a single SoraUsageStat entity. +type SoraUsageStatDeleteOne struct { + _d *SoraUsageStatDelete +} + +// Where appends a list predicates to the SoraUsageStatDelete builder. +func (_d *SoraUsageStatDeleteOne) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *SoraUsageStatDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{sorausagestat.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SoraUsageStatDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/sorausagestat_query.go b/backend/ent/sorausagestat_query.go new file mode 100644 index 00000000..da87d28c --- /dev/null +++ b/backend/ent/sorausagestat_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" +) + +// SoraUsageStatQuery is the builder for querying SoraUsageStat entities. +type SoraUsageStatQuery struct { + config + ctx *QueryContext + order []sorausagestat.OrderOption + inters []Interceptor + predicates []predicate.SoraUsageStat + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SoraUsageStatQuery builder. +func (_q *SoraUsageStatQuery) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *SoraUsageStatQuery) Limit(limit int) *SoraUsageStatQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *SoraUsageStatQuery) Offset(offset int) *SoraUsageStatQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *SoraUsageStatQuery) Unique(unique bool) *SoraUsageStatQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *SoraUsageStatQuery) Order(o ...sorausagestat.OrderOption) *SoraUsageStatQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first SoraUsageStat entity from the query. +// Returns a *NotFoundError when no SoraUsageStat was found. +func (_q *SoraUsageStatQuery) First(ctx context.Context) (*SoraUsageStat, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{sorausagestat.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *SoraUsageStatQuery) FirstX(ctx context.Context) *SoraUsageStat { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first SoraUsageStat ID from the query. +// Returns a *NotFoundError when no SoraUsageStat ID was found. +func (_q *SoraUsageStatQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{sorausagestat.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *SoraUsageStatQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single SoraUsageStat entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one SoraUsageStat entity is found. +// Returns a *NotFoundError when no SoraUsageStat entities are found. +func (_q *SoraUsageStatQuery) Only(ctx context.Context) (*SoraUsageStat, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{sorausagestat.Label} + default: + return nil, &NotSingularError{sorausagestat.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *SoraUsageStatQuery) OnlyX(ctx context.Context) *SoraUsageStat { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only SoraUsageStat ID in the query. +// Returns a *NotSingularError when more than one SoraUsageStat ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *SoraUsageStatQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{sorausagestat.Label} + default: + err = &NotSingularError{sorausagestat.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *SoraUsageStatQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of SoraUsageStats. +func (_q *SoraUsageStatQuery) All(ctx context.Context) ([]*SoraUsageStat, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*SoraUsageStat, *SoraUsageStatQuery]() + return withInterceptors[[]*SoraUsageStat](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *SoraUsageStatQuery) AllX(ctx context.Context) []*SoraUsageStat { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of SoraUsageStat IDs. +func (_q *SoraUsageStatQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(sorausagestat.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *SoraUsageStatQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *SoraUsageStatQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*SoraUsageStatQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *SoraUsageStatQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *SoraUsageStatQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *SoraUsageStatQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SoraUsageStatQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *SoraUsageStatQuery) Clone() *SoraUsageStatQuery { + if _q == nil { + return nil + } + return &SoraUsageStatQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]sorausagestat.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.SoraUsageStat{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.SoraUsageStat.Query(). +// GroupBy(sorausagestat.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *SoraUsageStatQuery) GroupBy(field string, fields ...string) *SoraUsageStatGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &SoraUsageStatGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = sorausagestat.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.SoraUsageStat.Query(). +// Select(sorausagestat.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *SoraUsageStatQuery) Select(fields ...string) *SoraUsageStatSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &SoraUsageStatSelect{SoraUsageStatQuery: _q} + sbuild.label = sorausagestat.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SoraUsageStatSelect configured with the given aggregations. +func (_q *SoraUsageStatQuery) Aggregate(fns ...AggregateFunc) *SoraUsageStatSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *SoraUsageStatQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !sorausagestat.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *SoraUsageStatQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SoraUsageStat, error) { + var ( + nodes = []*SoraUsageStat{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*SoraUsageStat).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &SoraUsageStat{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *SoraUsageStatQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *SoraUsageStatQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(sorausagestat.Table, sorausagestat.Columns, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, sorausagestat.FieldID) + for i := range fields { + if fields[i] != sorausagestat.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *SoraUsageStatQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(sorausagestat.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = sorausagestat.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *SoraUsageStatQuery) ForUpdate(opts ...sql.LockOption) *SoraUsageStatQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *SoraUsageStatQuery) ForShare(opts ...sql.LockOption) *SoraUsageStatQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// SoraUsageStatGroupBy is the group-by builder for SoraUsageStat entities. +type SoraUsageStatGroupBy struct { + selector + build *SoraUsageStatQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *SoraUsageStatGroupBy) Aggregate(fns ...AggregateFunc) *SoraUsageStatGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *SoraUsageStatGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraUsageStatQuery, *SoraUsageStatGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *SoraUsageStatGroupBy) sqlScan(ctx context.Context, root *SoraUsageStatQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SoraUsageStatSelect is the builder for selecting fields of SoraUsageStat entities. +type SoraUsageStatSelect struct { + *SoraUsageStatQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *SoraUsageStatSelect) Aggregate(fns ...AggregateFunc) *SoraUsageStatSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *SoraUsageStatSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SoraUsageStatQuery, *SoraUsageStatSelect](ctx, _s.SoraUsageStatQuery, _s, _s.inters, v) +} + +func (_s *SoraUsageStatSelect) sqlScan(ctx context.Context, root *SoraUsageStatQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/sorausagestat_update.go b/backend/ent/sorausagestat_update.go new file mode 100644 index 00000000..3210edac --- /dev/null +++ b/backend/ent/sorausagestat_update.go @@ -0,0 +1,748 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/sorausagestat" +) + +// SoraUsageStatUpdate is the builder for updating SoraUsageStat entities. +type SoraUsageStatUpdate struct { + config + hooks []Hook + mutation *SoraUsageStatMutation +} + +// Where appends a list predicates to the SoraUsageStatUpdate builder. +func (_u *SoraUsageStatUpdate) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SoraUsageStatUpdate) SetUpdatedAt(v time.Time) *SoraUsageStatUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraUsageStatUpdate) SetAccountID(v int64) *SoraUsageStatUpdate { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableAccountID(v *int64) *SoraUsageStatUpdate { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraUsageStatUpdate) AddAccountID(v int64) *SoraUsageStatUpdate { + _u.mutation.AddAccountID(v) + return _u +} + +// SetImageCount sets the "image_count" field. +func (_u *SoraUsageStatUpdate) SetImageCount(v int) *SoraUsageStatUpdate { + _u.mutation.ResetImageCount() + _u.mutation.SetImageCount(v) + return _u +} + +// SetNillableImageCount sets the "image_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableImageCount(v *int) *SoraUsageStatUpdate { + if v != nil { + _u.SetImageCount(*v) + } + return _u +} + +// AddImageCount adds value to the "image_count" field. +func (_u *SoraUsageStatUpdate) AddImageCount(v int) *SoraUsageStatUpdate { + _u.mutation.AddImageCount(v) + return _u +} + +// SetVideoCount sets the "video_count" field. +func (_u *SoraUsageStatUpdate) SetVideoCount(v int) *SoraUsageStatUpdate { + _u.mutation.ResetVideoCount() + _u.mutation.SetVideoCount(v) + return _u +} + +// SetNillableVideoCount sets the "video_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableVideoCount(v *int) *SoraUsageStatUpdate { + if v != nil { + _u.SetVideoCount(*v) + } + return _u +} + +// AddVideoCount adds value to the "video_count" field. +func (_u *SoraUsageStatUpdate) AddVideoCount(v int) *SoraUsageStatUpdate { + _u.mutation.AddVideoCount(v) + return _u +} + +// SetErrorCount sets the "error_count" field. +func (_u *SoraUsageStatUpdate) SetErrorCount(v int) *SoraUsageStatUpdate { + _u.mutation.ResetErrorCount() + _u.mutation.SetErrorCount(v) + return _u +} + +// SetNillableErrorCount sets the "error_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableErrorCount(v *int) *SoraUsageStatUpdate { + if v != nil { + _u.SetErrorCount(*v) + } + return _u +} + +// AddErrorCount adds value to the "error_count" field. +func (_u *SoraUsageStatUpdate) AddErrorCount(v int) *SoraUsageStatUpdate { + _u.mutation.AddErrorCount(v) + return _u +} + +// SetLastErrorAt sets the "last_error_at" field. +func (_u *SoraUsageStatUpdate) SetLastErrorAt(v time.Time) *SoraUsageStatUpdate { + _u.mutation.SetLastErrorAt(v) + return _u +} + +// SetNillableLastErrorAt sets the "last_error_at" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableLastErrorAt(v *time.Time) *SoraUsageStatUpdate { + if v != nil { + _u.SetLastErrorAt(*v) + } + return _u +} + +// ClearLastErrorAt clears the value of the "last_error_at" field. +func (_u *SoraUsageStatUpdate) ClearLastErrorAt() *SoraUsageStatUpdate { + _u.mutation.ClearLastErrorAt() + return _u +} + +// SetTodayImageCount sets the "today_image_count" field. +func (_u *SoraUsageStatUpdate) SetTodayImageCount(v int) *SoraUsageStatUpdate { + _u.mutation.ResetTodayImageCount() + _u.mutation.SetTodayImageCount(v) + return _u +} + +// SetNillableTodayImageCount sets the "today_image_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableTodayImageCount(v *int) *SoraUsageStatUpdate { + if v != nil { + _u.SetTodayImageCount(*v) + } + return _u +} + +// AddTodayImageCount adds value to the "today_image_count" field. +func (_u *SoraUsageStatUpdate) AddTodayImageCount(v int) *SoraUsageStatUpdate { + _u.mutation.AddTodayImageCount(v) + return _u +} + +// SetTodayVideoCount sets the "today_video_count" field. +func (_u *SoraUsageStatUpdate) SetTodayVideoCount(v int) *SoraUsageStatUpdate { + _u.mutation.ResetTodayVideoCount() + _u.mutation.SetTodayVideoCount(v) + return _u +} + +// SetNillableTodayVideoCount sets the "today_video_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableTodayVideoCount(v *int) *SoraUsageStatUpdate { + if v != nil { + _u.SetTodayVideoCount(*v) + } + return _u +} + +// AddTodayVideoCount adds value to the "today_video_count" field. +func (_u *SoraUsageStatUpdate) AddTodayVideoCount(v int) *SoraUsageStatUpdate { + _u.mutation.AddTodayVideoCount(v) + return _u +} + +// SetTodayErrorCount sets the "today_error_count" field. +func (_u *SoraUsageStatUpdate) SetTodayErrorCount(v int) *SoraUsageStatUpdate { + _u.mutation.ResetTodayErrorCount() + _u.mutation.SetTodayErrorCount(v) + return _u +} + +// SetNillableTodayErrorCount sets the "today_error_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableTodayErrorCount(v *int) *SoraUsageStatUpdate { + if v != nil { + _u.SetTodayErrorCount(*v) + } + return _u +} + +// AddTodayErrorCount adds value to the "today_error_count" field. +func (_u *SoraUsageStatUpdate) AddTodayErrorCount(v int) *SoraUsageStatUpdate { + _u.mutation.AddTodayErrorCount(v) + return _u +} + +// SetTodayDate sets the "today_date" field. +func (_u *SoraUsageStatUpdate) SetTodayDate(v time.Time) *SoraUsageStatUpdate { + _u.mutation.SetTodayDate(v) + return _u +} + +// SetNillableTodayDate sets the "today_date" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableTodayDate(v *time.Time) *SoraUsageStatUpdate { + if v != nil { + _u.SetTodayDate(*v) + } + return _u +} + +// ClearTodayDate clears the value of the "today_date" field. +func (_u *SoraUsageStatUpdate) ClearTodayDate() *SoraUsageStatUpdate { + _u.mutation.ClearTodayDate() + return _u +} + +// SetConsecutiveErrorCount sets the "consecutive_error_count" field. +func (_u *SoraUsageStatUpdate) SetConsecutiveErrorCount(v int) *SoraUsageStatUpdate { + _u.mutation.ResetConsecutiveErrorCount() + _u.mutation.SetConsecutiveErrorCount(v) + return _u +} + +// SetNillableConsecutiveErrorCount sets the "consecutive_error_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdate) SetNillableConsecutiveErrorCount(v *int) *SoraUsageStatUpdate { + if v != nil { + _u.SetConsecutiveErrorCount(*v) + } + return _u +} + +// AddConsecutiveErrorCount adds value to the "consecutive_error_count" field. +func (_u *SoraUsageStatUpdate) AddConsecutiveErrorCount(v int) *SoraUsageStatUpdate { + _u.mutation.AddConsecutiveErrorCount(v) + return _u +} + +// Mutation returns the SoraUsageStatMutation object of the builder. +func (_u *SoraUsageStatUpdate) Mutation() *SoraUsageStatMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *SoraUsageStatUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraUsageStatUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *SoraUsageStatUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraUsageStatUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SoraUsageStatUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := sorausagestat.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +func (_u *SoraUsageStatUpdate) sqlSave(ctx context.Context) (_node int, err error) { + _spec := sqlgraph.NewUpdateSpec(sorausagestat.Table, sorausagestat.Columns, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(sorausagestat.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(sorausagestat.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(sorausagestat.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.ImageCount(); ok { + _spec.SetField(sorausagestat.FieldImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedImageCount(); ok { + _spec.AddField(sorausagestat.FieldImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.VideoCount(); ok { + _spec.SetField(sorausagestat.FieldVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedVideoCount(); ok { + _spec.AddField(sorausagestat.FieldVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCount(); ok { + _spec.SetField(sorausagestat.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedErrorCount(); ok { + _spec.AddField(sorausagestat.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.LastErrorAt(); ok { + _spec.SetField(sorausagestat.FieldLastErrorAt, field.TypeTime, value) + } + if _u.mutation.LastErrorAtCleared() { + _spec.ClearField(sorausagestat.FieldLastErrorAt, field.TypeTime) + } + if value, ok := _u.mutation.TodayImageCount(); ok { + _spec.SetField(sorausagestat.FieldTodayImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTodayImageCount(); ok { + _spec.AddField(sorausagestat.FieldTodayImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.TodayVideoCount(); ok { + _spec.SetField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTodayVideoCount(); ok { + _spec.AddField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.TodayErrorCount(); ok { + _spec.SetField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTodayErrorCount(); ok { + _spec.AddField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.TodayDate(); ok { + _spec.SetField(sorausagestat.FieldTodayDate, field.TypeTime, value) + } + if _u.mutation.TodayDateCleared() { + _spec.ClearField(sorausagestat.FieldTodayDate, field.TypeTime) + } + if value, ok := _u.mutation.ConsecutiveErrorCount(); ok { + _spec.SetField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedConsecutiveErrorCount(); ok { + _spec.AddField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{sorausagestat.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// SoraUsageStatUpdateOne is the builder for updating a single SoraUsageStat entity. +type SoraUsageStatUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SoraUsageStatMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SoraUsageStatUpdateOne) SetUpdatedAt(v time.Time) *SoraUsageStatUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetAccountID sets the "account_id" field. +func (_u *SoraUsageStatUpdateOne) SetAccountID(v int64) *SoraUsageStatUpdateOne { + _u.mutation.ResetAccountID() + _u.mutation.SetAccountID(v) + return _u +} + +// SetNillableAccountID sets the "account_id" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableAccountID(v *int64) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetAccountID(*v) + } + return _u +} + +// AddAccountID adds value to the "account_id" field. +func (_u *SoraUsageStatUpdateOne) AddAccountID(v int64) *SoraUsageStatUpdateOne { + _u.mutation.AddAccountID(v) + return _u +} + +// SetImageCount sets the "image_count" field. +func (_u *SoraUsageStatUpdateOne) SetImageCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.ResetImageCount() + _u.mutation.SetImageCount(v) + return _u +} + +// SetNillableImageCount sets the "image_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableImageCount(v *int) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetImageCount(*v) + } + return _u +} + +// AddImageCount adds value to the "image_count" field. +func (_u *SoraUsageStatUpdateOne) AddImageCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.AddImageCount(v) + return _u +} + +// SetVideoCount sets the "video_count" field. +func (_u *SoraUsageStatUpdateOne) SetVideoCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.ResetVideoCount() + _u.mutation.SetVideoCount(v) + return _u +} + +// SetNillableVideoCount sets the "video_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableVideoCount(v *int) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetVideoCount(*v) + } + return _u +} + +// AddVideoCount adds value to the "video_count" field. +func (_u *SoraUsageStatUpdateOne) AddVideoCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.AddVideoCount(v) + return _u +} + +// SetErrorCount sets the "error_count" field. +func (_u *SoraUsageStatUpdateOne) SetErrorCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.ResetErrorCount() + _u.mutation.SetErrorCount(v) + return _u +} + +// SetNillableErrorCount sets the "error_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableErrorCount(v *int) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetErrorCount(*v) + } + return _u +} + +// AddErrorCount adds value to the "error_count" field. +func (_u *SoraUsageStatUpdateOne) AddErrorCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.AddErrorCount(v) + return _u +} + +// SetLastErrorAt sets the "last_error_at" field. +func (_u *SoraUsageStatUpdateOne) SetLastErrorAt(v time.Time) *SoraUsageStatUpdateOne { + _u.mutation.SetLastErrorAt(v) + return _u +} + +// SetNillableLastErrorAt sets the "last_error_at" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableLastErrorAt(v *time.Time) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetLastErrorAt(*v) + } + return _u +} + +// ClearLastErrorAt clears the value of the "last_error_at" field. +func (_u *SoraUsageStatUpdateOne) ClearLastErrorAt() *SoraUsageStatUpdateOne { + _u.mutation.ClearLastErrorAt() + return _u +} + +// SetTodayImageCount sets the "today_image_count" field. +func (_u *SoraUsageStatUpdateOne) SetTodayImageCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.ResetTodayImageCount() + _u.mutation.SetTodayImageCount(v) + return _u +} + +// SetNillableTodayImageCount sets the "today_image_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableTodayImageCount(v *int) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetTodayImageCount(*v) + } + return _u +} + +// AddTodayImageCount adds value to the "today_image_count" field. +func (_u *SoraUsageStatUpdateOne) AddTodayImageCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.AddTodayImageCount(v) + return _u +} + +// SetTodayVideoCount sets the "today_video_count" field. +func (_u *SoraUsageStatUpdateOne) SetTodayVideoCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.ResetTodayVideoCount() + _u.mutation.SetTodayVideoCount(v) + return _u +} + +// SetNillableTodayVideoCount sets the "today_video_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableTodayVideoCount(v *int) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetTodayVideoCount(*v) + } + return _u +} + +// AddTodayVideoCount adds value to the "today_video_count" field. +func (_u *SoraUsageStatUpdateOne) AddTodayVideoCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.AddTodayVideoCount(v) + return _u +} + +// SetTodayErrorCount sets the "today_error_count" field. +func (_u *SoraUsageStatUpdateOne) SetTodayErrorCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.ResetTodayErrorCount() + _u.mutation.SetTodayErrorCount(v) + return _u +} + +// SetNillableTodayErrorCount sets the "today_error_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableTodayErrorCount(v *int) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetTodayErrorCount(*v) + } + return _u +} + +// AddTodayErrorCount adds value to the "today_error_count" field. +func (_u *SoraUsageStatUpdateOne) AddTodayErrorCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.AddTodayErrorCount(v) + return _u +} + +// SetTodayDate sets the "today_date" field. +func (_u *SoraUsageStatUpdateOne) SetTodayDate(v time.Time) *SoraUsageStatUpdateOne { + _u.mutation.SetTodayDate(v) + return _u +} + +// SetNillableTodayDate sets the "today_date" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableTodayDate(v *time.Time) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetTodayDate(*v) + } + return _u +} + +// ClearTodayDate clears the value of the "today_date" field. +func (_u *SoraUsageStatUpdateOne) ClearTodayDate() *SoraUsageStatUpdateOne { + _u.mutation.ClearTodayDate() + return _u +} + +// SetConsecutiveErrorCount sets the "consecutive_error_count" field. +func (_u *SoraUsageStatUpdateOne) SetConsecutiveErrorCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.ResetConsecutiveErrorCount() + _u.mutation.SetConsecutiveErrorCount(v) + return _u +} + +// SetNillableConsecutiveErrorCount sets the "consecutive_error_count" field if the given value is not nil. +func (_u *SoraUsageStatUpdateOne) SetNillableConsecutiveErrorCount(v *int) *SoraUsageStatUpdateOne { + if v != nil { + _u.SetConsecutiveErrorCount(*v) + } + return _u +} + +// AddConsecutiveErrorCount adds value to the "consecutive_error_count" field. +func (_u *SoraUsageStatUpdateOne) AddConsecutiveErrorCount(v int) *SoraUsageStatUpdateOne { + _u.mutation.AddConsecutiveErrorCount(v) + return _u +} + +// Mutation returns the SoraUsageStatMutation object of the builder. +func (_u *SoraUsageStatUpdateOne) Mutation() *SoraUsageStatMutation { + return _u.mutation +} + +// Where appends a list predicates to the SoraUsageStatUpdate builder. +func (_u *SoraUsageStatUpdateOne) Where(ps ...predicate.SoraUsageStat) *SoraUsageStatUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *SoraUsageStatUpdateOne) Select(field string, fields ...string) *SoraUsageStatUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated SoraUsageStat entity. +func (_u *SoraUsageStatUpdateOne) Save(ctx context.Context) (*SoraUsageStat, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SoraUsageStatUpdateOne) SaveX(ctx context.Context) *SoraUsageStat { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *SoraUsageStatUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SoraUsageStatUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SoraUsageStatUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := sorausagestat.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +func (_u *SoraUsageStatUpdateOne) sqlSave(ctx context.Context) (_node *SoraUsageStat, err error) { + _spec := sqlgraph.NewUpdateSpec(sorausagestat.Table, sorausagestat.Columns, sqlgraph.NewFieldSpec(sorausagestat.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SoraUsageStat.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, sorausagestat.FieldID) + for _, f := range fields { + if !sorausagestat.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != sorausagestat.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(sorausagestat.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AccountID(); ok { + _spec.SetField(sorausagestat.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedAccountID(); ok { + _spec.AddField(sorausagestat.FieldAccountID, field.TypeInt64, value) + } + if value, ok := _u.mutation.ImageCount(); ok { + _spec.SetField(sorausagestat.FieldImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedImageCount(); ok { + _spec.AddField(sorausagestat.FieldImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.VideoCount(); ok { + _spec.SetField(sorausagestat.FieldVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedVideoCount(); ok { + _spec.AddField(sorausagestat.FieldVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCount(); ok { + _spec.SetField(sorausagestat.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedErrorCount(); ok { + _spec.AddField(sorausagestat.FieldErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.LastErrorAt(); ok { + _spec.SetField(sorausagestat.FieldLastErrorAt, field.TypeTime, value) + } + if _u.mutation.LastErrorAtCleared() { + _spec.ClearField(sorausagestat.FieldLastErrorAt, field.TypeTime) + } + if value, ok := _u.mutation.TodayImageCount(); ok { + _spec.SetField(sorausagestat.FieldTodayImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTodayImageCount(); ok { + _spec.AddField(sorausagestat.FieldTodayImageCount, field.TypeInt, value) + } + if value, ok := _u.mutation.TodayVideoCount(); ok { + _spec.SetField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTodayVideoCount(); ok { + _spec.AddField(sorausagestat.FieldTodayVideoCount, field.TypeInt, value) + } + if value, ok := _u.mutation.TodayErrorCount(); ok { + _spec.SetField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedTodayErrorCount(); ok { + _spec.AddField(sorausagestat.FieldTodayErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.TodayDate(); ok { + _spec.SetField(sorausagestat.FieldTodayDate, field.TypeTime, value) + } + if _u.mutation.TodayDateCleared() { + _spec.ClearField(sorausagestat.FieldTodayDate, field.TypeTime) + } + if value, ok := _u.mutation.ConsecutiveErrorCount(); ok { + _spec.SetField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedConsecutiveErrorCount(); ok { + _spec.AddField(sorausagestat.FieldConsecutiveErrorCount, field.TypeInt, value) + } + _node = &SoraUsageStat{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{sorausagestat.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 7ff16ec8..427c1552 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -32,6 +32,14 @@ type Tx struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // SoraAccount is the client for interacting with the SoraAccount builders. + SoraAccount *SoraAccountClient + // SoraCacheFile is the client for interacting with the SoraCacheFile builders. + SoraCacheFile *SoraCacheFileClient + // SoraTask is the client for interacting with the SoraTask builders. + SoraTask *SoraTaskClient + // SoraUsageStat is the client for interacting with the SoraUsageStat builders. + SoraUsageStat *SoraUsageStatClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. UsageCleanupTask *UsageCleanupTaskClient // UsageLog is the client for interacting with the UsageLog builders. @@ -186,6 +194,10 @@ func (tx *Tx) init() { tx.Proxy = NewProxyClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.Setting = NewSettingClient(tx.config) + tx.SoraAccount = NewSoraAccountClient(tx.config) + tx.SoraCacheFile = NewSoraCacheFileClient(tx.config) + tx.SoraTask = NewSoraTaskClient(tx.config) + tx.SoraUsageStat = NewSoraUsageStatClient(tx.config) tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config) tx.User = NewUserClient(tx.config) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 00a78480..b7df2273 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -58,6 +58,7 @@ type Config struct { UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Sora SoraConfig `mapstructure:"sora"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` @@ -69,6 +70,38 @@ type GeminiConfig struct { Quota GeminiQuotaConfig `mapstructure:"quota"` } +type SoraConfig struct { + BaseURL string `mapstructure:"base_url"` + Timeout int `mapstructure:"timeout"` + MaxRetries int `mapstructure:"max_retries"` + PollInterval float64 `mapstructure:"poll_interval"` + CallLogicMode string `mapstructure:"call_logic_mode"` + Cache SoraCacheConfig `mapstructure:"cache"` + WatermarkFree SoraWatermarkFreeConfig `mapstructure:"watermark_free"` + TokenRefresh SoraTokenRefreshConfig `mapstructure:"token_refresh"` +} + +type SoraCacheConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseDir string `mapstructure:"base_dir"` + VideoDir string `mapstructure:"video_dir"` + MaxBytes int64 `mapstructure:"max_bytes"` + AllowedHosts []string `mapstructure:"allowed_hosts"` + UserDirEnabled bool `mapstructure:"user_dir_enabled"` +} + +type SoraWatermarkFreeConfig struct { + Enabled bool `mapstructure:"enabled"` + ParseMethod string `mapstructure:"parse_method"` + CustomParseURL string `mapstructure:"custom_parse_url"` + CustomParseToken string `mapstructure:"custom_parse_token"` + FallbackOnFailure bool `mapstructure:"fallback_on_failure"` +} + +type SoraTokenRefreshConfig struct { + Enabled bool `mapstructure:"enabled"` +} + type GeminiOAuthConfig struct { ClientID string `mapstructure:"client_id"` ClientSecret string `mapstructure:"client_secret"` @@ -862,6 +895,24 @@ func setDefaults() { viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("sora.base_url", "https://sora.chatgpt.com/backend") + viper.SetDefault("sora.timeout", 120) + viper.SetDefault("sora.max_retries", 3) + viper.SetDefault("sora.poll_interval", 2.5) + viper.SetDefault("sora.call_logic_mode", "default") + viper.SetDefault("sora.cache.enabled", false) + viper.SetDefault("sora.cache.base_dir", "tmp/sora") + viper.SetDefault("sora.cache.video_dir", "data/video") + viper.SetDefault("sora.cache.max_bytes", int64(0)) + viper.SetDefault("sora.cache.allowed_hosts", []string{}) + viper.SetDefault("sora.cache.user_dir_enabled", true) + viper.SetDefault("sora.watermark_free.enabled", false) + viper.SetDefault("sora.watermark_free.parse_method", "third_party") + viper.SetDefault("sora.watermark_free.custom_parse_url", "") + viper.SetDefault("sora.watermark_free.custom_parse_token", "") + viper.SetDefault("sora.watermark_free.fallback_on_failure", true) + viper.SetDefault("sora.token_refresh.enabled", false) + // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET // Default: uses Gemini CLI public credentials (set via environment) diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 926624d2..25c12b75 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -49,7 +49,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0e3e0a2f..5451b848 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -79,6 +79,23 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { FallbackModelAntigravity: settings.FallbackModelAntigravity, EnableIdentityPatch: settings.EnableIdentityPatch, IdentityPatchPrompt: settings.IdentityPatchPrompt, + SoraBaseURL: settings.SoraBaseURL, + SoraTimeout: settings.SoraTimeout, + SoraMaxRetries: settings.SoraMaxRetries, + SoraPollInterval: settings.SoraPollInterval, + SoraCallLogicMode: settings.SoraCallLogicMode, + SoraCacheEnabled: settings.SoraCacheEnabled, + SoraCacheBaseDir: settings.SoraCacheBaseDir, + SoraCacheVideoDir: settings.SoraCacheVideoDir, + SoraCacheMaxBytes: settings.SoraCacheMaxBytes, + SoraCacheAllowedHosts: settings.SoraCacheAllowedHosts, + SoraCacheUserDirEnabled: settings.SoraCacheUserDirEnabled, + SoraWatermarkFreeEnabled: settings.SoraWatermarkFreeEnabled, + SoraWatermarkFreeParseMethod: settings.SoraWatermarkFreeParseMethod, + SoraWatermarkFreeCustomParseURL: settings.SoraWatermarkFreeCustomParseURL, + SoraWatermarkFreeCustomParseToken: settings.SoraWatermarkFreeCustomParseToken, + SoraWatermarkFreeFallbackOnFailure: settings.SoraWatermarkFreeFallbackOnFailure, + SoraTokenRefreshEnabled: settings.SoraTokenRefreshEnabled, OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled, OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: settings.OpsQueryModeDefault, @@ -138,6 +155,25 @@ type UpdateSettingsRequest struct { EnableIdentityPatch bool `json:"enable_identity_patch"` IdentityPatchPrompt string `json:"identity_patch_prompt"` + // Sora configuration + SoraBaseURL string `json:"sora_base_url"` + SoraTimeout int `json:"sora_timeout"` + SoraMaxRetries int `json:"sora_max_retries"` + SoraPollInterval float64 `json:"sora_poll_interval"` + SoraCallLogicMode string `json:"sora_call_logic_mode"` + SoraCacheEnabled bool `json:"sora_cache_enabled"` + SoraCacheBaseDir string `json:"sora_cache_base_dir"` + SoraCacheVideoDir string `json:"sora_cache_video_dir"` + SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"` + SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"` + SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"` + SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"` + SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"` + SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"` + SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"` + SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"` + SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"` + // Ops monitoring (vNext) OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"` OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"` @@ -227,6 +263,32 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // Sora 参数校验与清理 + req.SoraBaseURL = strings.TrimSpace(req.SoraBaseURL) + if req.SoraBaseURL == "" { + req.SoraBaseURL = previousSettings.SoraBaseURL + } + if req.SoraBaseURL != "" { + if err := config.ValidateAbsoluteHTTPURL(req.SoraBaseURL); err != nil { + response.BadRequest(c, "Sora Base URL must be an absolute http(s) URL") + return + } + } + if req.SoraTimeout <= 0 { + req.SoraTimeout = previousSettings.SoraTimeout + } + if req.SoraMaxRetries < 0 { + req.SoraMaxRetries = previousSettings.SoraMaxRetries + } + if req.SoraPollInterval <= 0 { + req.SoraPollInterval = previousSettings.SoraPollInterval + } + if req.SoraCacheMaxBytes < 0 { + req.SoraCacheMaxBytes = 0 + } + req.SoraCacheAllowedHosts = normalizeStringList(req.SoraCacheAllowedHosts) + req.SoraWatermarkFreeCustomParseURL = strings.TrimSpace(req.SoraWatermarkFreeCustomParseURL) + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -240,40 +302,57 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - PromoCodeEnabled: req.PromoCodeEnabled, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, - LinuxDoConnectClientID: req.LinuxDoConnectClientID, - LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, - LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - HomeContent: req.HomeContent, - HideCcsImportButton: req.HideCcsImportButton, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + PromoCodeEnabled: req.PromoCodeEnabled, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + HomeContent: req.HomeContent, + HideCcsImportButton: req.HideCcsImportButton, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, + SoraBaseURL: req.SoraBaseURL, + SoraTimeout: req.SoraTimeout, + SoraMaxRetries: req.SoraMaxRetries, + SoraPollInterval: req.SoraPollInterval, + SoraCallLogicMode: req.SoraCallLogicMode, + SoraCacheEnabled: req.SoraCacheEnabled, + SoraCacheBaseDir: req.SoraCacheBaseDir, + SoraCacheVideoDir: req.SoraCacheVideoDir, + SoraCacheMaxBytes: req.SoraCacheMaxBytes, + SoraCacheAllowedHosts: req.SoraCacheAllowedHosts, + SoraCacheUserDirEnabled: req.SoraCacheUserDirEnabled, + SoraWatermarkFreeEnabled: req.SoraWatermarkFreeEnabled, + SoraWatermarkFreeParseMethod: req.SoraWatermarkFreeParseMethod, + SoraWatermarkFreeCustomParseURL: req.SoraWatermarkFreeCustomParseURL, + SoraWatermarkFreeCustomParseToken: req.SoraWatermarkFreeCustomParseToken, + SoraWatermarkFreeFallbackOnFailure: req.SoraWatermarkFreeFallbackOnFailure, + SoraTokenRefreshEnabled: req.SoraTokenRefreshEnabled, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -349,6 +428,23 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, EnableIdentityPatch: updatedSettings.EnableIdentityPatch, IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, + SoraBaseURL: updatedSettings.SoraBaseURL, + SoraTimeout: updatedSettings.SoraTimeout, + SoraMaxRetries: updatedSettings.SoraMaxRetries, + SoraPollInterval: updatedSettings.SoraPollInterval, + SoraCallLogicMode: updatedSettings.SoraCallLogicMode, + SoraCacheEnabled: updatedSettings.SoraCacheEnabled, + SoraCacheBaseDir: updatedSettings.SoraCacheBaseDir, + SoraCacheVideoDir: updatedSettings.SoraCacheVideoDir, + SoraCacheMaxBytes: updatedSettings.SoraCacheMaxBytes, + SoraCacheAllowedHosts: updatedSettings.SoraCacheAllowedHosts, + SoraCacheUserDirEnabled: updatedSettings.SoraCacheUserDirEnabled, + SoraWatermarkFreeEnabled: updatedSettings.SoraWatermarkFreeEnabled, + SoraWatermarkFreeParseMethod: updatedSettings.SoraWatermarkFreeParseMethod, + SoraWatermarkFreeCustomParseURL: updatedSettings.SoraWatermarkFreeCustomParseURL, + SoraWatermarkFreeCustomParseToken: updatedSettings.SoraWatermarkFreeCustomParseToken, + SoraWatermarkFreeFallbackOnFailure: updatedSettings.SoraWatermarkFreeFallbackOnFailure, + SoraTokenRefreshEnabled: updatedSettings.SoraTokenRefreshEnabled, OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled, OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, @@ -477,6 +573,57 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.IdentityPatchPrompt != after.IdentityPatchPrompt { changed = append(changed, "identity_patch_prompt") } + if before.SoraBaseURL != after.SoraBaseURL { + changed = append(changed, "sora_base_url") + } + if before.SoraTimeout != after.SoraTimeout { + changed = append(changed, "sora_timeout") + } + if before.SoraMaxRetries != after.SoraMaxRetries { + changed = append(changed, "sora_max_retries") + } + if before.SoraPollInterval != after.SoraPollInterval { + changed = append(changed, "sora_poll_interval") + } + if before.SoraCallLogicMode != after.SoraCallLogicMode { + changed = append(changed, "sora_call_logic_mode") + } + if before.SoraCacheEnabled != after.SoraCacheEnabled { + changed = append(changed, "sora_cache_enabled") + } + if before.SoraCacheBaseDir != after.SoraCacheBaseDir { + changed = append(changed, "sora_cache_base_dir") + } + if before.SoraCacheVideoDir != after.SoraCacheVideoDir { + changed = append(changed, "sora_cache_video_dir") + } + if before.SoraCacheMaxBytes != after.SoraCacheMaxBytes { + changed = append(changed, "sora_cache_max_bytes") + } + if strings.Join(before.SoraCacheAllowedHosts, ",") != strings.Join(after.SoraCacheAllowedHosts, ",") { + changed = append(changed, "sora_cache_allowed_hosts") + } + if before.SoraCacheUserDirEnabled != after.SoraCacheUserDirEnabled { + changed = append(changed, "sora_cache_user_dir_enabled") + } + if before.SoraWatermarkFreeEnabled != after.SoraWatermarkFreeEnabled { + changed = append(changed, "sora_watermark_free_enabled") + } + if before.SoraWatermarkFreeParseMethod != after.SoraWatermarkFreeParseMethod { + changed = append(changed, "sora_watermark_free_parse_method") + } + if before.SoraWatermarkFreeCustomParseURL != after.SoraWatermarkFreeCustomParseURL { + changed = append(changed, "sora_watermark_free_custom_parse_url") + } + if before.SoraWatermarkFreeCustomParseToken != after.SoraWatermarkFreeCustomParseToken { + changed = append(changed, "sora_watermark_free_custom_parse_token") + } + if before.SoraWatermarkFreeFallbackOnFailure != after.SoraWatermarkFreeFallbackOnFailure { + changed = append(changed, "sora_watermark_free_fallback_on_failure") + } + if before.SoraTokenRefreshEnabled != after.SoraTokenRefreshEnabled { + changed = append(changed, "sora_token_refresh_enabled") + } if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled { changed = append(changed, "ops_monitoring_enabled") } @@ -492,6 +639,19 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, return changed } +func normalizeStringList(values []string) []string { + if len(values) == 0 { + return []string{} + } + normalized := make([]string, 0, len(values)) + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + normalized = append(normalized, trimmed) + } + } + return normalized +} + // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host" binding:"required"` diff --git a/backend/internal/handler/admin/sora_account_handler.go b/backend/internal/handler/admin/sora_account_handler.go new file mode 100644 index 00000000..adfefc0d --- /dev/null +++ b/backend/internal/handler/admin/sora_account_handler.go @@ -0,0 +1,355 @@ +package admin + +import ( + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// SoraAccountHandler Sora 账号扩展管理 +// 提供 Sora 扩展表的查询与更新能力。 +type SoraAccountHandler struct { + adminService service.AdminService + soraAccountRepo service.SoraAccountRepository + usageRepo service.SoraUsageStatRepository +} + +// NewSoraAccountHandler 创建 SoraAccountHandler +func NewSoraAccountHandler(adminService service.AdminService, soraAccountRepo service.SoraAccountRepository, usageRepo service.SoraUsageStatRepository) *SoraAccountHandler { + return &SoraAccountHandler{ + adminService: adminService, + soraAccountRepo: soraAccountRepo, + usageRepo: usageRepo, + } +} + +// SoraAccountUpdateRequest 更新/创建 Sora 账号扩展请求 +// 使用指针类型区分未提供与设置为空值。 +type SoraAccountUpdateRequest struct { + AccessToken *string `json:"access_token"` + SessionToken *string `json:"session_token"` + RefreshToken *string `json:"refresh_token"` + ClientID *string `json:"client_id"` + Email *string `json:"email"` + Username *string `json:"username"` + Remark *string `json:"remark"` + UseCount *int `json:"use_count"` + PlanType *string `json:"plan_type"` + PlanTitle *string `json:"plan_title"` + SubscriptionEnd *int64 `json:"subscription_end"` + SoraSupported *bool `json:"sora_supported"` + SoraInviteCode *string `json:"sora_invite_code"` + SoraRedeemedCount *int `json:"sora_redeemed_count"` + SoraRemainingCount *int `json:"sora_remaining_count"` + SoraTotalCount *int `json:"sora_total_count"` + SoraCooldownUntil *int64 `json:"sora_cooldown_until"` + CooledUntil *int64 `json:"cooled_until"` + ImageEnabled *bool `json:"image_enabled"` + VideoEnabled *bool `json:"video_enabled"` + ImageConcurrency *int `json:"image_concurrency"` + VideoConcurrency *int `json:"video_concurrency"` + IsExpired *bool `json:"is_expired"` +} + +// SoraAccountBatchRequest 批量导入请求 +// accounts 支持批量 upsert。 +type SoraAccountBatchRequest struct { + Accounts []SoraAccountBatchItem `json:"accounts"` +} + +// SoraAccountBatchItem 批量导入条目 +type SoraAccountBatchItem struct { + AccountID int64 `json:"account_id"` + SoraAccountUpdateRequest +} + +// SoraAccountBatchResult 批量导入结果 +// 仅返回成功/失败数量与明细。 +type SoraAccountBatchResult struct { + Success int `json:"success"` + Failed int `json:"failed"` + Results []SoraAccountBatchItemResult `json:"results"` +} + +// SoraAccountBatchItemResult 批量导入单条结果 +type SoraAccountBatchItemResult struct { + AccountID int64 `json:"account_id"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +// List 获取 Sora 账号扩展列表 +// GET /api/v1/admin/sora/accounts +func (h *SoraAccountHandler) List(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + search := strings.TrimSpace(c.Query("search")) + + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, service.PlatformSora, "", "", search) + if err != nil { + response.ErrorFrom(c, err) + return + } + + accountIDs := make([]int64, 0, len(accounts)) + for i := range accounts { + accountIDs = append(accountIDs, accounts[i].ID) + } + + soraMap := map[int64]*service.SoraAccount{} + if h.soraAccountRepo != nil { + soraMap, _ = h.soraAccountRepo.GetByAccountIDs(c.Request.Context(), accountIDs) + } + + usageMap := map[int64]*service.SoraUsageStat{} + if h.usageRepo != nil { + usageMap, _ = h.usageRepo.GetByAccountIDs(c.Request.Context(), accountIDs) + } + + result := make([]dto.SoraAccount, 0, len(accounts)) + for i := range accounts { + acc := accounts[i] + item := dto.SoraAccountFromService(&acc, soraMap[acc.ID], usageMap[acc.ID]) + if item != nil { + result = append(result, *item) + } + } + + response.Paginated(c, result, total, page, pageSize) +} + +// Get 获取单个 Sora 账号扩展 +// GET /api/v1/admin/sora/accounts/:id +func (h *SoraAccountHandler) Get(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "账号 ID 无效") + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if account.Platform != service.PlatformSora { + response.BadRequest(c, "账号不是 Sora 平台") + return + } + + var soraAcc *service.SoraAccount + if h.soraAccountRepo != nil { + soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID) + } + var usage *service.SoraUsageStat + if h.usageRepo != nil { + usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID) + } + + response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage)) +} + +// Upsert 更新或创建 Sora 账号扩展 +// PUT /api/v1/admin/sora/accounts/:id +func (h *SoraAccountHandler) Upsert(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "账号 ID 无效") + return + } + + var req SoraAccountUpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "请求参数无效: "+err.Error()) + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if account.Platform != service.PlatformSora { + response.BadRequest(c, "账号不是 Sora 平台") + return + } + + updates := buildSoraAccountUpdates(&req) + if h.soraAccountRepo != nil && len(updates) > 0 { + if err := h.soraAccountRepo.Upsert(c.Request.Context(), accountID, updates); err != nil { + response.ErrorFrom(c, err) + return + } + } + + var soraAcc *service.SoraAccount + if h.soraAccountRepo != nil { + soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID) + } + var usage *service.SoraUsageStat + if h.usageRepo != nil { + usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID) + } + + response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage)) +} + +// BatchUpsert 批量导入 Sora 账号扩展 +// POST /api/v1/admin/sora/accounts/import +func (h *SoraAccountHandler) BatchUpsert(c *gin.Context) { + var req SoraAccountBatchRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "请求参数无效: "+err.Error()) + return + } + if len(req.Accounts) == 0 { + response.BadRequest(c, "accounts 不能为空") + return + } + + ids := make([]int64, 0, len(req.Accounts)) + for _, item := range req.Accounts { + if item.AccountID > 0 { + ids = append(ids, item.AccountID) + } + } + + accountMap := make(map[int64]*service.Account, len(ids)) + if len(ids) > 0 { + accounts, _ := h.adminService.GetAccountsByIDs(c.Request.Context(), ids) + for _, acc := range accounts { + if acc != nil { + accountMap[acc.ID] = acc + } + } + } + + result := SoraAccountBatchResult{ + Results: make([]SoraAccountBatchItemResult, 0, len(req.Accounts)), + } + + for _, item := range req.Accounts { + entry := SoraAccountBatchItemResult{AccountID: item.AccountID} + acc := accountMap[item.AccountID] + if acc == nil { + entry.Error = "账号不存在" + result.Results = append(result.Results, entry) + result.Failed++ + continue + } + if acc.Platform != service.PlatformSora { + entry.Error = "账号不是 Sora 平台" + result.Results = append(result.Results, entry) + result.Failed++ + continue + } + updates := buildSoraAccountUpdates(&item.SoraAccountUpdateRequest) + if h.soraAccountRepo != nil && len(updates) > 0 { + if err := h.soraAccountRepo.Upsert(c.Request.Context(), item.AccountID, updates); err != nil { + entry.Error = err.Error() + result.Results = append(result.Results, entry) + result.Failed++ + continue + } + } + entry.Success = true + result.Results = append(result.Results, entry) + result.Success++ + } + + response.Success(c, result) +} + +// ListUsage 获取 Sora 调用统计 +// GET /api/v1/admin/sora/usage +func (h *SoraAccountHandler) ListUsage(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + if h.usageRepo == nil { + response.Paginated(c, []dto.SoraUsageStat{}, 0, page, pageSize) + return + } + stats, paginationResult, err := h.usageRepo.List(c.Request.Context(), params) + if err != nil { + response.ErrorFrom(c, err) + return + } + result := make([]dto.SoraUsageStat, 0, len(stats)) + for _, stat := range stats { + item := dto.SoraUsageStatFromService(stat) + if item != nil { + result = append(result, *item) + } + } + response.Paginated(c, result, paginationResult.Total, paginationResult.Page, paginationResult.PageSize) +} + +func buildSoraAccountUpdates(req *SoraAccountUpdateRequest) map[string]any { + if req == nil { + return nil + } + updates := make(map[string]any) + setString := func(key string, value *string) { + if value == nil { + return + } + updates[key] = strings.TrimSpace(*value) + } + setString("access_token", req.AccessToken) + setString("session_token", req.SessionToken) + setString("refresh_token", req.RefreshToken) + setString("client_id", req.ClientID) + setString("email", req.Email) + setString("username", req.Username) + setString("remark", req.Remark) + setString("plan_type", req.PlanType) + setString("plan_title", req.PlanTitle) + setString("sora_invite_code", req.SoraInviteCode) + + if req.UseCount != nil { + updates["use_count"] = *req.UseCount + } + if req.SoraSupported != nil { + updates["sora_supported"] = *req.SoraSupported + } + if req.SoraRedeemedCount != nil { + updates["sora_redeemed_count"] = *req.SoraRedeemedCount + } + if req.SoraRemainingCount != nil { + updates["sora_remaining_count"] = *req.SoraRemainingCount + } + if req.SoraTotalCount != nil { + updates["sora_total_count"] = *req.SoraTotalCount + } + if req.ImageEnabled != nil { + updates["image_enabled"] = *req.ImageEnabled + } + if req.VideoEnabled != nil { + updates["video_enabled"] = *req.VideoEnabled + } + if req.ImageConcurrency != nil { + updates["image_concurrency"] = *req.ImageConcurrency + } + if req.VideoConcurrency != nil { + updates["video_concurrency"] = *req.VideoConcurrency + } + if req.IsExpired != nil { + updates["is_expired"] = *req.IsExpired + } + if req.SubscriptionEnd != nil && *req.SubscriptionEnd > 0 { + updates["subscription_end"] = time.Unix(*req.SubscriptionEnd, 0).UTC() + } + if req.SoraCooldownUntil != nil && *req.SoraCooldownUntil > 0 { + updates["sora_cooldown_until"] = time.Unix(*req.SoraCooldownUntil, 0).UTC() + } + if req.CooledUntil != nil && *req.CooledUntil > 0 { + updates["cooled_until"] = time.Unix(*req.CooledUntil, 0).UTC() + } + return updates +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d58a8a29..3d48e13b 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -287,6 +287,72 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi } } +func SoraUsageStatFromService(stat *service.SoraUsageStat) *SoraUsageStat { + if stat == nil { + return nil + } + return &SoraUsageStat{ + AccountID: stat.AccountID, + ImageCount: stat.ImageCount, + VideoCount: stat.VideoCount, + ErrorCount: stat.ErrorCount, + LastErrorAt: timeToUnixSeconds(stat.LastErrorAt), + TodayImageCount: stat.TodayImageCount, + TodayVideoCount: stat.TodayVideoCount, + TodayErrorCount: stat.TodayErrorCount, + TodayDate: timeToUnixSeconds(stat.TodayDate), + ConsecutiveErrorCount: stat.ConsecutiveErrorCount, + CreatedAt: stat.CreatedAt, + UpdatedAt: stat.UpdatedAt, + } +} + +func SoraAccountFromService(account *service.Account, soraAcc *service.SoraAccount, usage *service.SoraUsageStat) *SoraAccount { + if account == nil { + return nil + } + out := &SoraAccount{ + AccountID: account.ID, + AccountName: account.Name, + AccountStatus: account.Status, + AccountType: account.Type, + AccountConcurrency: account.Concurrency, + ProxyID: account.ProxyID, + Usage: SoraUsageStatFromService(usage), + CreatedAt: account.CreatedAt, + UpdatedAt: account.UpdatedAt, + } + if soraAcc == nil { + return out + } + out.AccessToken = soraAcc.AccessToken + out.SessionToken = soraAcc.SessionToken + out.RefreshToken = soraAcc.RefreshToken + out.ClientID = soraAcc.ClientID + out.Email = soraAcc.Email + out.Username = soraAcc.Username + out.Remark = soraAcc.Remark + out.UseCount = soraAcc.UseCount + out.PlanType = soraAcc.PlanType + out.PlanTitle = soraAcc.PlanTitle + out.SubscriptionEnd = timeToUnixSeconds(soraAcc.SubscriptionEnd) + out.SoraSupported = soraAcc.SoraSupported + out.SoraInviteCode = soraAcc.SoraInviteCode + out.SoraRedeemedCount = soraAcc.SoraRedeemedCount + out.SoraRemainingCount = soraAcc.SoraRemainingCount + out.SoraTotalCount = soraAcc.SoraTotalCount + out.SoraCooldownUntil = timeToUnixSeconds(soraAcc.SoraCooldownUntil) + out.CooledUntil = timeToUnixSeconds(soraAcc.CooledUntil) + out.ImageEnabled = soraAcc.ImageEnabled + out.VideoEnabled = soraAcc.VideoEnabled + out.ImageConcurrency = soraAcc.ImageConcurrency + out.VideoConcurrency = soraAcc.VideoConcurrency + out.IsExpired = soraAcc.IsExpired + out.CreatedAt = soraAcc.CreatedAt + out.UpdatedAt = soraAcc.UpdatedAt + return out +} + func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary { if a == nil { return nil diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 01f39478..f643b8c1 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -46,6 +46,25 @@ type SystemSettings struct { EnableIdentityPatch bool `json:"enable_identity_patch"` IdentityPatchPrompt string `json:"identity_patch_prompt"` + // Sora configuration + SoraBaseURL string `json:"sora_base_url"` + SoraTimeout int `json:"sora_timeout"` + SoraMaxRetries int `json:"sora_max_retries"` + SoraPollInterval float64 `json:"sora_poll_interval"` + SoraCallLogicMode string `json:"sora_call_logic_mode"` + SoraCacheEnabled bool `json:"sora_cache_enabled"` + SoraCacheBaseDir string `json:"sora_cache_base_dir"` + SoraCacheVideoDir string `json:"sora_cache_video_dir"` + SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"` + SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"` + SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"` + SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"` + SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"` + SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"` + SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"` + SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"` + SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"` + // Ops monitoring (vNext) OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"` OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 938d707c..ebe8cdfe 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -141,6 +141,56 @@ type Account struct { Groups []*Group `json:"groups,omitempty"` } +type SoraUsageStat struct { + AccountID int64 `json:"account_id"` + ImageCount int `json:"image_count"` + VideoCount int `json:"video_count"` + ErrorCount int `json:"error_count"` + LastErrorAt *int64 `json:"last_error_at"` + TodayImageCount int `json:"today_image_count"` + TodayVideoCount int `json:"today_video_count"` + TodayErrorCount int `json:"today_error_count"` + TodayDate *int64 `json:"today_date"` + ConsecutiveErrorCount int `json:"consecutive_error_count"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type SoraAccount struct { + AccountID int64 `json:"account_id"` + AccountName string `json:"account_name"` + AccountStatus string `json:"account_status"` + AccountType string `json:"account_type"` + AccountConcurrency int `json:"account_concurrency"` + ProxyID *int64 `json:"proxy_id"` + AccessToken string `json:"access_token"` + SessionToken string `json:"session_token"` + RefreshToken string `json:"refresh_token"` + ClientID string `json:"client_id"` + Email string `json:"email"` + Username string `json:"username"` + Remark string `json:"remark"` + UseCount int `json:"use_count"` + PlanType string `json:"plan_type"` + PlanTitle string `json:"plan_title"` + SubscriptionEnd *int64 `json:"subscription_end"` + SoraSupported bool `json:"sora_supported"` + SoraInviteCode string `json:"sora_invite_code"` + SoraRedeemedCount int `json:"sora_redeemed_count"` + SoraRemainingCount int `json:"sora_remaining_count"` + SoraTotalCount int `json:"sora_total_count"` + SoraCooldownUntil *int64 `json:"sora_cooldown_until"` + CooledUntil *int64 `json:"cooled_until"` + ImageEnabled bool `json:"image_enabled"` + VideoEnabled bool `json:"video_enabled"` + ImageConcurrency int `json:"image_concurrency"` + VideoConcurrency int `json:"video_concurrency"` + IsExpired bool `json:"is_expired"` + Usage *SoraUsageStat `json:"usage,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type AccountGroup struct { AccountID int64 `json:"account_id"` GroupID int64 `json:"group_id"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 70ea51bf..559633f8 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -17,6 +17,7 @@ import ( pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/sora" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -508,6 +509,13 @@ func (h *GatewayHandler) Models(c *gin.Context) { }) return } + if platform == service.PlatformSora { + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": sora.ListModels(), + }) + return + } c.JSON(http.StatusOK, gin.H{ "object": "list", diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 5b1b317d..19a27f17 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -17,6 +17,7 @@ type AdminHandlers struct { Proxy *admin.ProxyHandler Redeem *admin.RedeemHandler Promo *admin.PromoHandler + SoraAccount *admin.SoraAccountHandler Setting *admin.SettingHandler Ops *admin.OpsHandler System *admin.SystemHandler @@ -36,6 +37,7 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler + SoraGateway *SoraGatewayHandler Setting *SettingHandler } diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index f62e6b3e..a24d4333 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -814,6 +814,8 @@ func guessPlatformFromPath(path string) string { return service.PlatformAntigravity case strings.HasPrefix(p, "/v1beta/"): return service.PlatformGemini + case strings.Contains(p, "/chat/completions"): + return service.PlatformSora case strings.Contains(p, "/responses"): return service.PlatformOpenAI default: diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go new file mode 100644 index 00000000..7fbd8a9b --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler.go @@ -0,0 +1,364 @@ +package handler + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/sora" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// SoraGatewayHandler handles Sora OpenAI compatible endpoints. +type SoraGatewayHandler struct { + gatewayService *service.GatewayService + soraGatewayService *service.SoraGatewayService + billingCacheService *service.BillingCacheService + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int +} + +// NewSoraGatewayHandler creates a new SoraGatewayHandler. +func NewSoraGatewayHandler( + gatewayService *service.GatewayService, + soraGatewayService *service.SoraGatewayService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, + cfg *config.Config, +) *SoraGatewayHandler { + pingInterval := time.Duration(0) + maxAccountSwitches := 3 + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + } + return &SoraGatewayHandler{ + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + billingCacheService: billingCacheService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, + } +} + +// ChatCompletions handles Sora OpenAI-compatible chat completions endpoint. +// POST /v1/chat/completions +func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { + apiKey, ok := middleware.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + model, _ := reqBody["model"].(string) + if strings.TrimSpace(model) == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + stream, _ := reqBody["stream"].(bool) + + prompt, imageData, videoData, remixID, err := parseSoraPrompt(reqBody) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return + } + if remixID == "" { + remixID = sora.ExtractRemixID(prompt) + } + if remixID != "" { + prompt = strings.ReplaceAll(prompt, remixID, "") + } + + if apiKey.Group != nil && apiKey.Group.Platform != service.PlatformSora { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "当前分组不支持 Sora 平台") + return + } + + streamStarted := false + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err == nil && canWait { + waitCounted = true + } + if err == nil && !canWait { + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, stream, &streamStarted) + if err != nil { + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + failedAccountIDs := make(map[int64]struct{}) + maxSwitches := h.maxAccountSwitches + if mode := h.soraGatewayService.CallLogicMode(c.Request.Context()); strings.EqualFold(mode, "native") { + maxSwitches = 1 + } + + for switchCount := 0; switchCount < maxSwitches; switchCount++ { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, "", model, failedAccountIDs, "") + if err != nil { + h.errorResponse(c, http.StatusServiceUnavailable, "server_error", err.Error()) + return + } + account := selection.Account + releaseFunc := selection.ReleaseFunc + + result, err := h.soraGatewayService.Generate(c.Request.Context(), account, service.SoraGenerationRequest{ + Model: model, + Prompt: prompt, + Image: imageData, + Video: videoData, + RemixTargetID: remixID, + Stream: stream, + UserID: subject.UserID, + }) + if err != nil { + // 失败路径:立即释放槽位,而非 defer + if releaseFunc != nil { + releaseFunc() + } + + if errors.Is(err, service.ErrSoraAccountMissingToken) || errors.Is(err, service.ErrSoraAccountNotEligible) { + failedAccountIDs[account.ID] = struct{}{} + continue + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "server_error", err.Error(), streamStarted) + return + } + + // 成功路径:使用 defer 在函数退出时释放 + if releaseFunc != nil { + defer releaseFunc() + } + + h.respondCompletion(c, model, result, stream) + return + } + + h.handleFailoverExhausted(c, http.StatusServiceUnavailable, streamStarted) +} + +func (h *SoraGatewayHandler) respondCompletion(c *gin.Context, model string, result *service.SoraGenerationResult, stream bool) { + if result == nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Empty response") + return + } + if stream { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + first := buildSoraStreamChunk(model, "", true, "") + if _, err := c.Writer.WriteString(first); err != nil { + return + } + final := buildSoraStreamChunk(model, result.Content, false, "stop") + if _, err := c.Writer.WriteString(final); err != nil { + return + } + _, _ = c.Writer.WriteString("data: [DONE]\n\n") + return + } + + c.JSON(http.StatusOK, buildSoraNonStreamResponse(model, result.Content)) +} + +func buildSoraStreamChunk(model, content string, isFirst bool, finishReason string) string { + chunkID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()) + delta := map[string]any{} + if isFirst { + delta["role"] = "assistant" + } + if content != "" { + delta["content"] = content + } else { + delta["content"] = nil + } + response := map[string]any{ + "id": chunkID, + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": delta, + "finish_reason": finishReason, + }, + }, + } + payload, _ := json.Marshal(response) + return "data: " + string(payload) + "\n\n" +} + +func buildSoraNonStreamResponse(model, content string) map[string]any { + return map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()), + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }, + }, + } +} + +func parseSoraPrompt(req map[string]any) (prompt, imageData, videoData, remixID string, err error) { + messages, ok := req["messages"].([]any) + if !ok || len(messages) == 0 { + return "", "", "", "", fmt.Errorf("messages is required") + } + last := messages[len(messages)-1] + msg, ok := last.(map[string]any) + if !ok { + return "", "", "", "", fmt.Errorf("invalid message format") + } + content, ok := msg["content"] + if !ok { + return "", "", "", "", fmt.Errorf("content is required") + } + + if v, ok := req["image"].(string); ok && v != "" { + imageData = v + } + if v, ok := req["video"].(string); ok && v != "" { + videoData = v + } + if v, ok := req["remix_target_id"].(string); ok { + remixID = v + } + + switch value := content.(type) { + case string: + prompt = value + case []any: + for _, item := range value { + part, ok := item.(map[string]any) + if !ok { + continue + } + switch part["type"] { + case "text": + if text, ok := part["text"].(string); ok { + prompt = text + } + case "image_url": + if image, ok := part["image_url"].(map[string]any); ok { + if url, ok := image["url"].(string); ok { + imageData = url + } + } + case "video_url": + if video, ok := part["video_url"].(map[string]any); ok { + if url, ok := video["url"].(string); ok { + videoData = url + } + } + } + } + default: + return "", "", "", "", fmt.Errorf("invalid content format") + } + if strings.TrimSpace(prompt) == "" && strings.TrimSpace(videoData) == "" { + return "", "", "", "", fmt.Errorf("prompt is required") + } + return prompt, imageData, videoData, remixID, nil +} + +func looksLikeURL(value string) bool { + trimmed := strings.ToLower(strings.TrimSpace(value)) + return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") +} + +func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { + if streamStarted { + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", err.Error(), true) + return + } + c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()}) +} + +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { + message := "No available Sora accounts" + h.handleStreamingAwareError(c, statusCode, "server_error", message, streamStarted) +} + +func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { + if streamStarted { + payload := map[string]any{"error": map[string]any{"message": message, "type": errType, "param": nil, "code": nil}} + data, _ := json.Marshal(payload) + _, _ = c.Writer.WriteString("data: " + string(data) + "\n\n") + _, _ = c.Writer.WriteString("data: [DONE]\n\n") + return + } + h.errorResponse(c, status, errType, message) +} + +func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "message": message, + "type": errType, + "param": nil, + "code": nil, + }, + }) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 2af7905e..48362c13 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -20,6 +20,7 @@ func ProvideAdminHandlers( proxyHandler *admin.ProxyHandler, redeemHandler *admin.RedeemHandler, promoHandler *admin.PromoHandler, + soraAccountHandler *admin.SoraAccountHandler, settingHandler *admin.SettingHandler, opsHandler *admin.OpsHandler, systemHandler *admin.SystemHandler, @@ -39,6 +40,7 @@ func ProvideAdminHandlers( Proxy: proxyHandler, Redeem: redeemHandler, Promo: promoHandler, + SoraAccount: soraAccountHandler, Setting: settingHandler, Ops: opsHandler, System: systemHandler, @@ -69,6 +71,7 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, + soraGatewayHandler *SoraGatewayHandler, settingHandler *SettingHandler, ) *Handlers { return &Handlers{ @@ -81,6 +84,7 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, + SoraGateway: soraGatewayHandler, Setting: settingHandler, } } @@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet( NewSubscriptionHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewSoraGatewayHandler, ProvideSettingHandler, // Admin handlers @@ -110,6 +115,7 @@ var ProviderSet = wire.NewSet( admin.NewProxyHandler, admin.NewRedeemHandler, admin.NewPromoHandler, + admin.NewSoraAccountHandler, admin.NewSettingHandler, admin.NewOpsHandler, ProvideSystemHandler, diff --git a/backend/internal/pkg/sora/character.go b/backend/internal/pkg/sora/character.go new file mode 100644 index 00000000..eff08712 --- /dev/null +++ b/backend/internal/pkg/sora/character.go @@ -0,0 +1,148 @@ +package sora + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/textproto" +) + +// UploadCharacterVideo uploads a character video and returns cameo ID. +func (c *Client) UploadCharacterVideo(ctx context.Context, opts RequestOptions, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("video data empty") + } + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + if err := writeMultipartFile(writer, "file", "video.mp4", "video/mp4", data); err != nil { + return "", err + } + if err := writer.WriteField("timestamps", "0,3"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/characters/upload", opts, &buf, writer.FormDataContentType(), false) + if err != nil { + return "", err + } + return stringFromJSON(resp, "id"), nil +} + +// GetCameoStatus returns cameo processing status. +func (c *Client) GetCameoStatus(ctx context.Context, opts RequestOptions, cameoID string) (map[string]any, error) { + if cameoID == "" { + return nil, errors.New("cameo id empty") + } + return c.doRequest(ctx, "GET", "/project_y/cameos/in_progress/"+cameoID, opts, nil, "", false) +} + +// DownloadCharacterImage downloads character avatar image data. +func (c *Client) DownloadCharacterImage(ctx context.Context, opts RequestOptions, imageURL string) ([]byte, error) { + if c.upstream == nil { + return nil, errors.New("upstream is nil") + } + req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", defaultDesktopUA) + resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("download image failed: %d", resp.StatusCode) + } + return io.ReadAll(resp.Body) +} + +// UploadCharacterImage uploads character avatar and returns asset pointer. +func (c *Client) UploadCharacterImage(ctx context.Context, opts RequestOptions, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("image data empty") + } + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + if err := writeMultipartFile(writer, "file", "profile.webp", "image/webp", data); err != nil { + return "", err + } + if err := writer.WriteField("use_case", "profile"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/project_y/file/upload", opts, &buf, writer.FormDataContentType(), false) + if err != nil { + return "", err + } + return stringFromJSON(resp, "asset_pointer"), nil +} + +// FinalizeCharacter finalizes character creation and returns character ID. +func (c *Client) FinalizeCharacter(ctx context.Context, opts RequestOptions, cameoID, username, displayName, assetPointer string) (string, error) { + payload := map[string]any{ + "cameo_id": cameoID, + "username": username, + "display_name": displayName, + "profile_asset_pointer": assetPointer, + "instruction_set": nil, + "safety_instruction_set": nil, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/characters/finalize", opts, bytes.NewReader(body), "application/json", false) + if err != nil { + return "", err + } + if character, ok := resp["character"].(map[string]any); ok { + if id, ok := character["character_id"].(string); ok { + return id, nil + } + } + return "", nil +} + +// SetCharacterPublic marks character as public. +func (c *Client) SetCharacterPublic(ctx context.Context, opts RequestOptions, cameoID string) error { + payload := map[string]any{"visibility": "public"} + body, err := json.Marshal(payload) + if err != nil { + return err + } + _, err = c.doRequest(ctx, "POST", "/project_y/cameos/by_id/"+cameoID+"/update_v2", opts, bytes.NewReader(body), "application/json", false) + return err +} + +// DeleteCharacter deletes a character by ID. +func (c *Client) DeleteCharacter(ctx context.Context, opts RequestOptions, characterID string) error { + if characterID == "" { + return nil + } + _, err := c.doRequest(ctx, "DELETE", "/project_y/characters/"+characterID, opts, nil, "", false) + return err +} + +func writeMultipartFile(writer *multipart.Writer, field, filename, contentType string, data []byte) error { + header := make(textproto.MIMEHeader) + header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, field, filename)) + if contentType != "" { + header.Set("Content-Type", contentType) + } + part, err := writer.CreatePart(header) + if err != nil { + return err + } + _, err = part.Write(data) + return err +} diff --git a/backend/internal/pkg/sora/client.go b/backend/internal/pkg/sora/client.go new file mode 100644 index 00000000..01398d0d --- /dev/null +++ b/backend/internal/pkg/sora/client.go @@ -0,0 +1,612 @@ +package sora + +import ( + "bytes" + "context" + "crypto/sha3" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +const ( + chatGPTBaseURL = "https://chatgpt.com" + sentinelFlow = "sora_2_create_task" + maxAPIResponseSize = 1 * 1024 * 1024 // 1MB +) + +var ( + defaultMobileUA = "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)" + defaultDesktopUA = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" + sentinelCache sync.Map // 包级缓存,存储 Sentinel Token,key 为 accountID +) + +// sentinelCacheEntry 是 Sentinel Token 缓存条目 +type sentinelCacheEntry struct { + token string + expiresAt time.Time +} + +// UpstreamClient defines the HTTP client interface for Sora requests. +type UpstreamClient interface { + Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) + DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) +} + +// Client is a minimal Sora API client. +type Client struct { + baseURL string + timeout time.Duration + upstream UpstreamClient + enableTLSFingerprint bool +} + +// RequestOptions configures per-request context. +type RequestOptions struct { + AccountID int64 + AccountConcurrency int + ProxyURL string + AccessToken string +} + +// getCachedSentinel 从缓存中获取 Sentinel Token +func getCachedSentinel(accountID int64) (string, bool) { + v, ok := sentinelCache.Load(accountID) + if !ok { + return "", false + } + entry := v.(*sentinelCacheEntry) + if time.Now().After(entry.expiresAt) { + sentinelCache.Delete(accountID) + return "", false + } + return entry.token, true +} + +// cacheSentinel 缓存 Sentinel Token +func cacheSentinel(accountID int64, token string) { + sentinelCache.Store(accountID, &sentinelCacheEntry{ + token: token, + expiresAt: time.Now().Add(3 * time.Minute), // 3分钟有效期 + }) +} + +// NewClient creates a Sora client. +func NewClient(baseURL string, timeout time.Duration, upstream UpstreamClient, enableTLSFingerprint bool) *Client { + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + timeout: timeout, + upstream: upstream, + enableTLSFingerprint: enableTLSFingerprint, + } +} + +// UploadImage uploads an image and returns media ID. +func (c *Client) UploadImage(ctx context.Context, opts RequestOptions, data []byte, filename string) (string, error) { + if filename == "" { + filename = "image.png" + } + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("file_name", filename); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/uploads", opts, &buf, writer.FormDataContentType(), false) + if err != nil { + return "", err + } + return stringFromJSON(resp, "id"), nil +} + +// GenerateImage creates an image generation task. +func (c *Client) GenerateImage(ctx context.Context, opts RequestOptions, prompt string, width, height int, mediaID string) (string, error) { + operation := "simple_compose" + var inpaint []map[string]any + if mediaID != "" { + operation = "remix" + inpaint = []map[string]any{ + { + "type": "image", + "frame_index": 0, + "upload_media_id": mediaID, + }, + } + } + payload := map[string]any{ + "type": "image_gen", + "operation": operation, + "prompt": prompt, + "width": width, + "height": height, + "n_variants": 1, + "n_frames": 1, + "inpaint_items": inpaint, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/video_gen", opts, bytes.NewReader(body), "application/json", true) + if err != nil { + return "", err + } + return stringFromJSON(resp, "id"), nil +} + +// GenerateVideo creates a video generation task. +func (c *Client) GenerateVideo(ctx context.Context, opts RequestOptions, prompt, orientation string, nFrames int, mediaID, styleID, model, size string) (string, error) { + var inpaint []map[string]any + if mediaID != "" { + inpaint = []map[string]any{{"kind": "upload", "upload_id": mediaID}} + } + payload := map[string]any{ + "kind": "video", + "prompt": prompt, + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "model": model, + "inpaint_items": inpaint, + "style_id": styleID, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/nf/create", opts, bytes.NewReader(body), "application/json", true) + if err != nil { + return "", err + } + return stringFromJSON(resp, "id"), nil +} + +// GenerateStoryboard creates a storyboard video task. +func (c *Client) GenerateStoryboard(ctx context.Context, opts RequestOptions, prompt, orientation string, nFrames int, mediaID, styleID string) (string, error) { + var inpaint []map[string]any + if mediaID != "" { + inpaint = []map[string]any{{"kind": "upload", "upload_id": mediaID}} + } + payload := map[string]any{ + "kind": "video", + "prompt": prompt, + "title": "Draft your video", + "orientation": orientation, + "size": "small", + "n_frames": nFrames, + "storyboard_id": nil, + "inpaint_items": inpaint, + "remix_target_id": nil, + "model": "sy_8", + "metadata": nil, + "style_id": styleID, + "cameo_ids": nil, + "cameo_replacements": nil, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/nf/create/storyboard", opts, bytes.NewReader(body), "application/json", true) + if err != nil { + return "", err + } + return stringFromJSON(resp, "id"), nil +} + +// RemixVideo creates a remix task. +func (c *Client) RemixVideo(ctx context.Context, opts RequestOptions, remixTargetID, prompt, orientation string, nFrames int, styleID string) (string, error) { + payload := map[string]any{ + "kind": "video", + "prompt": prompt, + "inpaint_items": []map[string]any{}, + "remix_target_id": remixTargetID, + "cameo_ids": []string{}, + "cameo_replacements": map[string]any{}, + "model": "sy_8", + "orientation": orientation, + "n_frames": nFrames, + "style_id": styleID, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/nf/create", opts, bytes.NewReader(body), "application/json", true) + if err != nil { + return "", err + } + return stringFromJSON(resp, "id"), nil +} + +// GetImageTasks returns recent image tasks. +func (c *Client) GetImageTasks(ctx context.Context, opts RequestOptions) (map[string]any, error) { + return c.doRequest(ctx, "GET", "/v2/recent_tasks?limit=20", opts, nil, "", false) +} + +// GetPendingTasks returns pending video tasks. +func (c *Client) GetPendingTasks(ctx context.Context, opts RequestOptions) ([]map[string]any, error) { + resp, err := c.doRequestAny(ctx, "GET", "/nf/pending/v2", opts, nil, "", false) + if err != nil { + return nil, err + } + switch v := resp.(type) { + case []any: + return convertList(v), nil + case map[string]any: + if list, ok := v["items"].([]any); ok { + return convertList(list), nil + } + if arr, ok := v["data"].([]any); ok { + return convertList(arr), nil + } + return convertListFromAny(v), nil + default: + return nil, nil + } +} + +// GetVideoDrafts returns recent video drafts. +func (c *Client) GetVideoDrafts(ctx context.Context, opts RequestOptions) (map[string]any, error) { + return c.doRequest(ctx, "GET", "/project_y/profile/drafts?limit=15", opts, nil, "", false) +} + +// EnhancePrompt calls prompt enhancement API. +func (c *Client) EnhancePrompt(ctx context.Context, opts RequestOptions, prompt, expansionLevel string, durationS int) (string, error) { + payload := map[string]any{ + "prompt": prompt, + "expansion_level": expansionLevel, + "duration_s": durationS, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/editor/enhance_prompt", opts, bytes.NewReader(body), "application/json", false) + if err != nil { + return "", err + } + return stringFromJSON(resp, "enhanced_prompt"), nil +} + +// PostVideoForWatermarkFree publishes a video for watermark-free parsing. +func (c *Client) PostVideoForWatermarkFree(ctx context.Context, opts RequestOptions, generationID string) (string, error) { + payload := map[string]any{ + "attachments_to_create": []map[string]any{{ + "generation_id": generationID, + "kind": "sora", + }}, + "post_text": "", + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + resp, err := c.doRequest(ctx, "POST", "/project_y/post", opts, bytes.NewReader(body), "application/json", true) + if err != nil { + return "", err + } + post, _ := resp["post"].(map[string]any) + if post == nil { + return "", nil + } + return stringFromJSON(post, "id"), nil +} + +// DeletePost deletes a Sora post. +func (c *Client) DeletePost(ctx context.Context, opts RequestOptions, postID string) error { + if postID == "" { + return nil + } + _, err := c.doRequest(ctx, "DELETE", "/project_y/post/"+postID, opts, nil, "", false) + return err +} + +func (c *Client) doRequest(ctx context.Context, method, endpoint string, opts RequestOptions, body io.Reader, contentType string, addSentinel bool) (map[string]any, error) { + resp, err := c.doRequestAny(ctx, method, endpoint, opts, body, contentType, addSentinel) + if err != nil { + return nil, err + } + parsed, ok := resp.(map[string]any) + if !ok { + return nil, errors.New("unexpected response format") + } + return parsed, nil +} + +func (c *Client) doRequestAny(ctx context.Context, method, endpoint string, opts RequestOptions, body io.Reader, contentType string, addSentinel bool) (any, error) { + if c.upstream == nil { + return nil, errors.New("upstream is nil") + } + url := c.baseURL + endpoint + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + if opts.AccessToken != "" { + req.Header.Set("Authorization", "Bearer "+opts.AccessToken) + } + req.Header.Set("User-Agent", defaultMobileUA) + if addSentinel { + sentinel, err := c.generateSentinelToken(ctx, opts) + if err != nil { + return nil, err + } + req.Header.Set("openai-sentinel-token", sentinel) + } + resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // 使用 LimitReader 限制最大响应大小,防止 DoS 攻击 + limitedReader := io.LimitReader(resp.Body, maxAPIResponseSize+1) + data, err := io.ReadAll(limitedReader) + if err != nil { + return nil, err + } + + // 检查是否超过大小限制 + if int64(len(data)) > maxAPIResponseSize { + return nil, fmt.Errorf("API 响应过大 (最大 %d 字节)", maxAPIResponseSize) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("sora api error: %d %s", resp.StatusCode, strings.TrimSpace(string(data))) + } + if len(data) == 0 { + return map[string]any{}, nil + } + var parsed any + if err := json.Unmarshal(data, &parsed); err != nil { + return nil, err + } + return parsed, nil +} + +func (c *Client) generateSentinelToken(ctx context.Context, opts RequestOptions) (string, error) { + // 尝试从缓存获取 + if token, ok := getCachedSentinel(opts.AccountID); ok { + return token, nil + } + + reqID := uuid.New().String() + powToken, err := generatePowToken(defaultDesktopUA) + if err != nil { + return "", err + } + payload := map[string]any{"p": powToken, "flow": sentinelFlow, "id": reqID} + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + url := chatGPTBaseURL + "/backend-api/sentinel/req" + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", defaultDesktopUA) + if opts.AccessToken != "" { + req.Header.Set("Authorization", "Bearer "+opts.AccessToken) + } + resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // 使用 LimitReader 限制最大响应大小,防止 DoS 攻击 + limitedReader := io.LimitReader(resp.Body, maxAPIResponseSize+1) + data, err := io.ReadAll(limitedReader) + if err != nil { + return "", err + } + + // 检查是否超过大小限制 + if int64(len(data)) > maxAPIResponseSize { + return "", fmt.Errorf("API 响应过大 (最大 %d 字节)", maxAPIResponseSize) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("sentinel request failed: %d %s", resp.StatusCode, strings.TrimSpace(string(data))) + } + var parsed map[string]any + if err := json.Unmarshal(data, &parsed); err != nil { + return "", err + } + token := buildSentinelToken(reqID, powToken, parsed) + + // 缓存结果 + cacheSentinel(opts.AccountID, token) + + return token, nil +} + +func buildSentinelToken(reqID, powToken string, resp map[string]any) string { + finalPow := powToken + pow, _ := resp["proofofwork"].(map[string]any) + if pow != nil { + required, _ := pow["required"].(bool) + if required { + seed, _ := pow["seed"].(string) + difficulty, _ := pow["difficulty"].(string) + if seed != "" && difficulty != "" { + candidate, _ := solvePow(seed, difficulty, defaultDesktopUA) + if candidate != "" { + finalPow = "gAAAAAB" + candidate + } + } + } + } + if !strings.HasSuffix(finalPow, "~S") { + finalPow += "~S" + } + turnstile := "" + if t, ok := resp["turnstile"].(map[string]any); ok { + turnstile, _ = t["dx"].(string) + } + token := "" + if v, ok := resp["token"].(string); ok { + token = v + } + payload := map[string]any{ + "p": finalPow, + "t": turnstile, + "c": token, + "id": reqID, + "flow": sentinelFlow, + } + data, _ := json.Marshal(payload) + return string(data) +} + +func generatePowToken(userAgent string) (string, error) { + seed := fmt.Sprintf("%f", float64(time.Now().UnixNano())/1e9) + candidate, _ := solvePow(seed, "0fffff", userAgent) + if candidate == "" { + return "", errors.New("pow generation failed") + } + return "gAAAAAC" + candidate, nil +} + +func solvePow(seed, difficulty, userAgent string) (string, bool) { + config := powConfig(userAgent) + seedBytes := []byte(seed) + diffBytes, err := hexDecode(difficulty) + if err != nil { + return "", false + } + configBytes, err := json.Marshal(config) + if err != nil { + return "", false + } + prefix := configBytes[:len(configBytes)-1] + for i := 0; i < 500000; i++ { + payload := append(prefix, []byte(fmt.Sprintf(",%d,%d]", i, i>>1))...) + b64 := base64.StdEncoding.EncodeToString(payload) + h := sha3.Sum512(append(seedBytes, []byte(b64)...)) + if bytes.Compare(h[:len(diffBytes)], diffBytes) <= 0 { + return b64, true + } + } + return "", false +} + +func powConfig(userAgent string) []any { + return []any{ + 3000, + formatPowTime(), + 4294705152, + 0, + userAgent, + "", + nil, + "en-US", + "en-US,es-US,en,es", + 0, + "webdriver-false", + "location", + "window", + time.Now().UnixMilli(), + uuid.New().String(), + "", + 16, + float64(time.Now().UnixMilli()), + } +} + +func formatPowTime() string { + loc := time.FixedZone("EST", -5*60*60) + return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05") + " GMT-0500 (Eastern Standard Time)" +} + +func hexDecode(s string) ([]byte, error) { + if len(s)%2 != 0 { + return nil, errors.New("invalid hex length") + } + out := make([]byte, len(s)/2) + for i := 0; i < len(out); i++ { + byteVal, err := hexPair(s[i*2 : i*2+2]) + if err != nil { + return nil, err + } + out[i] = byteVal + } + return out, nil +} + +func hexPair(pair string) (byte, error) { + var v byte + for i := 0; i < 2; i++ { + c := pair[i] + var n byte + switch { + case c >= '0' && c <= '9': + n = c - '0' + case c >= 'a' && c <= 'f': + n = c - 'a' + 10 + case c >= 'A' && c <= 'F': + n = c - 'A' + 10 + default: + return 0, errors.New("invalid hex") + } + v = v<<4 | n + } + return v, nil +} + +func stringFromJSON(data map[string]any, key string) string { + if data == nil { + return "" + } + if v, ok := data[key].(string); ok { + return v + } + return "" +} + +func convertList(list []any) []map[string]any { + results := make([]map[string]any, 0, len(list)) + for _, item := range list { + if m, ok := item.(map[string]any); ok { + results = append(results, m) + } + } + return results +} + +func convertListFromAny(data map[string]any) []map[string]any { + if data == nil { + return nil + } + items, ok := data["items"].([]any) + if ok { + return convertList(items) + } + return nil +} diff --git a/backend/internal/pkg/sora/models.go b/backend/internal/pkg/sora/models.go new file mode 100644 index 00000000..925d0c91 --- /dev/null +++ b/backend/internal/pkg/sora/models.go @@ -0,0 +1,263 @@ +package sora + +// ModelConfig 定义 Sora 模型配置。 +type ModelConfig struct { + Type string + Width int + Height int + Orientation string + NFrames int + Model string + Size string + RequirePro bool + ExpansionLevel string + DurationS int +} + +// ModelConfigs 定义所有模型配置。 +var ModelConfigs = map[string]ModelConfig{ + "gpt-image": { + Type: "image", + Width: 360, + Height: 360, + }, + "gpt-image-landscape": { + Type: "image", + Width: 540, + Height: 360, + }, + "gpt-image-portrait": { + Type: "image", + Width: 360, + Height: 540, + }, + "sora2-landscape-10s": { + Type: "video", + Orientation: "landscape", + NFrames: 300, + }, + "sora2-portrait-10s": { + Type: "video", + Orientation: "portrait", + NFrames: 300, + }, + "sora2-landscape-15s": { + Type: "video", + Orientation: "landscape", + NFrames: 450, + }, + "sora2-portrait-15s": { + Type: "video", + Orientation: "portrait", + NFrames: 450, + }, + "sora2-landscape-25s": { + Type: "video", + Orientation: "landscape", + NFrames: 750, + Model: "sy_8", + Size: "small", + RequirePro: true, + }, + "sora2-portrait-25s": { + Type: "video", + Orientation: "portrait", + NFrames: 750, + Model: "sy_8", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-10s": { + Type: "video", + Orientation: "landscape", + NFrames: 300, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-10s": { + Type: "video", + Orientation: "portrait", + NFrames: 300, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-15s": { + Type: "video", + Orientation: "landscape", + NFrames: 450, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-15s": { + Type: "video", + Orientation: "portrait", + NFrames: 450, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-25s": { + Type: "video", + Orientation: "landscape", + NFrames: 750, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-25s": { + Type: "video", + Orientation: "portrait", + NFrames: 750, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-hd-landscape-10s": { + Type: "video", + Orientation: "landscape", + NFrames: 300, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-portrait-10s": { + Type: "video", + Orientation: "portrait", + NFrames: 300, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-landscape-15s": { + Type: "video", + Orientation: "landscape", + NFrames: 450, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-portrait-15s": { + Type: "video", + Orientation: "portrait", + NFrames: 450, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "prompt-enhance-short-10s": { + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 10, + }, + "prompt-enhance-short-15s": { + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 15, + }, + "prompt-enhance-short-20s": { + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 20, + }, + "prompt-enhance-medium-10s": { + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 10, + }, + "prompt-enhance-medium-15s": { + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 15, + }, + "prompt-enhance-medium-20s": { + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 20, + }, + "prompt-enhance-long-10s": { + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 10, + }, + "prompt-enhance-long-15s": { + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 15, + }, + "prompt-enhance-long-20s": { + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 20, + }, +} + +// ModelListItem 返回模型列表条目。 +type ModelListItem struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Description string `json:"description"` +} + +// ListModels 生成模型列表。 +func ListModels() []ModelListItem { + models := make([]ModelListItem, 0, len(ModelConfigs)) + for id, cfg := range ModelConfigs { + description := "" + switch cfg.Type { + case "image": + description = "Image generation" + if cfg.Width > 0 && cfg.Height > 0 { + description += " - " + itoa(cfg.Width) + "x" + itoa(cfg.Height) + } + case "video": + description = "Video generation" + if cfg.Orientation != "" { + description += " - " + cfg.Orientation + } + case "prompt_enhance": + description = "Prompt enhancement" + if cfg.ExpansionLevel != "" { + description += " - " + cfg.ExpansionLevel + } + if cfg.DurationS > 0 { + description += " (" + itoa(cfg.DurationS) + "s)" + } + default: + description = "Sora model" + } + models = append(models, ModelListItem{ + ID: id, + Object: "model", + OwnedBy: "sora", + Description: description, + }) + } + return models +} + +func itoa(val int) string { + if val == 0 { + return "0" + } + neg := false + if val < 0 { + neg = true + val = -val + } + buf := [12]byte{} + i := len(buf) + for val > 0 { + i-- + buf[i] = byte('0' + val%10) + val /= 10 + } + if neg { + i-- + buf[i] = '-' + } + return string(buf[i:]) +} diff --git a/backend/internal/pkg/sora/prompt.go b/backend/internal/pkg/sora/prompt.go new file mode 100644 index 00000000..5134a264 --- /dev/null +++ b/backend/internal/pkg/sora/prompt.go @@ -0,0 +1,63 @@ +package sora + +import ( + "regexp" + "strings" +) + +var storyboardRe = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]`) + +// IsStoryboardPrompt 检测是否为分镜提示词。 +func IsStoryboardPrompt(prompt string) bool { + if strings.TrimSpace(prompt) == "" { + return false + } + return storyboardRe.MatchString(prompt) +} + +// FormatStoryboardPrompt 将分镜提示词转换为 API 需要的格式。 +func FormatStoryboardPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return prompt + } + matches := storyboardRe.FindAllStringSubmatchIndex(prompt, -1) + if len(matches) == 0 { + return prompt + } + firstIdx := matches[0][0] + instructions := strings.TrimSpace(prompt[:firstIdx]) + + shotPattern := regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) + shotMatches := shotPattern.FindAllStringSubmatch(prompt, -1) + if len(shotMatches) == 0 { + return prompt + } + + shots := make([]string, 0, len(shotMatches)) + for i, sm := range shotMatches { + if len(sm) < 3 { + continue + } + duration := strings.TrimSpace(sm[1]) + scene := strings.TrimSpace(sm[2]) + shots = append(shots, "Shot "+itoa(i+1)+":\nduration: "+duration+"sec\nScene: "+scene) + } + + timeline := strings.Join(shots, "\n\n") + if instructions != "" { + return "current timeline:\n" + timeline + "\n\ninstructions:\n" + instructions + } + return timeline +} + +// ExtractRemixID 提取分享链接中的 remix ID。 +func ExtractRemixID(text string) string { + text = strings.TrimSpace(text) + if text == "" { + return "" + } + re := regexp.MustCompile(`s_[a-f0-9]{32}`) + match := re.FindString(text) + return match +} diff --git a/backend/internal/pkg/uuidv7/uuidv7.go b/backend/internal/pkg/uuidv7/uuidv7.go new file mode 100644 index 00000000..67136774 --- /dev/null +++ b/backend/internal/pkg/uuidv7/uuidv7.go @@ -0,0 +1,31 @@ +package uuidv7 + +import ( + "crypto/rand" + "fmt" + "time" +) + +// New returns a UUIDv7 string. +func New() (string, error) { + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + ms := uint64(time.Now().UnixMilli()) + b[0] = byte(ms >> 40) + b[1] = byte(ms >> 32) + b[2] = byte(ms >> 24) + b[3] = byte(ms >> 16) + b[4] = byte(ms >> 8) + b[5] = byte(ms) + b[6] = (b[6] & 0x0f) | 0x70 + b[8] = (b[8] & 0x3f) | 0x80 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + uint32(b[0])<<24|uint32(b[1])<<16|uint32(b[2])<<8|uint32(b[3]), + uint16(b[4])<<8|uint16(b[5]), + uint16(b[6])<<8|uint16(b[7]), + uint16(b[8])<<8|uint16(b[9]), + uint64(b[10])<<40|uint64(b[11])<<32|uint64(b[12])<<24|uint64(b[13])<<16|uint64(b[14])<<8|uint64(b[15]), + ), nil +} diff --git a/backend/internal/repository/sora_repo.go b/backend/internal/repository/sora_repo.go new file mode 100644 index 00000000..2fe633f8 --- /dev/null +++ b/backend/internal/repository/sora_repo.go @@ -0,0 +1,498 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/Wei-Shaw/sub2api/ent" + dbsoraaccount "github.com/Wei-Shaw/sub2api/ent/soraaccount" + dbsoracachefile "github.com/Wei-Shaw/sub2api/ent/soracachefile" + dbsoratask "github.com/Wei-Shaw/sub2api/ent/soratask" + dbsorausagestat "github.com/Wei-Shaw/sub2api/ent/sorausagestat" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + + entsql "entgo.io/ent/dialect/sql" +) + +// SoraAccount + +type soraAccountRepository struct { + client *ent.Client +} + +func NewSoraAccountRepository(client *ent.Client) service.SoraAccountRepository { + return &soraAccountRepository{client: client} +} + +func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { + if accountID <= 0 { + return nil, nil + } + acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx) + if err != nil { + if ent.IsNotFound(err) { + return nil, nil + } + return nil, err + } + return mapSoraAccount(acc), nil +} + +func (r *soraAccountRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraAccount, error) { + if len(accountIDs) == 0 { + return map[int64]*service.SoraAccount{}, nil + } + records, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDIn(accountIDs...)).All(ctx) + if err != nil { + return nil, err + } + result := make(map[int64]*service.SoraAccount, len(records)) + for _, record := range records { + if record == nil { + continue + } + result[record.AccountID] = mapSoraAccount(record) + } + return result, nil +} + +func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { + if accountID <= 0 { + return errors.New("invalid account_id") + } + acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx) + if err != nil && !ent.IsNotFound(err) { + return err + } + if acc == nil { + builder := r.client.SoraAccount.Create().SetAccountID(accountID) + applySoraAccountUpdates(builder.Mutation(), updates) + return builder.Exec(ctx) + } + updater := r.client.SoraAccount.UpdateOneID(acc.ID) + applySoraAccountUpdates(updater.Mutation(), updates) + return updater.Exec(ctx) +} + +func applySoraAccountUpdates(m *ent.SoraAccountMutation, updates map[string]any) { + if updates == nil { + return + } + for key, val := range updates { + switch key { + case "access_token": + if v, ok := val.(string); ok { + m.SetAccessToken(v) + } + case "session_token": + if v, ok := val.(string); ok { + m.SetSessionToken(v) + } + case "refresh_token": + if v, ok := val.(string); ok { + m.SetRefreshToken(v) + } + case "client_id": + if v, ok := val.(string); ok { + m.SetClientID(v) + } + case "email": + if v, ok := val.(string); ok { + m.SetEmail(v) + } + case "username": + if v, ok := val.(string); ok { + m.SetUsername(v) + } + case "remark": + if v, ok := val.(string); ok { + m.SetRemark(v) + } + case "plan_type": + if v, ok := val.(string); ok { + m.SetPlanType(v) + } + case "plan_title": + if v, ok := val.(string); ok { + m.SetPlanTitle(v) + } + case "subscription_end": + if v, ok := val.(time.Time); ok { + m.SetSubscriptionEnd(v) + } + if v, ok := val.(*time.Time); ok && v != nil { + m.SetSubscriptionEnd(*v) + } + case "sora_supported": + if v, ok := val.(bool); ok { + m.SetSoraSupported(v) + } + case "sora_invite_code": + if v, ok := val.(string); ok { + m.SetSoraInviteCode(v) + } + case "sora_redeemed_count": + if v, ok := val.(int); ok { + m.SetSoraRedeemedCount(v) + } + case "sora_remaining_count": + if v, ok := val.(int); ok { + m.SetSoraRemainingCount(v) + } + case "sora_total_count": + if v, ok := val.(int); ok { + m.SetSoraTotalCount(v) + } + case "sora_cooldown_until": + if v, ok := val.(time.Time); ok { + m.SetSoraCooldownUntil(v) + } + if v, ok := val.(*time.Time); ok && v != nil { + m.SetSoraCooldownUntil(*v) + } + case "cooled_until": + if v, ok := val.(time.Time); ok { + m.SetCooledUntil(v) + } + if v, ok := val.(*time.Time); ok && v != nil { + m.SetCooledUntil(*v) + } + case "image_enabled": + if v, ok := val.(bool); ok { + m.SetImageEnabled(v) + } + case "video_enabled": + if v, ok := val.(bool); ok { + m.SetVideoEnabled(v) + } + case "image_concurrency": + if v, ok := val.(int); ok { + m.SetImageConcurrency(v) + } + case "video_concurrency": + if v, ok := val.(int); ok { + m.SetVideoConcurrency(v) + } + case "is_expired": + if v, ok := val.(bool); ok { + m.SetIsExpired(v) + } + } + } +} + +func mapSoraAccount(acc *ent.SoraAccount) *service.SoraAccount { + if acc == nil { + return nil + } + return &service.SoraAccount{ + AccountID: acc.AccountID, + AccessToken: derefString(acc.AccessToken), + SessionToken: derefString(acc.SessionToken), + RefreshToken: derefString(acc.RefreshToken), + ClientID: derefString(acc.ClientID), + Email: derefString(acc.Email), + Username: derefString(acc.Username), + Remark: derefString(acc.Remark), + UseCount: acc.UseCount, + PlanType: derefString(acc.PlanType), + PlanTitle: derefString(acc.PlanTitle), + SubscriptionEnd: acc.SubscriptionEnd, + SoraSupported: acc.SoraSupported, + SoraInviteCode: derefString(acc.SoraInviteCode), + SoraRedeemedCount: acc.SoraRedeemedCount, + SoraRemainingCount: acc.SoraRemainingCount, + SoraTotalCount: acc.SoraTotalCount, + SoraCooldownUntil: acc.SoraCooldownUntil, + CooledUntil: acc.CooledUntil, + ImageEnabled: acc.ImageEnabled, + VideoEnabled: acc.VideoEnabled, + ImageConcurrency: acc.ImageConcurrency, + VideoConcurrency: acc.VideoConcurrency, + IsExpired: acc.IsExpired, + CreatedAt: acc.CreatedAt, + UpdatedAt: acc.UpdatedAt, + } +} + +func mapSoraUsageStat(stat *ent.SoraUsageStat) *service.SoraUsageStat { + if stat == nil { + return nil + } + return &service.SoraUsageStat{ + AccountID: stat.AccountID, + ImageCount: stat.ImageCount, + VideoCount: stat.VideoCount, + ErrorCount: stat.ErrorCount, + LastErrorAt: stat.LastErrorAt, + TodayImageCount: stat.TodayImageCount, + TodayVideoCount: stat.TodayVideoCount, + TodayErrorCount: stat.TodayErrorCount, + TodayDate: stat.TodayDate, + ConsecutiveErrorCount: stat.ConsecutiveErrorCount, + CreatedAt: stat.CreatedAt, + UpdatedAt: stat.UpdatedAt, + } +} + +func mapSoraCacheFile(file *ent.SoraCacheFile) *service.SoraCacheFile { + if file == nil { + return nil + } + return &service.SoraCacheFile{ + ID: int64(file.ID), + TaskID: derefString(file.TaskID), + AccountID: file.AccountID, + UserID: file.UserID, + MediaType: file.MediaType, + OriginalURL: file.OriginalURL, + CachePath: file.CachePath, + CacheURL: file.CacheURL, + SizeBytes: file.SizeBytes, + CreatedAt: file.CreatedAt, + } +} + +// SoraUsageStat + +type soraUsageStatRepository struct { + client *ent.Client + sql sqlExecutor +} + +func NewSoraUsageStatRepository(client *ent.Client, sqlDB *sql.DB) service.SoraUsageStatRepository { + return &soraUsageStatRepository{client: client, sql: sqlDB} +} + +func (r *soraUsageStatRepository) RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error { + if accountID <= 0 { + return nil + } + field := "image_count" + todayField := "today_image_count" + if isVideo { + field = "video_count" + todayField = "today_video_count" + } + today := time.Now().UTC().Truncate(24 * time.Hour) + query := "INSERT INTO sora_usage_stats (account_id, " + field + ", " + todayField + ", today_date, consecutive_error_count, created_at, updated_at) " + + "VALUES ($1, 1, 1, $2, 0, NOW(), NOW()) " + + "ON CONFLICT (account_id) DO UPDATE SET " + + field + " = sora_usage_stats." + field + " + 1, " + + todayField + " = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats." + todayField + " + 1 ELSE 1 END, " + + "today_date = $2, consecutive_error_count = 0, updated_at = NOW()" + _, err := r.sql.ExecContext(ctx, query, accountID, today) + return err +} + +func (r *soraUsageStatRepository) RecordError(ctx context.Context, accountID int64) (int, error) { + if accountID <= 0 { + return 0, nil + } + today := time.Now().UTC().Truncate(24 * time.Hour) + query := "INSERT INTO sora_usage_stats (account_id, error_count, today_error_count, today_date, consecutive_error_count, last_error_at, created_at, updated_at) " + + "VALUES ($1, 1, 1, $2, 1, NOW(), NOW(), NOW()) " + + "ON CONFLICT (account_id) DO UPDATE SET " + + "error_count = sora_usage_stats.error_count + 1, " + + "today_error_count = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats.today_error_count + 1 ELSE 1 END, " + + "today_date = $2, consecutive_error_count = sora_usage_stats.consecutive_error_count + 1, last_error_at = NOW(), updated_at = NOW() " + + "RETURNING consecutive_error_count" + var consecutive int + err := scanSingleRow(ctx, r.sql, query, []any{accountID, today}, &consecutive) + if err != nil { + return 0, err + } + return consecutive, nil +} + +func (r *soraUsageStatRepository) ResetConsecutiveErrors(ctx context.Context, accountID int64) error { + if accountID <= 0 { + return nil + } + err := r.client.SoraUsageStat.Update().Where(dbsorausagestat.AccountIDEQ(accountID)). + SetConsecutiveErrorCount(0). + Exec(ctx) + if ent.IsNotFound(err) { + return nil + } + return err +} + +func (r *soraUsageStatRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraUsageStat, error) { + if accountID <= 0 { + return nil, nil + } + stat, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDEQ(accountID)).Only(ctx) + if err != nil { + if ent.IsNotFound(err) { + return nil, nil + } + return nil, err + } + return mapSoraUsageStat(stat), nil +} + +func (r *soraUsageStatRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraUsageStat, error) { + if len(accountIDs) == 0 { + return map[int64]*service.SoraUsageStat{}, nil + } + stats, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDIn(accountIDs...)).All(ctx) + if err != nil { + return nil, err + } + result := make(map[int64]*service.SoraUsageStat, len(stats)) + for _, stat := range stats { + if stat == nil { + continue + } + result[stat.AccountID] = mapSoraUsageStat(stat) + } + return result, nil +} + +func (r *soraUsageStatRepository) List(ctx context.Context, params pagination.PaginationParams) ([]*service.SoraUsageStat, *pagination.PaginationResult, error) { + query := r.client.SoraUsageStat.Query() + total, err := query.Count(ctx) + if err != nil { + return nil, nil, err + } + stats, err := query.Order(ent.Desc(dbsorausagestat.FieldUpdatedAt)). + Limit(params.Limit()). + Offset(params.Offset()). + All(ctx) + if err != nil { + return nil, nil, err + } + result := make([]*service.SoraUsageStat, 0, len(stats)) + for _, stat := range stats { + result = append(result, mapSoraUsageStat(stat)) + } + return result, paginationResultFromTotal(int64(total), params), nil +} + +// SoraTask + +type soraTaskRepository struct { + client *ent.Client +} + +func NewSoraTaskRepository(client *ent.Client) service.SoraTaskRepository { + return &soraTaskRepository{client: client} +} + +func (r *soraTaskRepository) Create(ctx context.Context, task *service.SoraTask) error { + if task == nil { + return nil + } + builder := r.client.SoraTask.Create(). + SetTaskID(task.TaskID). + SetAccountID(task.AccountID). + SetModel(task.Model). + SetPrompt(task.Prompt). + SetStatus(task.Status). + SetProgress(task.Progress). + SetRetryCount(task.RetryCount) + if task.ResultURLs != "" { + builder.SetResultUrls(task.ResultURLs) + } + if task.ErrorMessage != "" { + builder.SetErrorMessage(task.ErrorMessage) + } + if task.CreatedAt.IsZero() { + builder.SetCreatedAt(time.Now()) + } else { + builder.SetCreatedAt(task.CreatedAt) + } + if task.CompletedAt != nil { + builder.SetCompletedAt(*task.CompletedAt) + } + return builder.Exec(ctx) +} + +func (r *soraTaskRepository) UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error { + if taskID == "" { + return nil + } + builder := r.client.SoraTask.Update().Where(dbsoratask.TaskIDEQ(taskID)). + SetStatus(status). + SetProgress(progress) + if resultURLs != "" { + builder.SetResultUrls(resultURLs) + } + if errorMessage != "" { + builder.SetErrorMessage(errorMessage) + } + if completedAt != nil { + builder.SetCompletedAt(*completedAt) + } + _, err := builder.Save(ctx) + if ent.IsNotFound(err) { + return nil + } + return err +} + +// SoraCacheFile + +type soraCacheFileRepository struct { + client *ent.Client +} + +func NewSoraCacheFileRepository(client *ent.Client) service.SoraCacheFileRepository { + return &soraCacheFileRepository{client: client} +} + +func (r *soraCacheFileRepository) Create(ctx context.Context, file *service.SoraCacheFile) error { + if file == nil { + return nil + } + builder := r.client.SoraCacheFile.Create(). + SetAccountID(file.AccountID). + SetUserID(file.UserID). + SetMediaType(file.MediaType). + SetOriginalURL(file.OriginalURL). + SetCachePath(file.CachePath). + SetCacheURL(file.CacheURL). + SetSizeBytes(file.SizeBytes) + if file.TaskID != "" { + builder.SetTaskID(file.TaskID) + } + if file.CreatedAt.IsZero() { + builder.SetCreatedAt(time.Now()) + } else { + builder.SetCreatedAt(file.CreatedAt) + } + return builder.Exec(ctx) +} + +func (r *soraCacheFileRepository) ListOldest(ctx context.Context, limit int) ([]*service.SoraCacheFile, error) { + if limit <= 0 { + return []*service.SoraCacheFile{}, nil + } + records, err := r.client.SoraCacheFile.Query(). + Order(dbsoracachefile.ByCreatedAt(entsql.OrderAsc())). + Limit(limit). + All(ctx) + if err != nil { + return nil, err + } + result := make([]*service.SoraCacheFile, 0, len(records)) + for _, record := range records { + if record == nil { + continue + } + result = append(result, mapSoraCacheFile(record)) + } + return result, nil +} + +func (r *soraCacheFileRepository) DeleteByIDs(ctx context.Context, ids []int64) error { + if len(ids) == 0 { + return nil + } + _, err := r.client.SoraCacheFile.Delete().Where(dbsoracachefile.IDIn(ids...)).Exec(ctx) + return err +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 7a8d85f4..68168bb4 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -64,6 +64,10 @@ var ProviderSet = wire.NewSet( NewUserSubscriptionRepository, NewUserAttributeDefinitionRepository, NewUserAttributeValueRepository, + NewSoraAccountRepository, + NewSoraUsageStatRepository, + NewSoraTaskRepository, + NewSoraCacheFileRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index cf9015e4..e179f758 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,7 +1,10 @@ package server import ( + "context" "log" + "path/filepath" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -46,6 +49,22 @@ func SetupRouter( } } + // Serve Sora cached videos when enabled + cacheVideoDir := "" + cacheEnabled := false + if settingService != nil { + soraCfg := settingService.GetSoraConfig(context.Background()) + cacheEnabled = soraCfg.Cache.Enabled + cacheVideoDir = strings.TrimSpace(soraCfg.Cache.VideoDir) + } else if cfg != nil { + cacheEnabled = cfg.Sora.Cache.Enabled + cacheVideoDir = strings.TrimSpace(cfg.Sora.Cache.VideoDir) + } + if cacheEnabled && cacheVideoDir != "" { + videoDir := filepath.Clean(cacheVideoDir) + r.Static("/data/video", videoDir) + } + // 注册路由 registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 050e724d..c2248219 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -29,6 +29,9 @@ func RegisterAdminRoutes( // 账号管理 registerAccountRoutes(admin, h) + // Sora 账号扩展 + registerSoraRoutes(admin, h) + // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) @@ -229,6 +232,17 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerSoraRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + sora := admin.Group("/sora") + { + sora.GET("/accounts", h.Admin.SoraAccount.List) + sora.GET("/accounts/:id", h.Admin.SoraAccount.Get) + sora.PUT("/accounts/:id", h.Admin.SoraAccount.Upsert) + sora.POST("/accounts/import", h.Admin.SoraAccount.BatchUpsert) + sora.GET("/usage", h.Admin.SoraAccount.ListUsage) + } +} + func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { openai := admin.Group("/openai") { diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index bf019ce3..6dc39b3b 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -33,6 +33,7 @@ func RegisterGatewayRoutes( gateway.POST("/messages", h.Gateway.Messages) gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) gateway.GET("/models", h.Gateway.Models) + gateway.POST("/chat/completions", h.SoraGateway.ChatCompletions) gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 3bb63ffa..8a0d6f3a 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -22,6 +22,7 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" + PlatformSora = "sora" ) // Account type constants @@ -124,6 +125,28 @@ const ( SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyIdentityPatchPrompt = "identity_patch_prompt" + // ========================= + // Sora Settings + // ========================= + + SettingKeySoraBaseURL = "sora_base_url" + SettingKeySoraTimeout = "sora_timeout" + SettingKeySoraMaxRetries = "sora_max_retries" + SettingKeySoraPollInterval = "sora_poll_interval" + SettingKeySoraCallLogicMode = "sora_call_logic_mode" + SettingKeySoraCacheEnabled = "sora_cache_enabled" + SettingKeySoraCacheBaseDir = "sora_cache_base_dir" + SettingKeySoraCacheVideoDir = "sora_cache_video_dir" + SettingKeySoraCacheMaxBytes = "sora_cache_max_bytes" + SettingKeySoraCacheAllowedHosts = "sora_cache_allowed_hosts" + SettingKeySoraCacheUserDirEnabled = "sora_cache_user_dir_enabled" + SettingKeySoraWatermarkFreeEnabled = "sora_watermark_free_enabled" + SettingKeySoraWatermarkFreeParseMethod = "sora_watermark_free_parse_method" + SettingKeySoraWatermarkFreeCustomParseURL = "sora_watermark_free_custom_parse_url" + SettingKeySoraWatermarkFreeCustomParseToken = "sora_watermark_free_custom_parse_token" + SettingKeySoraWatermarkFreeFallbackOnFailure = "sora_watermark_free_fallback_on_failure" + SettingKeySoraTokenRefreshEnabled = "sora_token_refresh_enabled" + // ========================= // Ops Monitoring (vNext) // ========================= diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index b3714ed1..9d2619c8 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -378,7 +378,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI if len(groupIDs) == 0 { return nil } - platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} + platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformSora, PlatformAntigravity} var firstErr error for _, platform := range platforms { if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil { @@ -661,7 +661,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration { func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) { buckets := make([]SchedulerBucket, 0) - platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} + platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformSora, PlatformAntigravity} for _, platform := range platforms { buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle}) buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced}) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index d77dd30d..4716c051 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -219,6 +219,29 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch) updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt + // Sora settings + updates[SettingKeySoraBaseURL] = strings.TrimSpace(settings.SoraBaseURL) + updates[SettingKeySoraTimeout] = strconv.Itoa(settings.SoraTimeout) + updates[SettingKeySoraMaxRetries] = strconv.Itoa(settings.SoraMaxRetries) + updates[SettingKeySoraPollInterval] = strconv.FormatFloat(settings.SoraPollInterval, 'f', -1, 64) + updates[SettingKeySoraCallLogicMode] = settings.SoraCallLogicMode + updates[SettingKeySoraCacheEnabled] = strconv.FormatBool(settings.SoraCacheEnabled) + updates[SettingKeySoraCacheBaseDir] = settings.SoraCacheBaseDir + updates[SettingKeySoraCacheVideoDir] = settings.SoraCacheVideoDir + updates[SettingKeySoraCacheMaxBytes] = strconv.FormatInt(settings.SoraCacheMaxBytes, 10) + allowedHostsRaw, err := marshalStringSliceSetting(settings.SoraCacheAllowedHosts) + if err != nil { + return fmt.Errorf("marshal sora cache allowed hosts: %w", err) + } + updates[SettingKeySoraCacheAllowedHosts] = allowedHostsRaw + updates[SettingKeySoraCacheUserDirEnabled] = strconv.FormatBool(settings.SoraCacheUserDirEnabled) + updates[SettingKeySoraWatermarkFreeEnabled] = strconv.FormatBool(settings.SoraWatermarkFreeEnabled) + updates[SettingKeySoraWatermarkFreeParseMethod] = settings.SoraWatermarkFreeParseMethod + updates[SettingKeySoraWatermarkFreeCustomParseURL] = strings.TrimSpace(settings.SoraWatermarkFreeCustomParseURL) + updates[SettingKeySoraWatermarkFreeCustomParseToken] = settings.SoraWatermarkFreeCustomParseToken + updates[SettingKeySoraWatermarkFreeFallbackOnFailure] = strconv.FormatBool(settings.SoraWatermarkFreeFallbackOnFailure) + updates[SettingKeySoraTokenRefreshEnabled] = strconv.FormatBool(settings.SoraTokenRefreshEnabled) + // Ops monitoring (vNext) updates[SettingKeyOpsMonitoringEnabled] = strconv.FormatBool(settings.OpsMonitoringEnabled) updates[SettingKeyOpsRealtimeMonitoringEnabled] = strconv.FormatBool(settings.OpsRealtimeMonitoringEnabled) @@ -227,7 +250,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) } - err := s.settingRepo.SetMultiple(ctx, updates) + err = s.settingRepo.SetMultiple(ctx, updates) if err == nil && s.onUpdate != nil { s.onUpdate() // Invalidate cache after settings update } @@ -295,6 +318,41 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } +// GetSoraConfig 获取 Sora 配置(优先读取 DB 设置,回退 config.yaml) +func (s *SettingService) GetSoraConfig(ctx context.Context) config.SoraConfig { + base := config.SoraConfig{} + if s.cfg != nil { + base = s.cfg.Sora + } + if s.settingRepo == nil { + return base + } + keys := []string{ + SettingKeySoraBaseURL, + SettingKeySoraTimeout, + SettingKeySoraMaxRetries, + SettingKeySoraPollInterval, + SettingKeySoraCallLogicMode, + SettingKeySoraCacheEnabled, + SettingKeySoraCacheBaseDir, + SettingKeySoraCacheVideoDir, + SettingKeySoraCacheMaxBytes, + SettingKeySoraCacheAllowedHosts, + SettingKeySoraCacheUserDirEnabled, + SettingKeySoraWatermarkFreeEnabled, + SettingKeySoraWatermarkFreeParseMethod, + SettingKeySoraWatermarkFreeCustomParseURL, + SettingKeySoraWatermarkFreeCustomParseToken, + SettingKeySoraWatermarkFreeFallbackOnFailure, + SettingKeySoraTokenRefreshEnabled, + } + values, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return base + } + return mergeSoraConfig(base, values) +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -308,6 +366,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { } // 初始化默认设置 + soraCfg := config.SoraConfig{} + if s.cfg != nil { + soraCfg = s.cfg.Sora + } + allowedHostsRaw, _ := marshalStringSliceSetting(soraCfg.Cache.AllowedHosts) + defaults := map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "false", @@ -328,6 +392,25 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyEnableIdentityPatch: "true", SettingKeyIdentityPatchPrompt: "", + // Sora defaults + SettingKeySoraBaseURL: soraCfg.BaseURL, + SettingKeySoraTimeout: strconv.Itoa(soraCfg.Timeout), + SettingKeySoraMaxRetries: strconv.Itoa(soraCfg.MaxRetries), + SettingKeySoraPollInterval: strconv.FormatFloat(soraCfg.PollInterval, 'f', -1, 64), + SettingKeySoraCallLogicMode: soraCfg.CallLogicMode, + SettingKeySoraCacheEnabled: strconv.FormatBool(soraCfg.Cache.Enabled), + SettingKeySoraCacheBaseDir: soraCfg.Cache.BaseDir, + SettingKeySoraCacheVideoDir: soraCfg.Cache.VideoDir, + SettingKeySoraCacheMaxBytes: strconv.FormatInt(soraCfg.Cache.MaxBytes, 10), + SettingKeySoraCacheAllowedHosts: allowedHostsRaw, + SettingKeySoraCacheUserDirEnabled: strconv.FormatBool(soraCfg.Cache.UserDirEnabled), + SettingKeySoraWatermarkFreeEnabled: strconv.FormatBool(soraCfg.WatermarkFree.Enabled), + SettingKeySoraWatermarkFreeParseMethod: soraCfg.WatermarkFree.ParseMethod, + SettingKeySoraWatermarkFreeCustomParseURL: soraCfg.WatermarkFree.CustomParseURL, + SettingKeySoraWatermarkFreeCustomParseToken: soraCfg.WatermarkFree.CustomParseToken, + SettingKeySoraWatermarkFreeFallbackOnFailure: strconv.FormatBool(soraCfg.WatermarkFree.FallbackOnFailure), + SettingKeySoraTokenRefreshEnabled: strconv.FormatBool(soraCfg.TokenRefresh.Enabled), + // Ops monitoring defaults (vNext) SettingKeyOpsMonitoringEnabled: "true", SettingKeyOpsRealtimeMonitoringEnabled: "true", @@ -434,6 +517,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt] + // Sora settings + soraCfg := s.parseSoraConfig(settings) + result.SoraBaseURL = soraCfg.BaseURL + result.SoraTimeout = soraCfg.Timeout + result.SoraMaxRetries = soraCfg.MaxRetries + result.SoraPollInterval = soraCfg.PollInterval + result.SoraCallLogicMode = soraCfg.CallLogicMode + result.SoraCacheEnabled = soraCfg.Cache.Enabled + result.SoraCacheBaseDir = soraCfg.Cache.BaseDir + result.SoraCacheVideoDir = soraCfg.Cache.VideoDir + result.SoraCacheMaxBytes = soraCfg.Cache.MaxBytes + result.SoraCacheAllowedHosts = soraCfg.Cache.AllowedHosts + result.SoraCacheUserDirEnabled = soraCfg.Cache.UserDirEnabled + result.SoraWatermarkFreeEnabled = soraCfg.WatermarkFree.Enabled + result.SoraWatermarkFreeParseMethod = soraCfg.WatermarkFree.ParseMethod + result.SoraWatermarkFreeCustomParseURL = soraCfg.WatermarkFree.CustomParseURL + result.SoraWatermarkFreeCustomParseToken = soraCfg.WatermarkFree.CustomParseToken + result.SoraWatermarkFreeFallbackOnFailure = soraCfg.WatermarkFree.FallbackOnFailure + result.SoraTokenRefreshEnabled = soraCfg.TokenRefresh.Enabled + // Ops monitoring settings (default: enabled, fail-open) result.OpsMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsMonitoringEnabled]) result.OpsRealtimeMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsRealtimeMonitoringEnabled]) @@ -471,6 +574,131 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def return defaultValue } +func (s *SettingService) parseSoraConfig(settings map[string]string) config.SoraConfig { + base := config.SoraConfig{} + if s.cfg != nil { + base = s.cfg.Sora + } + return mergeSoraConfig(base, settings) +} + +func mergeSoraConfig(base config.SoraConfig, settings map[string]string) config.SoraConfig { + cfg := base + if settings == nil { + return cfg + } + if raw, ok := settings[SettingKeySoraBaseURL]; ok { + if trimmed := strings.TrimSpace(raw); trimmed != "" { + cfg.BaseURL = trimmed + } + } + if raw, ok := settings[SettingKeySoraTimeout]; ok { + if v, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil && v > 0 { + cfg.Timeout = v + } + } + if raw, ok := settings[SettingKeySoraMaxRetries]; ok { + if v, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil && v >= 0 { + cfg.MaxRetries = v + } + } + if raw, ok := settings[SettingKeySoraPollInterval]; ok { + if v, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && v > 0 { + cfg.PollInterval = v + } + } + if raw, ok := settings[SettingKeySoraCallLogicMode]; ok && strings.TrimSpace(raw) != "" { + cfg.CallLogicMode = strings.TrimSpace(raw) + } + if raw, ok := settings[SettingKeySoraCacheEnabled]; ok { + cfg.Cache.Enabled = parseBoolSetting(raw, cfg.Cache.Enabled) + } + if raw, ok := settings[SettingKeySoraCacheBaseDir]; ok && strings.TrimSpace(raw) != "" { + cfg.Cache.BaseDir = strings.TrimSpace(raw) + } + if raw, ok := settings[SettingKeySoraCacheVideoDir]; ok && strings.TrimSpace(raw) != "" { + cfg.Cache.VideoDir = strings.TrimSpace(raw) + } + if raw, ok := settings[SettingKeySoraCacheMaxBytes]; ok { + if v, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64); err == nil && v >= 0 { + cfg.Cache.MaxBytes = v + } + } + if raw, ok := settings[SettingKeySoraCacheAllowedHosts]; ok { + cfg.Cache.AllowedHosts = parseStringSliceSetting(raw) + } + if raw, ok := settings[SettingKeySoraCacheUserDirEnabled]; ok { + cfg.Cache.UserDirEnabled = parseBoolSetting(raw, cfg.Cache.UserDirEnabled) + } + if raw, ok := settings[SettingKeySoraWatermarkFreeEnabled]; ok { + cfg.WatermarkFree.Enabled = parseBoolSetting(raw, cfg.WatermarkFree.Enabled) + } + if raw, ok := settings[SettingKeySoraWatermarkFreeParseMethod]; ok && strings.TrimSpace(raw) != "" { + cfg.WatermarkFree.ParseMethod = strings.TrimSpace(raw) + } + if raw, ok := settings[SettingKeySoraWatermarkFreeCustomParseURL]; ok && strings.TrimSpace(raw) != "" { + cfg.WatermarkFree.CustomParseURL = strings.TrimSpace(raw) + } + if raw, ok := settings[SettingKeySoraWatermarkFreeCustomParseToken]; ok { + cfg.WatermarkFree.CustomParseToken = raw + } + if raw, ok := settings[SettingKeySoraWatermarkFreeFallbackOnFailure]; ok { + cfg.WatermarkFree.FallbackOnFailure = parseBoolSetting(raw, cfg.WatermarkFree.FallbackOnFailure) + } + if raw, ok := settings[SettingKeySoraTokenRefreshEnabled]; ok { + cfg.TokenRefresh.Enabled = parseBoolSetting(raw, cfg.TokenRefresh.Enabled) + } + return cfg +} + +func parseBoolSetting(raw string, fallback bool) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return fallback + } + if v, err := strconv.ParseBool(trimmed); err == nil { + return v + } + return fallback +} + +func parseStringSliceSetting(raw string) []string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return []string{} + } + var values []string + if err := json.Unmarshal([]byte(trimmed), &values); err == nil { + return normalizeStringSlice(values) + } + parts := strings.FieldsFunc(trimmed, func(r rune) bool { + return r == ',' || r == '\n' || r == ';' + }) + return normalizeStringSlice(parts) +} + +func marshalStringSliceSetting(values []string) (string, error) { + normalized := normalizeStringSlice(values) + data, err := json.Marshal(normalized) + if err != nil { + return "", err + } + return string(data), nil +} + +func normalizeStringSlice(values []string) []string { + if len(values) == 0 { + return []string{} + } + normalized := make([]string, 0, len(values)) + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + normalized = append(normalized, trimmed) + } + } + return normalized +} + // IsTurnstileEnabled 检查是否启用 Turnstile 验证 func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled) diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 919344e5..b20a474d 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -49,6 +49,25 @@ type SystemSettings struct { EnableIdentityPatch bool `json:"enable_identity_patch"` IdentityPatchPrompt string `json:"identity_patch_prompt"` + // Sora configuration + SoraBaseURL string + SoraTimeout int + SoraMaxRetries int + SoraPollInterval float64 + SoraCallLogicMode string + SoraCacheEnabled bool + SoraCacheBaseDir string + SoraCacheVideoDir string + SoraCacheMaxBytes int64 + SoraCacheAllowedHosts []string + SoraCacheUserDirEnabled bool + SoraWatermarkFreeEnabled bool + SoraWatermarkFreeParseMethod string + SoraWatermarkFreeCustomParseURL string + SoraWatermarkFreeCustomParseToken string + SoraWatermarkFreeFallbackOnFailure bool + SoraTokenRefreshEnabled bool + // Ops monitoring (vNext) OpsMonitoringEnabled bool OpsRealtimeMonitoringEnabled bool diff --git a/backend/internal/service/sora_cache_cleanup_service.go b/backend/internal/service/sora_cache_cleanup_service.go new file mode 100644 index 00000000..7ba1dc77 --- /dev/null +++ b/backend/internal/service/sora_cache_cleanup_service.go @@ -0,0 +1,156 @@ +package service + +import ( + "context" + "log" + "os" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + soraCacheCleanupInterval = time.Hour + soraCacheCleanupBatch = 200 +) + +// SoraCacheCleanupService 负责清理 Sora 视频缓存文件。 +type SoraCacheCleanupService struct { + cacheRepo SoraCacheFileRepository + settingService *SettingService + cfg *config.Config + stopCh chan struct{} + stopOnce sync.Once +} + +func NewSoraCacheCleanupService(cacheRepo SoraCacheFileRepository, settingService *SettingService, cfg *config.Config) *SoraCacheCleanupService { + return &SoraCacheCleanupService{ + cacheRepo: cacheRepo, + settingService: settingService, + cfg: cfg, + stopCh: make(chan struct{}), + } +} + +func (s *SoraCacheCleanupService) Start() { + if s == nil || s.cacheRepo == nil { + return + } + go s.cleanupLoop() +} + +func (s *SoraCacheCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + }) +} + +func (s *SoraCacheCleanupService) cleanupLoop() { + ticker := time.NewTicker(soraCacheCleanupInterval) + defer ticker.Stop() + + s.cleanupOnce() + for { + select { + case <-ticker.C: + s.cleanupOnce() + case <-s.stopCh: + return + } + } +} + +func (s *SoraCacheCleanupService) cleanupOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + if s.cacheRepo == nil { + return + } + + cfg := s.getSoraConfig(ctx) + videoDir := strings.TrimSpace(cfg.Cache.VideoDir) + if videoDir == "" { + return + } + maxBytes := cfg.Cache.MaxBytes + if maxBytes <= 0 { + return + } + + size, err := dirSize(videoDir) + if err != nil { + log.Printf("[SoraCacheCleanup] 计算目录大小失败: %v", err) + return + } + if size <= maxBytes { + return + } + + for size > maxBytes { + entries, err := s.cacheRepo.ListOldest(ctx, soraCacheCleanupBatch) + if err != nil { + log.Printf("[SoraCacheCleanup] 读取缓存记录失败: %v", err) + return + } + if len(entries) == 0 { + log.Printf("[SoraCacheCleanup] 无缓存记录但目录仍超限: size=%d max=%d", size, maxBytes) + return + } + + ids := make([]int64, 0, len(entries)) + for _, entry := range entries { + if entry == nil { + continue + } + removedSize := entry.SizeBytes + if entry.CachePath != "" { + if info, err := os.Stat(entry.CachePath); err == nil { + if removedSize <= 0 { + removedSize = info.Size() + } + } + if err := os.Remove(entry.CachePath); err != nil && !os.IsNotExist(err) { + log.Printf("[SoraCacheCleanup] 删除缓存文件失败: path=%s err=%v", entry.CachePath, err) + } + } + + if entry.ID > 0 { + ids = append(ids, entry.ID) + } + if removedSize > 0 { + size -= removedSize + if size < 0 { + size = 0 + } + } + } + + if len(ids) > 0 { + if err := s.cacheRepo.DeleteByIDs(ctx, ids); err != nil { + log.Printf("[SoraCacheCleanup] 删除缓存记录失败: %v", err) + } + } + + if size > maxBytes { + if refreshed, err := dirSize(videoDir); err == nil { + size = refreshed + } + } + } +} + +func (s *SoraCacheCleanupService) getSoraConfig(ctx context.Context) config.SoraConfig { + if s.settingService != nil { + return s.settingService.GetSoraConfig(ctx) + } + if s.cfg != nil { + return s.cfg.Sora + } + return config.SoraConfig{} +} diff --git a/backend/internal/service/sora_cache_service.go b/backend/internal/service/sora_cache_service.go new file mode 100644 index 00000000..304a8d40 --- /dev/null +++ b/backend/internal/service/sora_cache_service.go @@ -0,0 +1,246 @@ +package service + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/uuidv7" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" +) + +// SoraCacheService 提供 Sora 视频缓存能力。 +type SoraCacheService struct { + cfg *config.Config + cacheRepo SoraCacheFileRepository + settingService *SettingService + accountRepo AccountRepository + httpUpstream HTTPUpstream +} + +// NewSoraCacheService 创建 SoraCacheService。 +func NewSoraCacheService(cfg *config.Config, cacheRepo SoraCacheFileRepository, settingService *SettingService, accountRepo AccountRepository, httpUpstream HTTPUpstream) *SoraCacheService { + return &SoraCacheService{ + cfg: cfg, + cacheRepo: cacheRepo, + settingService: settingService, + accountRepo: accountRepo, + httpUpstream: httpUpstream, + } +} + +func (s *SoraCacheService) CacheVideo(ctx context.Context, accountID, userID int64, taskID, mediaURL string) (*SoraCacheFile, error) { + cfg := s.getSoraConfig(ctx) + if !cfg.Cache.Enabled { + return nil, nil + } + trimmed := strings.TrimSpace(mediaURL) + if trimmed == "" { + return nil, nil + } + + allowedHosts := cfg.Cache.AllowedHosts + useAllowlist := true + if len(allowedHosts) == 0 { + if s.cfg != nil { + allowedHosts = s.cfg.Security.URLAllowlist.UpstreamHosts + useAllowlist = s.cfg.Security.URLAllowlist.Enabled + } else { + useAllowlist = false + } + } + + if useAllowlist { + if _, err := urlvalidator.ValidateHTTPSURL(trimmed, urlvalidator.ValidationOptions{ + AllowedHosts: allowedHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg != nil && s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }); err != nil { + return nil, fmt.Errorf("缓存下载地址不合法: %w", err) + } + } else { + allowInsecure := false + if s.cfg != nil { + allowInsecure = s.cfg.Security.URLAllowlist.AllowInsecureHTTP + } + if _, err := urlvalidator.ValidateURLFormat(trimmed, allowInsecure); err != nil { + return nil, fmt.Errorf("缓存下载地址不合法: %w", err) + } + } + + videoDir := strings.TrimSpace(cfg.Cache.VideoDir) + if videoDir == "" { + return nil, nil + } + + if cfg.Cache.MaxBytes > 0 { + size, err := dirSize(videoDir) + if err != nil { + return nil, err + } + if size >= cfg.Cache.MaxBytes { + return nil, nil + } + } + + relativeDir := "" + if cfg.Cache.UserDirEnabled && userID > 0 { + relativeDir = fmt.Sprintf("u_%d", userID) + } + + targetDir := filepath.Join(videoDir, relativeDir) + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return nil, err + } + + uuid, err := uuidv7.New() + if err != nil { + return nil, err + } + + name := deriveFileName(trimmed) + if name == "" { + name = "video.mp4" + } + name = sanitizeFileName(name) + filename := uuid + "_" + name + cachePath := filepath.Join(targetDir, filename) + + resp, err := s.downloadMedia(ctx, accountID, trimmed, time.Duration(cfg.Timeout)*time.Second) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("缓存下载失败: %d", resp.StatusCode) + } + + out, err := os.Create(cachePath) + if err != nil { + return nil, err + } + defer out.Close() + + written, err := io.Copy(out, resp.Body) + if err != nil { + return nil, err + } + + cacheURL := buildCacheURL(relativeDir, filename) + + record := &SoraCacheFile{ + TaskID: taskID, + AccountID: accountID, + UserID: userID, + MediaType: "video", + OriginalURL: trimmed, + CachePath: cachePath, + CacheURL: cacheURL, + SizeBytes: written, + CreatedAt: time.Now(), + } + if s.cacheRepo != nil { + if err := s.cacheRepo.Create(ctx, record); err != nil { + return nil, err + } + } + return record, nil +} + +func buildCacheURL(relativeDir, filename string) string { + base := "/data/video" + if relativeDir != "" { + return path.Join(base, relativeDir, filename) + } + return path.Join(base, filename) +} + +func (s *SoraCacheService) getSoraConfig(ctx context.Context) config.SoraConfig { + if s.settingService != nil { + return s.settingService.GetSoraConfig(ctx) + } + if s.cfg != nil { + return s.cfg.Sora + } + return config.SoraConfig{} +} + +func (s *SoraCacheService) downloadMedia(ctx context.Context, accountID int64, mediaURL string, timeout time.Duration) (*http.Response, error) { + if timeout <= 0 { + timeout = 120 * time.Second + } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", mediaURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36") + + if s.httpUpstream == nil { + client := &http.Client{Timeout: timeout} + return client.Do(req) + } + + var accountConcurrency int + proxyURL := "" + if s.accountRepo != nil && accountID > 0 { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account != nil { + accountConcurrency = account.Concurrency + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + } + } + enableTLS := false + if s.cfg != nil { + enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled + } + return s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS) +} + +func deriveFileName(rawURL string) string { + parsed, err := url.Parse(rawURL) + if err != nil { + return "" + } + name := path.Base(parsed.Path) + if name == "/" || name == "." { + return "" + } + return name +} + +func sanitizeFileName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + sanitized := strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= 'A' && r <= 'Z': + return r + case r >= '0' && r <= '9': + return r + case r == '-' || r == '_' || r == '.': + return r + case r == ' ': // 空格替换为下划线 + return '_' + default: + return -1 + } + }, name) + return strings.TrimLeft(sanitized, ".") +} diff --git a/backend/internal/service/sora_cache_utils.go b/backend/internal/service/sora_cache_utils.go new file mode 100644 index 00000000..d52a861e --- /dev/null +++ b/backend/internal/service/sora_cache_utils.go @@ -0,0 +1,28 @@ +package service + +import ( + "os" + "path/filepath" +) + +func dirSize(root string) (int64, error) { + var size int64 + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + info, err := d.Info() + if err != nil { + return err + } + size += info.Size() + return nil + }) + if err != nil && os.IsNotExist(err) { + return 0, nil + } + return size, err +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go new file mode 100644 index 00000000..69da1879 --- /dev/null +++ b/backend/internal/service/sora_gateway_service.go @@ -0,0 +1,853 @@ +package service + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/sora" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" +) + +const ( + soraErrorDisableThreshold = 5 + maxImageDownloadSize = 20 * 1024 * 1024 // 20MB + maxVideoDownloadSize = 200 * 1024 * 1024 // 200MB +) + +var ( + ErrSoraAccountMissingToken = errors.New("sora account missing access token") + ErrSoraAccountNotEligible = errors.New("sora account not eligible") +) + +// SoraGenerationRequest 表示 Sora 生成请求。 +type SoraGenerationRequest struct { + Model string + Prompt string + Image string + Video string + RemixTargetID string + Stream bool + UserID int64 +} + +// SoraGenerationResult 表示 Sora 生成结果。 +type SoraGenerationResult struct { + Content string + MediaType string + ResultURLs []string + TaskID string +} + +// SoraGatewayService 处理 Sora 生成流程。 +type SoraGatewayService struct { + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository + usageRepo SoraUsageStatRepository + taskRepo SoraTaskRepository + cacheService *SoraCacheService + settingService *SettingService + concurrency *ConcurrencyService + cfg *config.Config + httpUpstream HTTPUpstream +} + +// NewSoraGatewayService 创建 SoraGatewayService。 +func NewSoraGatewayService( + accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, + usageRepo SoraUsageStatRepository, + taskRepo SoraTaskRepository, + cacheService *SoraCacheService, + settingService *SettingService, + concurrencyService *ConcurrencyService, + cfg *config.Config, + httpUpstream HTTPUpstream, +) *SoraGatewayService { + return &SoraGatewayService{ + accountRepo: accountRepo, + soraAccountRepo: soraAccountRepo, + usageRepo: usageRepo, + taskRepo: taskRepo, + cacheService: cacheService, + settingService: settingService, + concurrency: concurrencyService, + cfg: cfg, + httpUpstream: httpUpstream, + } +} + +// ListModels 返回 Sora 模型列表。 +func (s *SoraGatewayService) ListModels() []sora.ModelListItem { + return sora.ListModels() +} + +// Generate 执行 Sora 生成流程。 +func (s *SoraGatewayService) Generate(ctx context.Context, account *Account, req SoraGenerationRequest) (*SoraGenerationResult, error) { + client, cfg := s.getClient(ctx) + if client == nil { + return nil, errors.New("sora client is not configured") + } + modelCfg, ok := sora.ModelConfigs[req.Model] + if !ok { + return nil, fmt.Errorf("unsupported model: %s", req.Model) + } + accessToken, soraAcc, err := s.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + if soraAcc != nil && soraAcc.SoraCooldownUntil != nil && time.Now().Before(*soraAcc.SoraCooldownUntil) { + return nil, ErrSoraAccountNotEligible + } + if modelCfg.RequirePro && !isSoraProAccount(soraAcc) { + return nil, ErrSoraAccountNotEligible + } + if modelCfg.Type == "video" && soraAcc != nil { + if !soraAcc.VideoEnabled || !soraAcc.SoraSupported || soraAcc.IsExpired { + return nil, ErrSoraAccountNotEligible + } + } + if modelCfg.Type == "image" && soraAcc != nil { + if !soraAcc.ImageEnabled || soraAcc.IsExpired { + return nil, ErrSoraAccountNotEligible + } + } + + opts := sora.RequestOptions{ + AccountID: account.ID, + AccountConcurrency: account.Concurrency, + AccessToken: accessToken, + } + if account.Proxy != nil { + opts.ProxyURL = account.Proxy.URL() + } + + releaseFunc, err := s.acquireSoraSlots(ctx, account, soraAcc, modelCfg.Type == "video") + if err != nil { + return nil, err + } + if releaseFunc != nil { + defer releaseFunc() + } + + if modelCfg.Type == "prompt_enhance" { + content, err := client.EnhancePrompt(ctx, opts, req.Prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) + if err != nil { + return nil, err + } + return &SoraGenerationResult{Content: content, MediaType: "text"}, nil + } + + var mediaID string + if req.Image != "" { + data, err := s.loadImageBytes(ctx, opts, req.Image) + if err != nil { + return nil, err + } + mediaID, err = client.UploadImage(ctx, opts, data, "image.png") + if err != nil { + return nil, err + } + } + if req.Video != "" && modelCfg.Type != "video" { + return nil, errors.New("视频输入仅支持视频模型") + } + if req.Video != "" && req.Image != "" { + return nil, errors.New("不能同时传入 image 与 video") + } + + var cleanupCharacter func() + if req.Video != "" && req.RemixTargetID == "" { + username, characterID, err := s.createCharacter(ctx, client, opts, req.Video) + if err != nil { + return nil, err + } + if strings.TrimSpace(req.Prompt) == "" { + return &SoraGenerationResult{ + Content: fmt.Sprintf("角色创建成功,角色名@%s", username), + MediaType: "text", + }, nil + } + if username != "" { + req.Prompt = fmt.Sprintf("@%s %s", username, strings.TrimSpace(req.Prompt)) + } + if characterID != "" { + cleanupCharacter = func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + _ = client.DeleteCharacter(ctx, opts, characterID) + } + } + } + if cleanupCharacter != nil { + defer cleanupCharacter() + } + + var taskID string + if modelCfg.Type == "image" { + taskID, err = client.GenerateImage(ctx, opts, req.Prompt, modelCfg.Width, modelCfg.Height, mediaID) + } else { + orientation := modelCfg.Orientation + if orientation == "" { + orientation = "landscape" + } + modelName := modelCfg.Model + if modelName == "" { + modelName = "sy_8" + } + size := modelCfg.Size + if size == "" { + size = "small" + } + if req.RemixTargetID != "" { + taskID, err = client.RemixVideo(ctx, opts, req.RemixTargetID, req.Prompt, orientation, modelCfg.NFrames, "") + } else if sora.IsStoryboardPrompt(req.Prompt) { + formatted := sora.FormatStoryboardPrompt(req.Prompt) + taskID, err = client.GenerateStoryboard(ctx, opts, formatted, orientation, modelCfg.NFrames, mediaID, "") + } else { + taskID, err = client.GenerateVideo(ctx, opts, req.Prompt, orientation, modelCfg.NFrames, mediaID, "", modelName, size) + } + } + if err != nil { + return nil, err + } + + if s.taskRepo != nil { + _ = s.taskRepo.Create(ctx, &SoraTask{ + TaskID: taskID, + AccountID: account.ID, + Model: req.Model, + Prompt: req.Prompt, + Status: "processing", + Progress: 0, + CreatedAt: time.Now(), + }) + } + + result, err := s.pollResult(ctx, client, cfg, opts, taskID, modelCfg.Type == "video", req) + if err != nil { + if s.taskRepo != nil { + _ = s.taskRepo.UpdateStatus(ctx, taskID, "failed", 0, "", err.Error(), timePtr(time.Now())) + } + consecutive := 0 + if s.usageRepo != nil { + consecutive, _ = s.usageRepo.RecordError(ctx, account.ID) + } + if consecutive >= soraErrorDisableThreshold { + _ = s.accountRepo.SetError(ctx, account.ID, "Sora 连续错误次数过多,已自动禁用") + } + return nil, err + } + + if s.taskRepo != nil { + payload, _ := json.Marshal(result.ResultURLs) + _ = s.taskRepo.UpdateStatus(ctx, taskID, "completed", 100, string(payload), "", timePtr(time.Now())) + } + if s.usageRepo != nil { + _ = s.usageRepo.RecordSuccess(ctx, account.ID, modelCfg.Type == "video") + } + return result, nil +} + +func (s *SoraGatewayService) pollResult(ctx context.Context, client *sora.Client, cfg config.SoraConfig, opts sora.RequestOptions, taskID string, isVideo bool, req SoraGenerationRequest) (*SoraGenerationResult, error) { + if taskID == "" { + return nil, errors.New("missing task id") + } + pollInterval := 2 * time.Second + if cfg.PollInterval > 0 { + pollInterval = time.Duration(cfg.PollInterval*1000) * time.Millisecond + } + timeout := 300 * time.Second + if cfg.Timeout > 0 { + timeout = time.Duration(cfg.Timeout) * time.Second + } + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if isVideo { + pending, err := client.GetPendingTasks(ctx, opts) + if err == nil { + for _, task := range pending { + if stringFromMap(task, "id") == taskID { + continue + } + } + } + drafts, err := client.GetVideoDrafts(ctx, opts) + if err != nil { + return nil, err + } + items, _ := drafts["items"].([]any) + for _, item := range items { + entry, ok := item.(map[string]any) + if !ok { + continue + } + if stringFromMap(entry, "task_id") != taskID { + continue + } + url := firstNonEmpty(stringFromMap(entry, "downloadable_url"), stringFromMap(entry, "url")) + reason := stringFromMap(entry, "reason_str") + if url == "" { + if reason == "" { + reason = "视频生成失败" + } + return nil, errors.New(reason) + } + finalURL, err := s.handleWatermark(ctx, client, cfg, opts, url, entry, req, opts.AccountID, taskID) + if err != nil { + return nil, err + } + return &SoraGenerationResult{ + Content: buildVideoMarkdown(finalURL), + MediaType: "video", + ResultURLs: []string{finalURL}, + TaskID: taskID, + }, nil + } + } else { + resp, err := client.GetImageTasks(ctx, opts) + if err != nil { + return nil, err + } + tasks, _ := resp["task_responses"].([]any) + for _, item := range tasks { + entry, ok := item.(map[string]any) + if !ok { + continue + } + if stringFromMap(entry, "id") != taskID { + continue + } + status := stringFromMap(entry, "status") + switch status { + case "succeeded": + urls := extractImageURLs(entry) + if len(urls) == 0 { + return nil, errors.New("image urls empty") + } + content := buildImageMarkdown(urls) + return &SoraGenerationResult{ + Content: content, + MediaType: "image", + ResultURLs: urls, + TaskID: taskID, + }, nil + case "failed": + message := stringFromMap(entry, "error_message") + if message == "" { + message = "image generation failed" + } + return nil, errors.New(message) + } + } + } + + time.Sleep(pollInterval) + } + return nil, errors.New("generation timeout") +} + +func (s *SoraGatewayService) handleWatermark(ctx context.Context, client *sora.Client, cfg config.SoraConfig, opts sora.RequestOptions, url string, entry map[string]any, req SoraGenerationRequest, accountID int64, taskID string) (string, error) { + if !cfg.WatermarkFree.Enabled { + return s.cacheVideo(ctx, url, req, accountID, taskID), nil + } + generationID := stringFromMap(entry, "id") + if generationID == "" { + return s.cacheVideo(ctx, url, req, accountID, taskID), nil + } + postID, err := client.PostVideoForWatermarkFree(ctx, opts, generationID) + if err != nil { + if cfg.WatermarkFree.FallbackOnFailure { + return s.cacheVideo(ctx, url, req, accountID, taskID), nil + } + return "", err + } + if postID == "" { + if cfg.WatermarkFree.FallbackOnFailure { + return s.cacheVideo(ctx, url, req, accountID, taskID), nil + } + return "", errors.New("watermark-free post id empty") + } + var parsedURL string + if cfg.WatermarkFree.ParseMethod == "custom" { + if cfg.WatermarkFree.CustomParseURL == "" || cfg.WatermarkFree.CustomParseToken == "" { + return "", errors.New("custom parse 未配置") + } + parsedURL, err = s.fetchCustomWatermarkURL(ctx, cfg.WatermarkFree.CustomParseURL, cfg.WatermarkFree.CustomParseToken, postID) + if err != nil { + if cfg.WatermarkFree.FallbackOnFailure { + return s.cacheVideo(ctx, url, req, accountID, taskID), nil + } + return "", err + } + } else { + parsedURL = fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID) + } + cached := s.cacheVideo(ctx, parsedURL, req, accountID, taskID) + _ = client.DeletePost(ctx, opts, postID) + return cached, nil +} + +func (s *SoraGatewayService) cacheVideo(ctx context.Context, url string, req SoraGenerationRequest, accountID int64, taskID string) string { + if s.cacheService == nil { + return url + } + file, err := s.cacheService.CacheVideo(ctx, accountID, req.UserID, taskID, url) + if err != nil || file == nil { + return url + } + return file.CacheURL +} + +func (s *SoraGatewayService) getAccessToken(ctx context.Context, account *Account) (string, *SoraAccount, error) { + if account == nil { + return "", nil, errors.New("account is nil") + } + var soraAcc *SoraAccount + if s.soraAccountRepo != nil { + soraAcc, _ = s.soraAccountRepo.GetByAccountID(ctx, account.ID) + } + if soraAcc != nil && soraAcc.AccessToken != "" { + return soraAcc.AccessToken, soraAcc, nil + } + if account.Credentials != nil { + if v, ok := account.Credentials["access_token"].(string); ok && v != "" { + return v, soraAcc, nil + } + if v, ok := account.Credentials["token"].(string); ok && v != "" { + return v, soraAcc, nil + } + } + return "", soraAcc, ErrSoraAccountMissingToken +} + +func (s *SoraGatewayService) getClient(ctx context.Context) (*sora.Client, config.SoraConfig) { + cfg := s.getSoraConfig(ctx) + if s.httpUpstream == nil { + return nil, cfg + } + baseURL := strings.TrimSpace(cfg.BaseURL) + if baseURL == "" { + return nil, cfg + } + timeout := time.Duration(cfg.Timeout) * time.Second + if cfg.Timeout <= 0 { + timeout = 120 * time.Second + } + enableTLS := false + if s.cfg != nil { + enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled + } + return sora.NewClient(baseURL, timeout, s.httpUpstream, enableTLS), cfg +} + +func decodeBase64(raw string) ([]byte, error) { + data := raw + if idx := strings.Index(raw, "base64,"); idx != -1 { + data = raw[idx+7:] + } + return base64.StdEncoding.DecodeString(data) +} + +func extractImageURLs(entry map[string]any) []string { + generations, _ := entry["generations"].([]any) + urls := make([]string, 0, len(generations)) + for _, gen := range generations { + m, ok := gen.(map[string]any) + if !ok { + continue + } + if url, ok := m["url"].(string); ok && url != "" { + urls = append(urls, url) + } + } + return urls +} + +func buildImageMarkdown(urls []string) string { + parts := make([]string, 0, len(urls)) + for _, u := range urls { + parts = append(parts, fmt.Sprintf("![Generated Image](%s)", u)) + } + return strings.Join(parts, "\n") +} + +func buildVideoMarkdown(url string) string { + return fmt.Sprintf("```html\n\n```", url) +} + +func stringFromMap(m map[string]any, key string) string { + if m == nil { + return "" + } + if v, ok := m[key].(string); ok { + return v + } + return "" +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} + +func isSoraProAccount(acc *SoraAccount) bool { + if acc == nil { + return false + } + return strings.EqualFold(acc.PlanType, "chatgpt_pro") +} + +func timePtr(t time.Time) *time.Time { + return &t +} + +// fetchCustomWatermarkURL 使用自定义解析服务获取无水印视频 URL +func (s *SoraGatewayService) fetchCustomWatermarkURL(ctx context.Context, parseURL, parseToken, postID string) (string, error) { + // 使用项目的 URL 校验器验证 parseURL 格式,防止 SSRF 攻击 + if _, err := urlvalidator.ValidateHTTPSURL(parseURL, urlvalidator.ValidationOptions{}); err != nil { + return "", fmt.Errorf("无效的解析服务地址: %w", err) + } + + payload := map[string]any{ + "url": fmt.Sprintf("https://sora.chatgpt.com/p/%s", postID), + "token": parseToken, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, "POST", strings.TrimRight(parseURL, "/")+"/get-sora-link", strings.NewReader(string(body))) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + // 复用 httpUpstream,遵守代理和 TLS 配置 + enableTLS := false + if s.cfg != nil { + enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled + } + resp, err := s.httpUpstream.DoWithTLS(req, "", 0, 1, enableTLS) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("custom parse failed: %d", resp.StatusCode) + } + var parsed map[string]any + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return "", err + } + if errMsg, ok := parsed["error"].(string); ok && errMsg != "" { + return "", errors.New(errMsg) + } + if link, ok := parsed["download_link"].(string); ok { + return link, nil + } + return "", errors.New("custom parse response missing download_link") +} + +const ( + soraSlotImageLock int64 = 1 + soraSlotImageLimit int64 = 2 + soraSlotVideoLimit int64 = 3 + soraDefaultUsername = "character" +) + +func (s *SoraGatewayService) CallLogicMode(ctx context.Context) string { + return strings.TrimSpace(s.getSoraConfig(ctx).CallLogicMode) +} + +func (s *SoraGatewayService) getSoraConfig(ctx context.Context) config.SoraConfig { + if s.settingService != nil { + return s.settingService.GetSoraConfig(ctx) + } + if s.cfg != nil { + return s.cfg.Sora + } + return config.SoraConfig{} +} + +func (s *SoraGatewayService) acquireSoraSlots(ctx context.Context, account *Account, soraAcc *SoraAccount, isVideo bool) (func(), error) { + if s.concurrency == nil || account == nil || soraAcc == nil { + return nil, nil + } + releases := make([]func(), 0, 2) + appendRelease := func(release func()) { + if release != nil { + releases = append(releases, release) + } + } + // 错误时释放所有已获取的槽位 + releaseAll := func() { + for _, r := range releases { + r() + } + } + + if isVideo { + if soraAcc.VideoConcurrency > 0 { + release, err := s.acquireSoraSlot(ctx, account.ID, soraAcc.VideoConcurrency, soraSlotVideoLimit) + if err != nil { + releaseAll() + return nil, err + } + appendRelease(release) + } + } else { + release, err := s.acquireSoraSlot(ctx, account.ID, 1, soraSlotImageLock) + if err != nil { + releaseAll() + return nil, err + } + appendRelease(release) + if soraAcc.ImageConcurrency > 0 { + release, err := s.acquireSoraSlot(ctx, account.ID, soraAcc.ImageConcurrency, soraSlotImageLimit) + if err != nil { + releaseAll() // 释放已获取的 soraSlotImageLock + return nil, err + } + appendRelease(release) + } + } + + if len(releases) == 0 { + return nil, nil + } + return func() { + for _, release := range releases { + release() + } + }, nil +} + +func (s *SoraGatewayService) acquireSoraSlot(ctx context.Context, accountID int64, maxConcurrency int, slotType int64) (func(), error) { + if s.concurrency == nil || maxConcurrency <= 0 { + return nil, nil + } + derivedID := soraConcurrencyAccountID(accountID, slotType) + result, err := s.concurrency.AcquireAccountSlot(ctx, derivedID, maxConcurrency) + if err != nil { + return nil, err + } + if !result.Acquired { + return nil, ErrSoraAccountNotEligible + } + return result.ReleaseFunc, nil +} + +func soraConcurrencyAccountID(accountID int64, slotType int64) int64 { + if accountID < 0 { + accountID = -accountID + } + return -(accountID*10 + slotType) +} + +func (s *SoraGatewayService) createCharacter(ctx context.Context, client *sora.Client, opts sora.RequestOptions, rawVideo string) (string, string, error) { + videoBytes, err := s.loadVideoBytes(ctx, opts, rawVideo) + if err != nil { + return "", "", err + } + cameoID, err := client.UploadCharacterVideo(ctx, opts, videoBytes) + if err != nil { + return "", "", err + } + status, err := s.pollCameoStatus(ctx, client, opts, cameoID) + if err != nil { + return "", "", err + } + username := processCharacterUsername(stringFromMap(status, "username_hint")) + if username == "" { + username = soraDefaultUsername + } + displayName := stringFromMap(status, "display_name_hint") + if displayName == "" { + displayName = "Character" + } + profileURL := stringFromMap(status, "profile_asset_url") + if profileURL == "" { + return "", "", errors.New("profile asset url missing") + } + avatarData, err := client.DownloadCharacterImage(ctx, opts, profileURL) + if err != nil { + return "", "", err + } + assetPointer, err := client.UploadCharacterImage(ctx, opts, avatarData) + if err != nil { + return "", "", err + } + characterID, err := client.FinalizeCharacter(ctx, opts, cameoID, username, displayName, assetPointer) + if err != nil { + return "", "", err + } + if err := client.SetCharacterPublic(ctx, opts, cameoID); err != nil { + return "", "", err + } + return username, characterID, nil +} + +func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, client *sora.Client, opts sora.RequestOptions, cameoID string) (map[string]any, error) { + if cameoID == "" { + return nil, errors.New("cameo id empty") + } + timeout := 600 * time.Second + pollInterval := 5 * time.Second + deadline := time.Now().Add(timeout) + consecutiveErrors := 0 + maxConsecutiveErrors := 3 + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + time.Sleep(pollInterval) + status, err := client.GetCameoStatus(ctx, opts, cameoID) + if err != nil { + consecutiveErrors++ + if consecutiveErrors >= maxConsecutiveErrors { + return nil, err + } + continue + } + consecutiveErrors = 0 + statusValue := stringFromMap(status, "status") + statusMessage := stringFromMap(status, "status_message") + if statusValue == "failed" { + if statusMessage == "" { + statusMessage = "角色创建失败" + } + return nil, fmt.Errorf("角色创建失败: %s", statusMessage) + } + if strings.EqualFold(statusMessage, "Completed") || strings.EqualFold(statusValue, "finalized") { + return status, nil + } + } + return nil, errors.New("角色创建超时") +} + +func (s *SoraGatewayService) loadVideoBytes(ctx context.Context, opts sora.RequestOptions, rawVideo string) ([]byte, error) { + trimmed := strings.TrimSpace(rawVideo) + if trimmed == "" { + return nil, errors.New("video data is empty") + } + if looksLikeURL(trimmed) { + if err := s.validateMediaURL(trimmed); err != nil { + return nil, err + } + return s.downloadMedia(ctx, opts, trimmed, maxVideoDownloadSize) + } + return decodeBase64(trimmed) +} + +func (s *SoraGatewayService) loadImageBytes(ctx context.Context, opts sora.RequestOptions, rawImage string) ([]byte, error) { + trimmed := strings.TrimSpace(rawImage) + if trimmed == "" { + return nil, errors.New("image data is empty") + } + if looksLikeURL(trimmed) { + if err := s.validateMediaURL(trimmed); err != nil { + return nil, err + } + return s.downloadMedia(ctx, opts, trimmed, maxImageDownloadSize) + } + return decodeBase64(trimmed) +} + +func (s *SoraGatewayService) validateMediaURL(rawURL string) error { + cfg := s.cfg + if cfg == nil { + return nil + } + if cfg.Security.URLAllowlist.Enabled { + _, err := urlvalidator.ValidateHTTPSURL(rawURL, urlvalidator.ValidationOptions{ + AllowedHosts: cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return fmt.Errorf("媒体地址不合法: %w", err) + } + return nil + } + if _, err := urlvalidator.ValidateURLFormat(rawURL, cfg.Security.URLAllowlist.AllowInsecureHTTP); err != nil { + return fmt.Errorf("媒体地址不合法: %w", err) + } + return nil +} + +func (s *SoraGatewayService) downloadMedia(ctx context.Context, opts sora.RequestOptions, mediaURL string, maxSize int64) ([]byte, error) { + if s.httpUpstream == nil { + return nil, errors.New("upstream is nil") + } + req, err := http.NewRequestWithContext(ctx, "GET", mediaURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36") + enableTLS := false + if s.cfg != nil { + enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled + } + resp, err := s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, enableTLS) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("下载失败: %d", resp.StatusCode) + } + + // 使用 LimitReader 限制最大读取大小,防止 DoS 攻击 + limitedReader := io.LimitReader(resp.Body, maxSize+1) + data, err := io.ReadAll(limitedReader) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否超过大小限制 + if int64(len(data)) > maxSize { + return nil, fmt.Errorf("媒体文件过大 (最大 %d 字节, 实际 %d 字节)", maxSize, len(data)) + } + + return data, nil +} + +func processCharacterUsername(usernameHint string) string { + trimmed := strings.TrimSpace(usernameHint) + if trimmed == "" { + return "" + } + base := trimmed + if idx := strings.LastIndex(trimmed, "."); idx != -1 && idx+1 < len(trimmed) { + base = trimmed[idx+1:] + } + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return fmt.Sprintf("%s%d", base, rng.Intn(900)+100) +} + +func looksLikeURL(value string) bool { + trimmed := strings.ToLower(strings.TrimSpace(value)) + return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") +} diff --git a/backend/internal/service/sora_repository.go b/backend/internal/service/sora_repository.go new file mode 100644 index 00000000..578260bd --- /dev/null +++ b/backend/internal/service/sora_repository.go @@ -0,0 +1,113 @@ +package service + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// SoraAccount 表示 Sora 账号扩展信息。 +type SoraAccount struct { + AccountID int64 + AccessToken string + SessionToken string + RefreshToken string + ClientID string + Email string + Username string + Remark string + UseCount int + PlanType string + PlanTitle string + SubscriptionEnd *time.Time + SoraSupported bool + SoraInviteCode string + SoraRedeemedCount int + SoraRemainingCount int + SoraTotalCount int + SoraCooldownUntil *time.Time + CooledUntil *time.Time + ImageEnabled bool + VideoEnabled bool + ImageConcurrency int + VideoConcurrency int + IsExpired bool + CreatedAt time.Time + UpdatedAt time.Time +} + +// SoraUsageStat 表示 Sora 调用统计。 +type SoraUsageStat struct { + AccountID int64 + ImageCount int + VideoCount int + ErrorCount int + LastErrorAt *time.Time + TodayImageCount int + TodayVideoCount int + TodayErrorCount int + TodayDate *time.Time + ConsecutiveErrorCount int + CreatedAt time.Time + UpdatedAt time.Time +} + +// SoraTask 表示 Sora 任务记录。 +type SoraTask struct { + TaskID string + AccountID int64 + Model string + Prompt string + Status string + Progress float64 + ResultURLs string + ErrorMessage string + RetryCount int + CreatedAt time.Time + CompletedAt *time.Time +} + +// SoraCacheFile 表示 Sora 缓存文件记录。 +type SoraCacheFile struct { + ID int64 + TaskID string + AccountID int64 + UserID int64 + MediaType string + OriginalURL string + CachePath string + CacheURL string + SizeBytes int64 + CreatedAt time.Time +} + +// SoraAccountRepository 定义 Sora 账号仓储接口。 +type SoraAccountRepository interface { + GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) + GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*SoraAccount, error) + Upsert(ctx context.Context, accountID int64, updates map[string]any) error +} + +// SoraUsageStatRepository 定义 Sora 调用统计仓储接口。 +type SoraUsageStatRepository interface { + RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error + RecordError(ctx context.Context, accountID int64) (int, error) + ResetConsecutiveErrors(ctx context.Context, accountID int64) error + GetByAccountID(ctx context.Context, accountID int64) (*SoraUsageStat, error) + GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*SoraUsageStat, error) + List(ctx context.Context, params pagination.PaginationParams) ([]*SoraUsageStat, *pagination.PaginationResult, error) +} + +// SoraTaskRepository 定义 Sora 任务仓储接口。 +type SoraTaskRepository interface { + Create(ctx context.Context, task *SoraTask) error + UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error +} + +// SoraCacheFileRepository 定义 Sora 缓存文件仓储接口。 +type SoraCacheFileRepository interface { + Create(ctx context.Context, file *SoraCacheFile) error + ListOldest(ctx context.Context, limit int) ([]*SoraCacheFile, error) + DeleteByIDs(ctx context.Context, ids []int64) error +} diff --git a/backend/internal/service/sora_token_refresh_service.go b/backend/internal/service/sora_token_refresh_service.go new file mode 100644 index 00000000..0caf40e5 --- /dev/null +++ b/backend/internal/service/sora_token_refresh_service.go @@ -0,0 +1,313 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const defaultSoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" + +// SoraTokenRefreshService handles Sora access token refresh. +type SoraTokenRefreshService struct { + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository + settingService *SettingService + httpUpstream HTTPUpstream + cfg *config.Config + stopCh chan struct{} + stopOnce sync.Once +} + +func NewSoraTokenRefreshService( + accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, + settingService *SettingService, + httpUpstream HTTPUpstream, + cfg *config.Config, +) *SoraTokenRefreshService { + return &SoraTokenRefreshService{ + accountRepo: accountRepo, + soraAccountRepo: soraAccountRepo, + settingService: settingService, + httpUpstream: httpUpstream, + cfg: cfg, + stopCh: make(chan struct{}), + } +} + +func (s *SoraTokenRefreshService) Start() { + if s == nil { + return + } + go s.refreshLoop() +} + +func (s *SoraTokenRefreshService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + }) +} + +func (s *SoraTokenRefreshService) refreshLoop() { + for { + wait := s.nextRunDelay() + timer := time.NewTimer(wait) + select { + case <-timer.C: + s.refreshOnce() + case <-s.stopCh: + timer.Stop() + return + } + } +} + +func (s *SoraTokenRefreshService) refreshOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + if !s.isEnabled(ctx) { + log.Println("[SoraTokenRefresh] disabled by settings") + return + } + if s.accountRepo == nil || s.soraAccountRepo == nil { + log.Println("[SoraTokenRefresh] repository not configured") + return + } + + accounts, err := s.accountRepo.ListByPlatform(ctx, PlatformSora) + if err != nil { + log.Printf("[SoraTokenRefresh] list accounts failed: %v", err) + return + } + if len(accounts) == 0 { + log.Println("[SoraTokenRefresh] no sora accounts") + return + } + ids := make([]int64, 0, len(accounts)) + accountMap := make(map[int64]*Account, len(accounts)) + for i := range accounts { + acc := accounts[i] + ids = append(ids, acc.ID) + accountMap[acc.ID] = &acc + } + accountExtras, err := s.soraAccountRepo.GetByAccountIDs(ctx, ids) + if err != nil { + log.Printf("[SoraTokenRefresh] load sora accounts failed: %v", err) + return + } + + success := 0 + failed := 0 + skipped := 0 + for accountID, account := range accountMap { + extra := accountExtras[accountID] + if extra == nil { + skipped++ + continue + } + result, err := s.refreshForAccount(ctx, account, extra) + if err != nil { + failed++ + log.Printf("[SoraTokenRefresh] account %d refresh failed: %v", accountID, err) + continue + } + if result == nil { + skipped++ + continue + } + + updates := map[string]any{ + "access_token": result.AccessToken, + } + if result.RefreshToken != "" { + updates["refresh_token"] = result.RefreshToken + } + if result.Email != "" { + updates["email"] = result.Email + } + if err := s.soraAccountRepo.Upsert(ctx, accountID, updates); err != nil { + failed++ + log.Printf("[SoraTokenRefresh] account %d update failed: %v", accountID, err) + continue + } + success++ + } + log.Printf("[SoraTokenRefresh] done: success=%d failed=%d skipped=%d", success, failed, skipped) +} + +func (s *SoraTokenRefreshService) refreshForAccount(ctx context.Context, account *Account, extra *SoraAccount) (*soraRefreshResult, error) { + if extra == nil { + return nil, nil + } + if strings.TrimSpace(extra.SessionToken) == "" && strings.TrimSpace(extra.RefreshToken) == "" { + return nil, nil + } + + if extra.SessionToken != "" { + result, err := s.refreshWithSessionToken(ctx, account, extra.SessionToken) + if err == nil && result != nil && result.AccessToken != "" { + return result, nil + } + if strings.TrimSpace(extra.RefreshToken) == "" { + return nil, err + } + } + + clientID := strings.TrimSpace(extra.ClientID) + if clientID == "" { + clientID = defaultSoraClientID + } + return s.refreshWithRefreshToken(ctx, account, extra.RefreshToken, clientID) +} + +type soraRefreshResult struct { + AccessToken string + RefreshToken string + Email string +} + +type soraSessionResponse struct { + AccessToken string `json:"accessToken"` + User struct { + Email string `json:"email"` + } `json:"user"` +} + +type soraRefreshResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +func (s *SoraTokenRefreshService) refreshWithSessionToken(ctx context.Context, account *Account, sessionToken string) (*soraRefreshResult, error) { + if s.httpUpstream == nil { + return nil, fmt.Errorf("upstream not configured") + } + req, err := http.NewRequestWithContext(ctx, "GET", "https://sora.chatgpt.com/api/auth/session", nil) + if err != nil { + return nil, err + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36") + + enableTLS := false + if s.cfg != nil { + enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled + } + proxyURL := "" + accountConcurrency := 0 + accountID := int64(0) + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + } + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("session refresh failed: %d", resp.StatusCode) + } + var payload soraSessionResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return nil, err + } + if payload.AccessToken == "" { + return nil, errors.New("session refresh missing access token") + } + return &soraRefreshResult{AccessToken: payload.AccessToken, Email: payload.User.Email}, nil +} + +func (s *SoraTokenRefreshService) refreshWithRefreshToken(ctx context.Context, account *Account, refreshToken, clientID string) (*soraRefreshResult, error) { + if s.httpUpstream == nil { + return nil, fmt.Errorf("upstream not configured") + } + payload := map[string]any{ + "client_id": clientID, + "grant_type": "refresh_token", + "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", + "refresh_token": refreshToken, + } + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36") + + enableTLS := false + if s.cfg != nil { + enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled + } + proxyURL := "" + accountConcurrency := 0 + accountID := int64(0) + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + } + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("refresh token failed: %d", resp.StatusCode) + } + var payloadResp soraRefreshResponse + if err := json.NewDecoder(resp.Body).Decode(&payloadResp); err != nil { + return nil, err + } + if payloadResp.AccessToken == "" { + return nil, errors.New("refresh token missing access token") + } + return &soraRefreshResult{AccessToken: payloadResp.AccessToken, RefreshToken: payloadResp.RefreshToken}, nil +} + +func (s *SoraTokenRefreshService) nextRunDelay() time.Duration { + location := time.Local + if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" { + if tz, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil { + location = tz + } + } + now := time.Now().In(location) + next := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, location).Add(24 * time.Hour) + return time.Until(next) +} + +func (s *SoraTokenRefreshService) isEnabled(ctx context.Context) bool { + if s.settingService == nil { + return s.cfg != nil && s.cfg.Sora.TokenRefresh.Enabled + } + cfg := s.settingService.GetSoraConfig(ctx) + return cfg.TokenRefresh.Enabled +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b210286d..f68ed6ba 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -51,6 +51,30 @@ func ProvideTokenRefreshService( return svc } +// ProvideSoraTokenRefreshService creates and starts SoraTokenRefreshService. +func ProvideSoraTokenRefreshService( + accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, + settingService *SettingService, + httpUpstream HTTPUpstream, + cfg *config.Config, +) *SoraTokenRefreshService { + svc := NewSoraTokenRefreshService(accountRepo, soraAccountRepo, settingService, httpUpstream, cfg) + svc.Start() + return svc +} + +// ProvideSoraCacheCleanupService creates and starts SoraCacheCleanupService. +func ProvideSoraCacheCleanupService( + cacheRepo SoraCacheFileRepository, + settingService *SettingService, + cfg *config.Config, +) *SoraCacheCleanupService { + svc := NewSoraCacheCleanupService(cacheRepo, settingService, cfg) + svc.Start() + return svc +} + // ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { svc := NewDashboardAggregationService(repo, timingWheel, cfg) @@ -222,6 +246,8 @@ var ProviderSet = wire.NewSet( NewAdminService, NewGatewayService, NewOpenAIGatewayService, + NewSoraCacheService, + NewSoraGatewayService, NewOAuthService, NewOpenAIOAuthService, NewGeminiOAuthService, @@ -255,6 +281,8 @@ var ProviderSet = wire.NewSet( NewCRSSyncService, ProvideUpdateService, ProvideTokenRefreshService, + ProvideSoraTokenRefreshService, + ProvideSoraCacheCleanupService, ProvideAccountExpiryService, ProvideTimingWheelService, ProvideDashboardAggregationService, diff --git a/backend/migrations/044_add_sora_tables.sql b/backend/migrations/044_add_sora_tables.sql new file mode 100644 index 00000000..77e24a51 --- /dev/null +++ b/backend/migrations/044_add_sora_tables.sql @@ -0,0 +1,94 @@ +-- Add Sora platform tables + +CREATE TABLE IF NOT EXISTS sora_accounts ( + id BIGSERIAL PRIMARY KEY, + account_id BIGINT NOT NULL UNIQUE, + access_token TEXT, + session_token TEXT, + refresh_token TEXT, + client_id TEXT, + email TEXT, + username TEXT, + remark TEXT, + use_count INT DEFAULT 0, + plan_type TEXT, + plan_title TEXT, + subscription_end TIMESTAMPTZ, + sora_supported BOOLEAN DEFAULT FALSE, + sora_invite_code TEXT, + sora_redeemed_count INT DEFAULT 0, + sora_remaining_count INT DEFAULT 0, + sora_total_count INT DEFAULT 0, + sora_cooldown_until TIMESTAMPTZ, + cooled_until TIMESTAMPTZ, + image_enabled BOOLEAN DEFAULT TRUE, + video_enabled BOOLEAN DEFAULT TRUE, + image_concurrency INT DEFAULT -1, + video_concurrency INT DEFAULT -1, + is_expired BOOLEAN DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (account_id) REFERENCES accounts(id) +); + +CREATE INDEX IF NOT EXISTS idx_sora_accounts_plan_type ON sora_accounts (plan_type); +CREATE INDEX IF NOT EXISTS idx_sora_accounts_sora_supported ON sora_accounts (sora_supported); +CREATE INDEX IF NOT EXISTS idx_sora_accounts_image_enabled ON sora_accounts (image_enabled); +CREATE INDEX IF NOT EXISTS idx_sora_accounts_video_enabled ON sora_accounts (video_enabled); + +CREATE TABLE IF NOT EXISTS sora_usage_stats ( + id BIGSERIAL PRIMARY KEY, + account_id BIGINT NOT NULL UNIQUE, + image_count INT DEFAULT 0, + video_count INT DEFAULT 0, + error_count INT DEFAULT 0, + last_error_at TIMESTAMPTZ, + today_image_count INT DEFAULT 0, + today_video_count INT DEFAULT 0, + today_error_count INT DEFAULT 0, + today_date DATE, + consecutive_error_count INT DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (account_id) REFERENCES accounts(id) +); + +CREATE INDEX IF NOT EXISTS idx_sora_usage_stats_today_date ON sora_usage_stats (today_date); + +CREATE TABLE IF NOT EXISTS sora_tasks ( + id BIGSERIAL PRIMARY KEY, + task_id TEXT NOT NULL UNIQUE, + account_id BIGINT NOT NULL, + model TEXT NOT NULL, + prompt TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'processing', + progress DOUBLE PRECISION DEFAULT 0, + result_urls TEXT, + error_message TEXT, + retry_count INT DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + completed_at TIMESTAMPTZ, + FOREIGN KEY (account_id) REFERENCES accounts(id) +); + +CREATE INDEX IF NOT EXISTS idx_sora_tasks_account_id ON sora_tasks (account_id); +CREATE INDEX IF NOT EXISTS idx_sora_tasks_status ON sora_tasks (status); + +CREATE TABLE IF NOT EXISTS sora_cache_files ( + id BIGSERIAL PRIMARY KEY, + task_id TEXT, + account_id BIGINT NOT NULL, + user_id BIGINT NOT NULL, + media_type TEXT NOT NULL, + original_url TEXT NOT NULL, + cache_path TEXT NOT NULL, + cache_url TEXT NOT NULL, + size_bytes BIGINT DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + FOREIGN KEY (account_id) REFERENCES accounts(id), + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE INDEX IF NOT EXISTS idx_sora_cache_files_account_id ON sora_cache_files (account_id); +CREATE INDEX IF NOT EXISTS idx_sora_cache_files_user_id ON sora_cache_files (user_id); +CREATE INDEX IF NOT EXISTS idx_sora_cache_files_media_type ON sora_cache_files (media_type); diff --git a/config.yaml b/config.yaml index 5e7513fb..bd171417 100644 --- a/config.yaml +++ b/config.yaml @@ -525,3 +525,63 @@ gemini: # Cooldown time (minutes) after hitting quota # 达到配额后的冷却时间(分钟) cooldown_minutes: 5 + +# ============================================================================= +# Sora +# Sora 配置 +# ============================================================================= +sora: + # Sora Backend API base URL + # Sora 后端 API 基础地址 + base_url: "https://sora.chatgpt.com/backend" + # Request timeout in seconds + # 请求超时时间(秒) + timeout: 120 + # Max retry attempts for upstream requests + # 上游请求最大重试次数 + max_retries: 3 + # Poll interval in seconds for task status + # 任务状态轮询间隔(秒) + poll_interval: 2.5 + # Call logic mode: default/native/proxy (default keeps current behavior) + # 调用模式:default/native/proxy(default 保持当前默认策略) + call_logic_mode: "default" + cache: + # Enable media caching + # 是否启用媒体缓存 + enabled: false + # Base cache directory (temporary files, intermediate downloads) + # 缓存根目录(临时文件、中间下载) + base_dir: "tmp/sora" + # Video cache directory (separated from images) + # 视频缓存目录(与图片分离) + video_dir: "data/video" + # Max bytes for cache dir (0 = unlimited) + # 缓存目录最大字节数(0 = 不限制) + max_bytes: 0 + # Allowed hosts for cache download (empty -> fallback to global allowlist) + # 缓存下载白名单域名(为空则回退全局 allowlist) + allowed_hosts: [] + # Enable user directory isolation (data/video/u_{user_id}) + # 是否按用户隔离目录(data/video/u_{user_id}) + user_dir_enabled: true + watermark_free: + # Enable watermark-free flow + # 是否启用去水印流程 + enabled: false + # Parse method: third_party/custom + # 解析方式:third_party/custom + parse_method: "third_party" + # Custom parse server URL + # 自定义解析服务 URL + custom_parse_url: "" + # Custom parse token + # 自定义解析 token + custom_parse_token: "" + # Fallback to watermark video when parse fails + # 去水印失败时是否回退原视频 + fallback_on_failure: true + token_refresh: + # Enable periodic token refresh + # 是否启用定时刷新 + enabled: false diff --git a/deploy/.env.example b/deploy/.env.example index f21a3c62..668c2b4f 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -194,6 +194,28 @@ GEMINI_OAUTH_SCOPES= # GEMINI_QUOTA_POLICY={"tiers":{"LEGACY":{"pro_rpd":50,"flash_rpd":1500,"cooldown_minutes":30},"PRO":{"pro_rpd":1500,"flash_rpd":4000,"cooldown_minutes":5},"ULTRA":{"pro_rpd":2000,"flash_rpd":0,"cooldown_minutes":5}}} GEMINI_QUOTA_POLICY= +# ----------------------------------------------------------------------------- +# Sora Configuration (OPTIONAL) +# ----------------------------------------------------------------------------- +SORA_BASE_URL=https://sora.chatgpt.com/backend +SORA_TIMEOUT=120 +SORA_MAX_RETRIES=3 +SORA_POLL_INTERVAL=2.5 +SORA_CALL_LOGIC_MODE=default +SORA_CACHE_ENABLED=false +SORA_CACHE_BASE_DIR=tmp/sora +SORA_CACHE_VIDEO_DIR=data/video +SORA_CACHE_MAX_BYTES=0 +# Comma-separated hosts (leave empty to use global allowlist) +SORA_CACHE_ALLOWED_HOSTS= +SORA_CACHE_USER_DIR_ENABLED=true +SORA_WATERMARK_FREE_ENABLED=false +SORA_WATERMARK_FREE_PARSE_METHOD=third_party +SORA_WATERMARK_FREE_CUSTOM_PARSE_URL= +SORA_WATERMARK_FREE_CUSTOM_PARSE_TOKEN= +SORA_WATERMARK_FREE_FALLBACK_ON_FAILURE=true +SORA_TOKEN_REFRESH_ENABLED=false + # ----------------------------------------------------------------------------- # Ops Monitoring Configuration (运维监控配置) # ----------------------------------------------------------------------------- diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 558b8ef0..f3a001e7 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -583,6 +583,66 @@ gemini: # 达到配额后的冷却时间(分钟) cooldown_minutes: 5 +# ============================================================================= +# Sora +# Sora 配置 +# ============================================================================= +sora: + # Sora Backend API base URL + # Sora 后端 API 基础地址 + base_url: "https://sora.chatgpt.com/backend" + # Request timeout in seconds + # 请求超时时间(秒) + timeout: 120 + # Max retry attempts for upstream requests + # 上游请求最大重试次数 + max_retries: 3 + # Poll interval in seconds for task status + # 任务状态轮询间隔(秒) + poll_interval: 2.5 + # Call logic mode: default/native/proxy (default keeps current behavior) + # 调用模式:default/native/proxy(default 保持当前默认策略) + call_logic_mode: "default" + cache: + # Enable media caching + # 是否启用媒体缓存 + enabled: false + # Base cache directory (temporary files, intermediate downloads) + # 缓存根目录(临时文件、中间下载) + base_dir: "tmp/sora" + # Video cache directory (separated from images) + # 视频缓存目录(与图片分离) + video_dir: "data/video" + # Max bytes for cache dir (0 = unlimited) + # 缓存目录最大字节数(0 = 不限制) + max_bytes: 0 + # Allowed hosts for cache download (empty -> fallback to global allowlist) + # 缓存下载白名单域名(为空则回退全局 allowlist) + allowed_hosts: [] + # Enable user directory isolation (data/video/u_{user_id}) + # 是否按用户隔离目录(data/video/u_{user_id}) + user_dir_enabled: true + watermark_free: + # Enable watermark-free flow + # 是否启用去水印流程 + enabled: false + # Parse method: third_party/custom + # 解析方式:third_party/custom + parse_method: "third_party" + # Custom parse server URL + # 自定义解析服务 URL + custom_parse_url: "" + # Custom parse token + # 自定义解析 token + custom_parse_token: "" + # Fallback to watermark video when parse fails + # 去水印失败时是否回退原视频 + fallback_on_failure: true + token_refresh: + # Enable periodic token refresh + # 是否启用定时刷新 + enabled: false + # ============================================================================= # Update Configuration (在线更新配置) # ============================================================================= diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 6e2ade00..3e127135 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -55,6 +55,25 @@ export interface SystemSettings { enable_identity_patch: boolean identity_patch_prompt: string + // Sora configuration + sora_base_url: string + sora_timeout: number + sora_max_retries: number + sora_poll_interval: number + sora_call_logic_mode: string + sora_cache_enabled: boolean + sora_cache_base_dir: string + sora_cache_video_dir: string + sora_cache_max_bytes: number + sora_cache_allowed_hosts: string[] + sora_cache_user_dir_enabled: boolean + sora_watermark_free_enabled: boolean + sora_watermark_free_parse_method: string + sora_watermark_free_custom_parse_url: string + sora_watermark_free_custom_parse_token: string + sora_watermark_free_fallback_on_failure: boolean + sora_token_refresh_enabled: boolean + // Ops Monitoring (vNext) ops_monitoring_enabled: boolean ops_realtime_monitoring_enabled: boolean @@ -97,6 +116,23 @@ export interface UpdateSettingsRequest { fallback_model_antigravity?: string enable_identity_patch?: boolean identity_patch_prompt?: string + sora_base_url?: string + sora_timeout?: number + sora_max_retries?: number + sora_poll_interval?: number + sora_call_logic_mode?: string + sora_cache_enabled?: boolean + sora_cache_base_dir?: string + sora_cache_video_dir?: string + sora_cache_max_bytes?: number + sora_cache_allowed_hosts?: string[] + sora_cache_user_dir_enabled?: boolean + sora_watermark_free_enabled?: boolean + sora_watermark_free_parse_method?: string + sora_watermark_free_custom_parse_url?: string + sora_watermark_free_custom_parse_token?: string + sora_watermark_free_fallback_on_failure?: boolean + sora_token_refresh_enabled?: boolean ops_monitoring_enabled?: boolean ops_realtime_monitoring_enabled?: boolean ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 144241ff..9a73764a 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -147,6 +147,19 @@ Antigravity + @@ -672,6 +685,8 @@ ? 'https://api.openai.com' : form.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' + : form.platform === 'sora' + ? 'https://sora.chatgpt.com/backend' : 'https://api.anthropic.com' " /> @@ -689,6 +704,8 @@ ? 'sk-proj-...' : form.platform === 'gemini' ? 'AIza...' + : form.platform === 'sora' + ? 'access-token...' : 'sk-ant-...' " /> @@ -1850,12 +1867,14 @@ const oauthStepTitle = computed(() => { const baseUrlHint = computed(() => { if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') + if (form.platform === 'sora') return t('admin.accounts.sora.baseUrlHint') return t('admin.accounts.baseUrlHint') }) const apiKeyHint = computed(() => { if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint') + if (form.platform === 'sora') return t('admin.accounts.sora.apiKeyHint') return t('admin.accounts.apiKeyHint') }) @@ -2100,7 +2119,9 @@ watch( ? 'https://api.openai.com' : newPlatform === 'gemini' ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' + : newPlatform === 'sora' + ? 'https://sora.chatgpt.com/backend' + : 'https://api.anthropic.com' // Clear model-related settings allowedModels.value = [] modelMappings.value = [] @@ -2112,6 +2133,9 @@ watch( if (newPlatform === 'antigravity') { accountCategory.value = 'oauth-based' } + if (newPlatform === 'sora') { + accountCategory.value = 'apikey' + } // Reset OAuth states oauth.resetState() openaiOAuth.resetState() @@ -2383,12 +2407,17 @@ const handleSubmit = async () => { ? 'https://api.openai.com' : form.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' + : form.platform === 'sora' + ? 'https://sora.chatgpt.com/backend' + : 'https://api.anthropic.com' // Build credentials with optional model mapping - const credentials: Record = { - base_url: apiKeyBaseUrl.value.trim() || defaultBaseUrl, - api_key: apiKeyValue.value.trim() + const credentials: Record = {} + if (form.platform === 'sora') { + credentials.access_token = apiKeyValue.value.trim() + } else { + credentials.base_url = apiKeyBaseUrl.value.trim() || defaultBaseUrl + credentials.api_key = apiKeyValue.value.trim() } if (form.platform === 'gemini') { credentials.tier_id = geminiTierAIStudio.value diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 0dd855ef..f86694e3 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -39,6 +39,8 @@ ? 'https://api.openai.com' : account.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' + : account.platform === 'sora' + ? 'https://sora.chatgpt.com/backend' : 'https://api.anthropic.com' " /> @@ -55,6 +57,8 @@ ? 'sk-proj-...' : account.platform === 'gemini' ? 'AIza...' + : account.platform === 'sora' + ? 'access-token...' : 'sk-ant-...' " /> @@ -919,6 +923,7 @@ const baseUrlHint = computed(() => { if (!props.account) return t('admin.accounts.baseUrlHint') if (props.account.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') if (props.account.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') + if (props.account.platform === 'sora') return t('admin.accounts.sora.baseUrlHint') return t('admin.accounts.baseUrlHint') }) @@ -997,6 +1002,7 @@ const tempUnschedPresets = computed(() => [ const defaultBaseUrl = computed(() => { if (props.account?.platform === 'openai') return 'https://api.openai.com' if (props.account?.platform === 'gemini') return 'https://generativelanguage.googleapis.com' + if (props.account?.platform === 'sora') return 'https://sora.chatgpt.com/backend' return 'https://api.anthropic.com' }) @@ -1061,7 +1067,9 @@ watch( ? 'https://api.openai.com' : newAccount.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' + : newAccount.platform === 'sora' + ? 'https://sora.chatgpt.com/backend' + : 'https://api.anthropic.com' editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl // Load model mappings and detect mode @@ -1104,7 +1112,9 @@ watch( ? 'https://api.openai.com' : newAccount.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' + : newAccount.platform === 'sora' + ? 'https://sora.chatgpt.com/backend' + : 'https://api.anthropic.com' editBaseUrl.value = platformDefaultUrl modelRestrictionMode.value = 'whitelist' modelMappings.value = [] @@ -1381,17 +1391,32 @@ const handleSubmit = async () => { if (props.account.type === 'apikey') { const currentCredentials = (props.account.credentials as Record) || {} const newBaseUrl = editBaseUrl.value.trim() || defaultBaseUrl.value + const isSora = props.account.platform === 'sora' const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) // Always update credentials for apikey type to handle model mapping changes - const newCredentials: Record = { - base_url: newBaseUrl + const newCredentials: Record = {} + if (!isSora) { + newCredentials.base_url = newBaseUrl } // Handle API key if (editApiKey.value.trim()) { // User provided a new API key - newCredentials.api_key = editApiKey.value.trim() + if (isSora) { + newCredentials.access_token = editApiKey.value.trim() + } else { + newCredentials.api_key = editApiKey.value.trim() + } + } else if (isSora) { + const existingToken = (currentCredentials.access_token || currentCredentials.token) as string | undefined + if (existingToken) { + newCredentials.access_token = existingToken + } else { + appStore.showError(t('admin.accounts.apiKeyIsRequired')) + submitting.value = false + return + } } else if (currentCredentials.api_key) { // Preserve existing api_key newCredentials.api_key = currentCredentials.api_key diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 194237fa..ca904c00 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -428,7 +428,7 @@ interface Props { allowMultiple?: boolean methodLabel?: string showCookieOption?: boolean // Whether to show cookie auto-auth option - platform?: 'anthropic' | 'openai' | 'gemini' | 'antigravity' // Platform type for different UI/text + platform?: 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora' // Platform type for different UI/text showProjectId?: boolean // New prop to control project ID visibility } diff --git a/frontend/src/components/admin/account/AccountTableFilters.vue b/frontend/src/components/admin/account/AccountTableFilters.vue index 47ceedd7..16d34453 100644 --- a/frontend/src/components/admin/account/AccountTableFilters.vue +++ b/frontend/src/components/admin/account/AccountTableFilters.vue @@ -19,7 +19,7 @@ const props = defineProps(['searchQuery', 'filters']); const emit = defineEmits( const updatePlatform = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, platform: value }) } const updateType = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, type: value }) } const updateStatus = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, status: value }) } -const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }]) +const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'sora', label: 'Sora' }, { value: 'antigravity', label: 'Antigravity' }]) const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }]) const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }]) diff --git a/frontend/src/components/common/GroupBadge.vue b/frontend/src/components/common/GroupBadge.vue index 239d0452..4dd64c71 100644 --- a/frontend/src/components/common/GroupBadge.vue +++ b/frontend/src/components/common/GroupBadge.vue @@ -97,6 +97,9 @@ const labelClass = computed(() => { if (props.platform === 'gemini') { return `${base} bg-blue-200/60 text-blue-800 dark:bg-blue-800/40 dark:text-blue-300` } + if (props.platform === 'sora') { + return `${base} bg-rose-200/60 text-rose-800 dark:bg-rose-800/40 dark:text-rose-300` + } return `${base} bg-violet-200/60 text-violet-800 dark:bg-violet-800/40 dark:text-violet-300` }) @@ -118,6 +121,11 @@ const badgeClass = computed(() => { ? 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400' : 'bg-sky-50 text-sky-700 dark:bg-sky-900/20 dark:text-sky-400' } + if (props.platform === 'sora') { + return isSubscription.value + ? 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400' + : 'bg-rose-50 text-rose-700 dark:bg-rose-900/20 dark:text-rose-400' + } // Fallback: original colors return isSubscription.value ? 'bg-violet-100 text-violet-700 dark:bg-violet-900/30 dark:text-violet-400' diff --git a/frontend/src/components/common/PlatformIcon.vue b/frontend/src/components/common/PlatformIcon.vue index 1e137ae5..32bf11ef 100644 --- a/frontend/src/components/common/PlatformIcon.vue +++ b/frontend/src/components/common/PlatformIcon.vue @@ -19,6 +19,10 @@ + + + + { if (props.platform === 'anthropic') return 'Anthropic' if (props.platform === 'openai') return 'OpenAI' if (props.platform === 'antigravity') return 'Antigravity' + if (props.platform === 'sora') return 'Sora' return 'Gemini' }) @@ -74,6 +75,9 @@ const platformClass = computed(() => { if (props.platform === 'antigravity') { return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400' } + if (props.platform === 'sora') { + return 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400' + } return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400' }) @@ -87,6 +91,9 @@ const typeClass = computed(() => { if (props.platform === 'antigravity') { return 'bg-purple-100 text-purple-600 dark:bg-purple-900/30 dark:text-purple-400' } + if (props.platform === 'sora') { + return 'bg-rose-100 text-rose-600 dark:bg-rose-900/30 dark:text-rose-400' + } return 'bg-blue-100 text-blue-600 dark:bg-blue-900/30 dark:text-blue-400' }) diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 7f9bd1ed..99d09d1b 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -180,6 +180,8 @@ const defaultClientTab = computed(() => { switch (props.platform) { case 'openai': return 'codex' + case 'sora': + return 'codex' case 'gemini': return 'gemini' case 'antigravity': @@ -266,6 +268,7 @@ const clientTabs = computed((): TabConfig[] => { if (!props.platform) return [] switch (props.platform) { case 'openai': + case 'sora': return [ { id: 'codex', label: t('keys.useKeyModal.cliTabs.codexCli'), icon: TerminalIcon }, { id: 'opencode', label: t('keys.useKeyModal.cliTabs.opencode'), icon: TerminalIcon } @@ -306,7 +309,7 @@ const showShellTabs = computed(() => activeClientTab.value !== 'opencode') const currentTabs = computed(() => { if (!showShellTabs.value) return [] - if (props.platform === 'openai') { + if (props.platform === 'openai' || props.platform === 'sora') { return openaiTabs } return shellTabs @@ -315,6 +318,7 @@ const currentTabs = computed(() => { const platformDescription = computed(() => { switch (props.platform) { case 'openai': + case 'sora': return t('keys.useKeyModal.openai.description') case 'gemini': return t('keys.useKeyModal.gemini.description') @@ -328,6 +332,7 @@ const platformDescription = computed(() => { const platformNote = computed(() => { switch (props.platform) { case 'openai': + case 'sora': return activeTab.value === 'windows' ? t('keys.useKeyModal.openai.noteWindows') : t('keys.useKeyModal.openai.note') @@ -386,6 +391,7 @@ const currentFiles = computed((): FileConfig[] => { case 'anthropic': return [generateOpenCodeConfig('anthropic', apiBase, apiKey)] case 'openai': + case 'sora': return [generateOpenCodeConfig('openai', apiBase, apiKey)] case 'gemini': return [generateOpenCodeConfig('gemini', geminiBase, apiKey)] @@ -401,6 +407,7 @@ const currentFiles = computed((): FileConfig[] => { switch (props.platform) { case 'openai': + case 'sora': return generateOpenAIFiles(baseUrl, apiKey) case 'gemini': return [generateGeminiCliContent(baseUrl, apiKey)] diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index d4fa2993..d3e8fb1f 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -52,6 +52,38 @@ const geminiModels = [ 'gemini-3-pro-preview' ] +// OpenAI Sora +const soraModels = [ + 'gpt-image', + 'gpt-image-landscape', + 'gpt-image-portrait', + 'sora2-landscape-10s', + 'sora2-portrait-10s', + 'sora2-landscape-15s', + 'sora2-portrait-15s', + 'sora2-landscape-25s', + 'sora2-portrait-25s', + 'sora2pro-landscape-10s', + 'sora2pro-portrait-10s', + 'sora2pro-landscape-15s', + 'sora2pro-portrait-15s', + 'sora2pro-landscape-25s', + 'sora2pro-portrait-25s', + 'sora2pro-hd-landscape-10s', + 'sora2pro-hd-portrait-10s', + 'sora2pro-hd-landscape-15s', + 'sora2pro-hd-portrait-15s', + 'prompt-enhance-short-10s', + 'prompt-enhance-short-15s', + 'prompt-enhance-short-20s', + 'prompt-enhance-medium-10s', + 'prompt-enhance-medium-15s', + 'prompt-enhance-medium-20s', + 'prompt-enhance-long-10s', + 'prompt-enhance-long-15s', + 'prompt-enhance-long-20s' +] + // 智谱 GLM const zhipuModels = [ 'glm-4', 'glm-4v', 'glm-4-plus', 'glm-4-0520', @@ -182,6 +214,7 @@ const allModelsList: string[] = [ ...openaiModels, ...claudeModels, ...geminiModels, + ...soraModels, ...zhipuModels, ...qwenModels, ...deepseekModels, @@ -258,6 +291,7 @@ export function getModelsByPlatform(platform: string): string[] { case 'anthropic': case 'claude': return claudeModels case 'gemini': return geminiModels + case 'sora': return soraModels case 'zhipu': return zhipuModels case 'qwen': return qwenModels case 'deepseek': return deepseekModels @@ -281,6 +315,7 @@ export function getModelsByPlatform(platform: string): string[] { export function getPresetMappingsByPlatform(platform: string) { if (platform === 'openai') return openaiPresetMappings if (platform === 'gemini') return geminiPresetMappings + if (platform === 'sora') return [] return anthropicPresetMappings } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index e293491b..d100f47f 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -895,6 +895,7 @@ export default { anthropic: 'Anthropic', openai: 'OpenAI', gemini: 'Gemini', + sora: 'Sora', antigravity: 'Antigravity' }, deleteConfirm: @@ -1079,6 +1080,7 @@ export default { claude: 'Claude', openai: 'OpenAI', gemini: 'Gemini', + sora: 'Sora', antigravity: 'Antigravity' }, types: { @@ -1247,6 +1249,11 @@ export default { baseUrlHint: 'Leave default for official OpenAI API', apiKeyHint: 'Your OpenAI API Key' }, + // Sora specific hints + sora: { + baseUrlHint: 'Leave empty to use global Sora Base URL', + apiKeyHint: 'Your Sora access token' + }, modelRestriction: 'Model Restriction (Optional)', modelWhitelist: 'Model Whitelist', modelMapping: 'Model Mapping', @@ -2784,6 +2791,47 @@ export default { defaultConcurrency: 'Default Concurrency', defaultConcurrencyHint: 'Maximum concurrent requests for new users' }, + sora: { + title: 'Sora Settings', + description: 'Configure Sora upstream requests, cache, and watermark-free flow', + baseUrl: 'Sora Base URL', + baseUrlPlaceholder: 'https://sora.chatgpt.com/backend', + baseUrlHint: 'Base URL for the Sora backend API', + callLogicMode: 'Call Mode', + callLogicModeDefault: 'Default', + callLogicModeNative: 'Native', + callLogicModeProxy: 'Proxy', + callLogicModeHint: 'Default keeps the existing behavior', + timeout: 'Timeout (seconds)', + timeoutHint: 'Timeout for single request', + maxRetries: 'Max Retries', + maxRetriesHint: 'Retry count for upstream failures', + pollInterval: 'Poll Interval (seconds)', + pollIntervalHint: 'Polling interval for task status', + cacheEnabled: 'Enable Cache', + cacheEnabledHint: 'Cache generated media for local downloads', + cacheBaseDir: 'Cache Base Dir', + cacheVideoDir: 'Video Cache Dir', + cacheMaxBytes: 'Cache Size (bytes)', + cacheMaxBytesHint: '0 means unlimited', + cacheUserDirEnabled: 'User Directory Isolation', + cacheUserDirEnabledHint: 'Create per-user subdirectories', + cacheAllowedHosts: 'Cache Allowlist', + cacheAllowedHostsPlaceholder: 'One host per line, e.g. oscdn2.dyysy.com', + cacheAllowedHostsHint: 'Empty falls back to the global URL allowlist', + watermarkFreeEnabled: 'Enable Watermark-Free', + watermarkFreeEnabledHint: 'Try to resolve watermark-free videos', + watermarkFreeParseMethod: 'Parse Method', + watermarkFreeParseMethodThirdParty: 'Third-party', + watermarkFreeParseMethodCustom: 'Custom', + watermarkFreeParseMethodHint: 'Select the watermark-free parse method', + watermarkFreeCustomParseUrl: 'Custom Parse URL', + watermarkFreeCustomParseToken: 'Custom Parse Token', + watermarkFreeFallback: 'Fallback on Failure', + watermarkFreeFallbackHint: 'Return the original video on failure', + tokenRefreshEnabled: 'Enable Token Refresh', + tokenRefreshEnabledHint: 'Periodic token refresh (requires scheduler)' + }, site: { title: 'Site Settings', description: 'Customize site branding', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index dbeb3819..019aa357 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -941,6 +941,7 @@ export default { anthropic: 'Anthropic', openai: 'OpenAI', gemini: 'Gemini', + sora: 'Sora', antigravity: 'Antigravity' }, saving: '保存中...', @@ -1199,6 +1200,7 @@ export default { openai: 'OpenAI', anthropic: 'Anthropic', gemini: 'Gemini', + sora: 'Sora', antigravity: 'Antigravity' }, types: { @@ -1382,6 +1384,11 @@ export default { baseUrlHint: '留空使用官方 OpenAI API', apiKeyHint: '您的 OpenAI API Key' }, + // Sora specific hints + sora: { + baseUrlHint: '留空使用全局 Sora Base URL', + apiKeyHint: '您的 Sora access token' + }, modelRestriction: '模型限制(可选)', modelWhitelist: '模型白名单', modelMapping: '模型映射', @@ -2936,6 +2943,47 @@ export default { defaultConcurrency: '默认并发数', defaultConcurrencyHint: '新用户的最大并发请求数' }, + sora: { + title: 'Sora 设置', + description: '配置 Sora 上游请求、缓存与去水印策略', + baseUrl: 'Sora Base URL', + baseUrlPlaceholder: 'https://sora.chatgpt.com/backend', + baseUrlHint: 'Sora 后端 API 基础地址', + callLogicMode: '调用模式', + callLogicModeDefault: '默认', + callLogicModeNative: '原生', + callLogicModeProxy: '代理', + callLogicModeHint: '默认保持当前策略', + timeout: '请求超时(秒)', + timeoutHint: '单次任务超时控制', + maxRetries: '最大重试次数', + maxRetriesHint: '上游请求失败时的重试次数', + pollInterval: '轮询间隔(秒)', + pollIntervalHint: '任务状态轮询间隔', + cacheEnabled: '启用缓存', + cacheEnabledHint: '启用生成结果缓存并提供本地下载', + cacheBaseDir: '缓存根目录', + cacheVideoDir: '视频缓存目录', + cacheMaxBytes: '缓存容量(字节)', + cacheMaxBytesHint: '0 表示不限制', + cacheUserDirEnabled: '按用户隔离缓存目录', + cacheUserDirEnabledHint: '开启后按用户创建子目录', + cacheAllowedHosts: '缓存下载白名单', + cacheAllowedHostsPlaceholder: '每行一个域名,例如: oscdn2.dyysy.com', + cacheAllowedHostsHint: '为空时回退全局 URL 白名单', + watermarkFreeEnabled: '启用去水印', + watermarkFreeEnabledHint: '尝试通过解析服务获取无水印视频', + watermarkFreeParseMethod: '解析方式', + watermarkFreeParseMethodThirdParty: '第三方解析', + watermarkFreeParseMethodCustom: '自定义解析', + watermarkFreeParseMethodHint: '选择去水印解析方式', + watermarkFreeCustomParseUrl: '自定义解析地址', + watermarkFreeCustomParseToken: '自定义解析 Token', + watermarkFreeFallback: '解析失败降级', + watermarkFreeFallbackHint: '失败时返回原视频', + tokenRefreshEnabled: '启用 Token 刷新', + tokenRefreshEnabledHint: '定时刷新 Sora Token(需配置调度)' + }, site: { title: '站点设置', description: '自定义站点品牌', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 37c9f030..41c9c82e 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -252,7 +252,7 @@ export interface PaginationConfig { // ==================== API Key & Group Types ==================== -export type GroupPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' +export type GroupPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora' export type SubscriptionType = 'standard' | 'subscription' @@ -355,7 +355,7 @@ export interface UpdateGroupRequest { // ==================== Account & Proxy Types ==================== -export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' +export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora' export type AccountType = 'oauth' | 'setup-token' | 'apikey' export type OAuthAddMethod = 'oauth' | 'setup-token' export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index 78ef2e48..37daaf9b 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -1152,6 +1152,7 @@ const platformOptions = computed(() => [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, + { value: 'sora', label: 'Sora' }, { value: 'antigravity', label: 'Antigravity' } ]) @@ -1160,6 +1161,7 @@ const platformFilterOptions = computed(() => [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, + { value: 'sora', label: 'Sora' }, { value: 'antigravity', label: 'Antigravity' } ]) diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 7ebca114..ae46dc0f 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -561,6 +561,221 @@ + +
+
+

+ {{ t('admin.settings.sora.title') }} +

+

+ {{ t('admin.settings.sora.description') }} +

+
+
+
+
+ + +

+ {{ t('admin.settings.sora.baseUrlHint') }} +

+
+
+ + +

+ {{ t('admin.settings.sora.callLogicModeHint') }} +

+
+
+ +
+
+ + +

+ {{ t('admin.settings.sora.timeoutHint') }} +

+
+
+ + +

+ {{ t('admin.settings.sora.maxRetriesHint') }} +

+
+
+ + +

+ {{ t('admin.settings.sora.pollIntervalHint') }} +

+
+
+ +
+
+
+ +

+ {{ t('admin.settings.sora.cacheEnabledHint') }} +

+
+ +
+ +
+
+
+ + +
+
+ + +
+
+ + +

+ {{ t('admin.settings.sora.cacheMaxBytesHint') }} +

+
+
+ +
+
+ +

+ {{ t('admin.settings.sora.cacheUserDirEnabledHint') }} +

+
+ +
+ +
+ + +

+ {{ t('admin.settings.sora.cacheAllowedHostsHint') }} +

+
+
+
+ +
+
+
+ +

+ {{ t('admin.settings.sora.watermarkFreeEnabledHint') }} +

+
+ +
+ +
+
+
+ + +

+ {{ t('admin.settings.sora.watermarkFreeParseMethodHint') }} +

+
+
+
+ +

+ {{ t('admin.settings.sora.watermarkFreeFallbackHint') }} +

+
+ +
+
+ +
+
+ + +
+
+ + +
+
+
+
+ +
+
+ +

+ {{ t('admin.settings.sora.tokenRefreshEnabledHint') }} +

+
+ +
+
+
+
@@ -1023,6 +1238,7 @@ type SettingsForm = SystemSettings & { smtp_password: string turnstile_secret_key: string linuxdo_connect_client_secret: string + sora_cache_allowed_hosts_text: string } const form = reactive({ @@ -1067,6 +1283,25 @@ const form = reactive({ // Identity patch (Claude -> Gemini) enable_identity_patch: true, identity_patch_prompt: '', + // Sora + sora_base_url: 'https://sora.chatgpt.com/backend', + sora_timeout: 120, + sora_max_retries: 3, + sora_poll_interval: 2.5, + sora_call_logic_mode: 'default', + sora_cache_enabled: false, + sora_cache_base_dir: 'tmp/sora', + sora_cache_video_dir: 'data/video', + sora_cache_max_bytes: 0, + sora_cache_allowed_hosts: [], + sora_cache_user_dir_enabled: true, + sora_watermark_free_enabled: false, + sora_watermark_free_parse_method: 'third_party', + sora_watermark_free_custom_parse_url: '', + sora_watermark_free_custom_parse_token: '', + sora_watermark_free_fallback_on_failure: true, + sora_token_refresh_enabled: false, + sora_cache_allowed_hosts_text: '', // Ops monitoring (vNext) ops_monitoring_enabled: true, ops_realtime_monitoring_enabled: true, @@ -1136,6 +1371,7 @@ async function loadSettings() { form.smtp_password = '' form.turnstile_secret_key = '' form.linuxdo_connect_client_secret = '' + form.sora_cache_allowed_hosts_text = (settings.sora_cache_allowed_hosts || []).join('\n') } catch (error: any) { appStore.showError( t('admin.settings.failedToLoad') + ': ' + (error.message || t('common.unknownError')) @@ -1148,6 +1384,11 @@ async function loadSettings() { async function saveSettings() { saving.value = true try { + const soraAllowedHosts = form.sora_cache_allowed_hosts_text + .split(/\r?\n/) + .map((value) => value.trim()) + .filter((value) => value.length > 0) + const payload: UpdateSettingsRequest = { registration_enabled: form.registration_enabled, email_verify_enabled: form.email_verify_enabled, @@ -1182,13 +1423,31 @@ async function saveSettings() { fallback_model_gemini: form.fallback_model_gemini, fallback_model_antigravity: form.fallback_model_antigravity, enable_identity_patch: form.enable_identity_patch, - identity_patch_prompt: form.identity_patch_prompt + identity_patch_prompt: form.identity_patch_prompt, + sora_base_url: form.sora_base_url, + sora_timeout: form.sora_timeout, + sora_max_retries: form.sora_max_retries, + sora_poll_interval: form.sora_poll_interval, + sora_call_logic_mode: form.sora_call_logic_mode, + sora_cache_enabled: form.sora_cache_enabled, + sora_cache_base_dir: form.sora_cache_base_dir, + sora_cache_video_dir: form.sora_cache_video_dir, + sora_cache_max_bytes: form.sora_cache_max_bytes, + sora_cache_allowed_hosts: soraAllowedHosts, + sora_cache_user_dir_enabled: form.sora_cache_user_dir_enabled, + sora_watermark_free_enabled: form.sora_watermark_free_enabled, + sora_watermark_free_parse_method: form.sora_watermark_free_parse_method, + sora_watermark_free_custom_parse_url: form.sora_watermark_free_custom_parse_url, + sora_watermark_free_custom_parse_token: form.sora_watermark_free_custom_parse_token, + sora_watermark_free_fallback_on_failure: form.sora_watermark_free_fallback_on_failure, + sora_token_refresh_enabled: form.sora_token_refresh_enabled } const updated = await adminAPI.settings.updateSettings(payload) Object.assign(form, updated) form.smtp_password = '' form.turnstile_secret_key = '' form.linuxdo_connect_client_secret = '' + form.sora_cache_allowed_hosts_text = (updated.sora_cache_allowed_hosts || []).join('\n') // Refresh cached public settings so sidebar/header update immediately await appStore.fetchPublicSettings(true) appStore.showSuccess(t('admin.settings.settingsSaved')) diff --git a/frontend/src/views/admin/ops/components/OpsDashboardHeader.vue b/frontend/src/views/admin/ops/components/OpsDashboardHeader.vue index f2a7d787..493cb346 100644 --- a/frontend/src/views/admin/ops/components/OpsDashboardHeader.vue +++ b/frontend/src/views/admin/ops/components/OpsDashboardHeader.vue @@ -111,6 +111,7 @@ const platformOptions = computed(() => [ { value: 'openai', label: 'OpenAI' }, { value: 'anthropic', label: 'Anthropic' }, { value: 'gemini', label: 'Gemini' }, + { value: 'sora', label: 'Sora' }, { value: 'antigravity', label: 'Antigravity' } ]) diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index b72ae9ad..b7e3d166 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -916,6 +916,7 @@ const executeCcsImport = (row: ApiKey, clientType: 'claude' | 'gemini') => { } else { switch (platform) { case 'openai': + case 'sora': app = 'codex' endpoint = baseUrl break From a505d992eefaeb11b70f626c48ed80d799db19f5 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 29 Jan 2026 20:33:26 +0800 Subject: [PATCH 003/363] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yaml | 527 --------------------------------- deploy/docker-compose-test.yml | 4 +- 2 files changed, 2 insertions(+), 529 deletions(-) delete mode 100644 config.yaml diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 5e7513fb..00000000 --- a/config.yaml +++ /dev/null @@ -1,527 +0,0 @@ -# Sub2API Configuration File -# Sub2API 配置文件 -# -# Copy this file to /etc/sub2api/config.yaml and modify as needed -# 复制此文件到 /etc/sub2api/config.yaml 并根据需要修改 -# -# Documentation / 文档: https://github.com/Wei-Shaw/sub2api - -# ============================================================================= -# Server Configuration -# 服务器配置 -# ============================================================================= -server: - # Bind address (0.0.0.0 for all interfaces) - # 绑定地址(0.0.0.0 表示监听所有网络接口) - host: "0.0.0.0" - # Port to listen on - # 监听端口 - port: 8080 - # Mode: "debug" for development, "release" for production - # 运行模式:"debug" 用于开发,"release" 用于生产环境 - mode: "release" - # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. - # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 - trusted_proxies: [] - -# ============================================================================= -# Run Mode Configuration -# 运行模式配置 -# ============================================================================= -# Run mode: "standard" (default) or "simple" (for internal use) -# 运行模式:"standard"(默认)或 "simple"(内部使用) -# - standard: Full SaaS features with billing/balance checks -# - standard: 完整 SaaS 功能,包含计费和余额校验 -# - simple: Hides SaaS features and skips billing/balance checks -# - simple: 隐藏 SaaS 功能,跳过计费和余额校验 -run_mode: "standard" - -# ============================================================================= -# CORS Configuration -# 跨域资源共享 (CORS) 配置 -# ============================================================================= -cors: - # Allowed origins list. Leave empty to disable cross-origin requests. - # 允许的来源列表。留空则禁用跨域请求。 - allowed_origins: [] - # Allow credentials (cookies/authorization headers). Cannot be used with "*". - # 允许携带凭证(cookies/授权头)。不能与 "*" 通配符同时使用。 - allow_credentials: true - -# ============================================================================= -# Security Configuration -# 安全配置 -# ============================================================================= -security: - url_allowlist: - # Enable URL allowlist validation (disable to skip all URL checks) - # 启用 URL 白名单验证(禁用则跳过所有 URL 检查) - enabled: false - # Allowed upstream hosts for API proxying - # 允许代理的上游 API 主机列表 - upstream_hosts: - - "api.openai.com" - - "api.anthropic.com" - - "api.kimi.com" - - "open.bigmodel.cn" - - "api.minimaxi.com" - - "generativelanguage.googleapis.com" - - "cloudcode-pa.googleapis.com" - - "*.openai.azure.com" - # Allowed hosts for pricing data download - # 允许下载定价数据的主机列表 - pricing_hosts: - - "raw.githubusercontent.com" - # Allowed hosts for CRS sync (required when using CRS sync) - # 允许 CRS 同步的主机列表(使用 CRS 同步功能时必须配置) - crs_hosts: [] - # Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks) - # 允许本地/私有 IP 地址用于上游/定价/CRS(仅在可信网络中使用) - allow_private_hosts: true - # Allow http:// URLs when allowlist is disabled (default: false, require https) - # 白名单禁用时是否允许 http:// URL(默认: false,要求 https) - allow_insecure_http: true - response_headers: - # Enable configurable response header filtering (disable to use default allowlist) - # 启用可配置的响应头过滤(禁用则使用默认白名单) - enabled: false - # Extra allowed response headers from upstream - # 额外允许的上游响应头 - additional_allowed: [] - # Force-remove response headers from upstream - # 强制移除的上游响应头 - force_remove: [] - csp: - # Enable Content-Security-Policy header - # 启用内容安全策略 (CSP) 响应头 - enabled: true - # Default CSP policy (override if you host assets on other domains) - # 默认 CSP 策略(如果静态资源托管在其他域名,请自行覆盖) - policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" - proxy_probe: - # Allow skipping TLS verification for proxy probe (debug only) - # 允许代理探测时跳过 TLS 证书验证(仅用于调试) - insecure_skip_verify: false - -# ============================================================================= -# Gateway Configuration -# 网关配置 -# ============================================================================= -gateway: - # Timeout for waiting upstream response headers (seconds) - # 等待上游响应头超时时间(秒) - response_header_timeout: 600 - # Max request body size in bytes (default: 100MB) - # 请求体最大字节数(默认 100MB) - max_body_size: 104857600 - # Connection pool isolation strategy: - # 连接池隔离策略: - # - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts) - # - proxy: 按代理隔离,同一代理共享连接池(适合代理少、账户多) - # - account: Isolate by account, same account shares connection pool (suitable for few accounts, strict isolation) - # - account: 按账户隔离,同一账户共享连接池(适合账户少、需严格隔离) - # - account_proxy: Isolate by account+proxy combination (default, finest granularity) - # - account_proxy: 按账户+代理组合隔离(默认,最细粒度) - connection_pool_isolation: "account_proxy" - # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) - # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) - # Max idle connections across all hosts - # 所有主机的最大空闲连接数 - max_idle_conns: 240 - # Max idle connections per host - # 每个主机的最大空闲连接数 - max_idle_conns_per_host: 120 - # Max connections per host - # 每个主机的最大连接数 - max_conns_per_host: 240 - # Idle connection timeout (seconds) - # 空闲连接超时时间(秒) - idle_conn_timeout_seconds: 90 - # Upstream client cache settings - # 上游连接池客户端缓存配置 - # max_upstream_clients: Max cached clients, evicts least recently used when exceeded - # max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的 - max_upstream_clients: 5000 - # client_idle_ttl_seconds: Client idle reclaim threshold (seconds), reclaimed when idle and no active requests - # client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收 - client_idle_ttl_seconds: 900 - # Concurrency slot expiration time (minutes) - # 并发槽位过期时间(分钟) - concurrency_slot_ttl_minutes: 30 - # Stream data interval timeout (seconds), 0=disable - # 流数据间隔超时(秒),0=禁用 - stream_data_interval_timeout: 180 - # Stream keepalive interval (seconds), 0=disable - # 流式 keepalive 间隔(秒),0=禁用 - stream_keepalive_interval: 10 - # SSE max line size in bytes (default: 40MB) - # SSE 单行最大字节数(默认 40MB) - max_line_size: 41943040 - # Log upstream error response body summary (safe/truncated; does not log request content) - # 记录上游错误响应体摘要(安全/截断;不记录请求内容) - log_upstream_error_body: true - # Max bytes to log from upstream error body - # 记录上游错误响应体的最大字节数 - log_upstream_error_body_max_bytes: 2048 - # Auto inject anthropic-beta header for API-key accounts when needed (default: off) - # 需要时自动为 API-key 账户注入 anthropic-beta 头(默认:关闭) - inject_beta_for_apikey: false - # Allow failover on selected 400 errors (default: off) - # 允许在特定 400 错误时进行故障转移(默认:关闭) - failover_on_400: false - -# ============================================================================= -# API Key Auth Cache Configuration -# API Key 认证缓存配置 -# ============================================================================= -api_key_auth_cache: - # L1 cache size (entries), in-process LRU/TTL cache - # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 - l1_size: 65535 - # L1 cache TTL (seconds) - # L1 缓存 TTL(秒) - l1_ttl_seconds: 15 - # L2 cache TTL (seconds), stored in Redis - # L2 缓存 TTL(秒),Redis 中存储 - l2_ttl_seconds: 300 - # Negative cache TTL (seconds) - # 负缓存 TTL(秒) - negative_ttl_seconds: 30 - # TTL jitter percent (0-100) - # TTL 抖动百分比(0-100) - jitter_percent: 10 - # Enable singleflight for cache misses - # 缓存未命中时启用 singleflight 合并回源 - singleflight: true - -# ============================================================================= -# Dashboard Cache Configuration -# 仪表盘缓存配置 -# ============================================================================= -dashboard_cache: - # Enable dashboard cache - # 启用仪表盘缓存 - enabled: true - # Redis key prefix for multi-environment isolation - # Redis key 前缀,用于多环境隔离 - key_prefix: "sub2api:" - # Fresh TTL (seconds); within this window cached stats are considered fresh - # 新鲜阈值(秒);命中后处于该窗口视为新鲜数据 - stats_fresh_ttl_seconds: 15 - # Cache TTL (seconds) stored in Redis - # Redis 缓存 TTL(秒) - stats_ttl_seconds: 30 - # Async refresh timeout (seconds) - # 异步刷新超时(秒) - stats_refresh_timeout_seconds: 30 - -# ============================================================================= -# Dashboard Aggregation Configuration -# 仪表盘预聚合配置(重启生效) -# ============================================================================= -dashboard_aggregation: - # Enable aggregation job - # 启用聚合作业 - enabled: true - # Refresh interval (seconds) - # 刷新间隔(秒) - interval_seconds: 60 - # Lookback window (seconds) for late-arriving data - # 回看窗口(秒),处理迟到数据 - lookback_seconds: 120 - # Allow manual backfill - # 允许手动回填 - backfill_enabled: false - # Backfill max range (days) - # 回填最大跨度(天) - backfill_max_days: 31 - # Recompute recent N days on startup - # 启动时重算最近 N 天 - recompute_days: 2 - # Retention windows (days) - # 保留窗口(天) - retention: - # Raw usage_logs retention - # 原始 usage_logs 保留天数 - usage_logs_days: 90 - # Hourly aggregation retention - # 小时聚合保留天数 - hourly_days: 180 - # Daily aggregation retention - # 日聚合保留天数 - daily_days: 730 - -# ============================================================================= -# Usage Cleanup Task Configuration -# 使用记录清理任务配置(重启生效) -# ============================================================================= -usage_cleanup: - # Enable cleanup task worker - # 启用清理任务执行器 - enabled: true - # Max date range (days) per task - # 单次任务最大时间跨度(天) - max_range_days: 31 - # Batch delete size - # 单批删除数量 - batch_size: 5000 - # Worker interval (seconds) - # 执行器轮询间隔(秒) - worker_interval_seconds: 10 - # Task execution timeout (seconds) - # 单次任务最大执行时长(秒) - task_timeout_seconds: 1800 - -# ============================================================================= -# Concurrency Wait Configuration -# 并发等待配置 -# ============================================================================= -concurrency: - # SSE ping interval during concurrency wait (seconds) - # 并发等待期间的 SSE ping 间隔(秒) - ping_interval: 10 - -# ============================================================================= -# Database Configuration (PostgreSQL) -# 数据库配置 (PostgreSQL) -# ============================================================================= -database: - # Database host address - # 数据库主机地址 - host: "localhost" - # Database port - # 数据库端口 - port: 5432 - # Database username - # 数据库用户名 - user: "postgres" - # Database password - # 数据库密码 - password: "your_secure_password_here" - # Database name - # 数据库名称 - dbname: "sub2api" - # SSL mode: disable, require, verify-ca, verify-full - # SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证) - sslmode: "disable" - -# ============================================================================= -# Redis Configuration -# Redis 配置 -# ============================================================================= -redis: - # Redis host address - # Redis 主机地址 - host: "localhost" - # Redis port - # Redis 端口 - port: 6379 - # Redis password (leave empty if no password is set) - # Redis 密码(如果未设置密码则留空) - password: "" - # Database number (0-15) - # 数据库编号(0-15) - db: 0 - -# ============================================================================= -# Ops Monitoring (Optional) -# 运维监控 (可选) -# ============================================================================= -ops: - # Hard switch: disable all ops background jobs and APIs when false - # 硬开关:为 false 时禁用所有 Ops 后台任务与接口 - enabled: true - - # Prefer pre-aggregated tables (ops_metrics_hourly/ops_metrics_daily) for long-window dashboard queries. - # 优先使用预聚合表(用于长时间窗口查询性能) - use_preaggregated_tables: false - - # Data cleanup configuration - # 数据清理配置(vNext 默认统一保留 30 天) - cleanup: - enabled: true - # Cron expression (minute hour dom month dow), e.g. "0 2 * * *" = daily at 2 AM - # Cron 表达式(分 时 日 月 周),例如 "0 2 * * *" = 每天凌晨 2 点 - schedule: "0 2 * * *" - error_log_retention_days: 30 - minute_metrics_retention_days: 30 - hourly_metrics_retention_days: 30 - - # Pre-aggregation configuration - # 预聚合任务配置 - aggregation: - enabled: true - - # OpsMetricsCollector Redis cache (reduces duplicate expensive window aggregation in multi-replica deployments) - # 指标采集 Redis 缓存(多副本部署时减少重复计算) - metrics_collector_cache: - enabled: true - ttl: 65s - -# ============================================================================= -# JWT Configuration -# JWT 配置 -# ============================================================================= -jwt: - # IMPORTANT: Change this to a random string in production! - # 重要:生产环境中请更改为随机字符串! - # Generate with / 生成命令: openssl rand -hex 32 - secret: "change-this-to-a-secure-random-string" - # Token expiration time in hours (max 24) - # 令牌过期时间(小时,最大 24) - expire_hour: 24 - -# ============================================================================= -# Default Settings -# 默认设置 -# ============================================================================= -default: - # Initial admin account (created on first run) - # 初始管理员账户(首次运行时创建) - admin_email: "admin@example.com" - admin_password: "admin123" - - # Default settings for new users - # 新用户默认设置 - # Max concurrent requests per user - # 每用户最大并发请求数 - user_concurrency: 5 - # Initial balance for new users - # 新用户初始余额 - user_balance: 0 - - # API key settings - # API 密钥设置 - # Prefix for generated API keys - # 生成的 API 密钥前缀 - api_key_prefix: "sk-" - - # Rate multiplier (affects billing calculation) - # 费率倍数(影响计费计算) - rate_multiplier: 1.0 - -# ============================================================================= -# Rate Limiting -# 速率限制 -# ============================================================================= -rate_limit: - # Cooldown time (in minutes) when upstream returns 529 (overloaded) - # 上游返回 529(过载)时的冷却时间(分钟) - overload_cooldown_minutes: 10 - -# ============================================================================= -# Pricing Data Source (Optional) -# 定价数据源(可选) -# ============================================================================= -pricing: - # URL to fetch model pricing data (default: LiteLLM) - # 获取模型定价数据的 URL(默认:LiteLLM) - remote_url: "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" - # Hash verification URL (optional) - # 哈希校验 URL(可选) - hash_url: "" - # Local data directory for caching - # 本地数据缓存目录 - data_dir: "./data" - # Fallback pricing file - # 备用定价文件 - fallback_file: "./resources/model-pricing/model_prices_and_context_window.json" - # Update interval in hours - # 更新间隔(小时) - update_interval_hours: 24 - # Hash check interval in minutes - # 哈希检查间隔(分钟) - hash_check_interval_minutes: 10 - -# ============================================================================= -# Billing Configuration -# 计费配置 -# ============================================================================= -billing: - circuit_breaker: - # Enable circuit breaker for billing service - # 启用计费服务熔断器 - enabled: true - # Number of failures before opening circuit - # 触发熔断的失败次数阈值 - failure_threshold: 5 - # Time to wait before attempting reset (seconds) - # 熔断后重试等待时间(秒) - reset_timeout_seconds: 30 - # Number of requests to allow in half-open state - # 半开状态允许通过的请求数 - half_open_requests: 3 - -# ============================================================================= -# Turnstile Configuration -# Turnstile 人机验证配置 -# ============================================================================= -turnstile: - # Require Turnstile in release mode (when enabled, login/register will fail if not configured) - # 在 release 模式下要求 Turnstile 验证(启用后,若未配置则登录/注册会失败) - required: false - -# ============================================================================= -# Gemini OAuth (Required for Gemini accounts) -# Gemini OAuth 配置(Gemini 账户必需) -# ============================================================================= -# Sub2API supports TWO Gemini OAuth modes: -# Sub2API 支持两种 Gemini OAuth 模式: -# -# 1. Code Assist OAuth (requires GCP project_id) -# 1. Code Assist OAuth(需要 GCP project_id) -# - Uses: cloudcode-pa.googleapis.com (Code Assist API) -# - 使用:cloudcode-pa.googleapis.com(Code Assist API) -# -# 2. AI Studio OAuth (no project_id needed) -# 2. AI Studio OAuth(不需要 project_id) -# - Uses: generativelanguage.googleapis.com (AI Studio API) -# - 使用:generativelanguage.googleapis.com(AI Studio API) -# -# Default: Uses Gemini CLI's public OAuth credentials (same as Google's official CLI tool) -# 默认:使用 Gemini CLI 的公开 OAuth 凭证(与 Google 官方 CLI 工具相同) -gemini: - oauth: - # Gemini CLI public OAuth credentials (works for both Code Assist and AI Studio) - # Gemini CLI 公开 OAuth 凭证(适用于 Code Assist 和 AI Studio) - client_id: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - client_secret: "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" - # Optional scopes (space-separated). Leave empty to auto-select based on oauth_type. - # 可选的权限范围(空格分隔)。留空则根据 oauth_type 自动选择。 - scopes: "" - quota: - # Optional: local quota simulation for Gemini Code Assist (local billing). - # 可选:Gemini Code Assist 本地配额模拟(本地计费)。 - # These values are used for UI progress + precheck scheduling, not official Google quotas. - # 这些值用于 UI 进度显示和预检调度,并非 Google 官方配额。 - tiers: - LEGACY: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 50 - # Flash model requests per day - # Flash 模型每日请求数 - flash_rpd: 1500 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 30 - PRO: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 1500 - # Flash model requests per day - # Flash 模型每日请求数 - flash_rpd: 4000 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 5 - ULTRA: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 2000 - # Flash model requests per day (0 = unlimited) - # Flash 模型每日请求数(0 = 无限制) - flash_rpd: 0 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 5 diff --git a/deploy/docker-compose-test.yml b/deploy/docker-compose-test.yml index bcda3141..19903f6f 100644 --- a/deploy/docker-compose-test.yml +++ b/deploy/docker-compose-test.yml @@ -33,7 +33,7 @@ services: # Data persistence (config.yaml will be auto-generated here) - sub2api_data:/app/data # Mount custom config.yaml (optional, overrides auto-generated config) - - ./config.yaml:/app/data/config.yaml:ro + # - ./config.yaml:/app/data/config.yaml:ro environment: # ======================================================================= # Auto Setup (REQUIRED for Docker deployment) @@ -150,7 +150,7 @@ services: # Redis Cache # =========================================================================== redis: - image: redis:7-alpine + image: redis:8-alpine container_name: sub2api-redis restart: unless-stopped ulimits: From 99dc3b59bc54ca8512478397b8ad69bb8b3834fa Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Fri, 30 Jan 2026 14:08:04 +0800 Subject: [PATCH 004/363] =?UTF-8?q?feat(=E8=B4=A6=E5=8F=B7):=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=20Sora=20=E8=B4=A6=E5=8F=B7=E5=8F=8C=E8=A1=A8?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E4=B8=8E=E5=88=9B=E5=BB=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 sora_accounts 表与 accounts.extra GIN 索引\n- OpenAI OAuth 支持同时创建 Sora 账号并同步配置\n- Token 刷新同步关联 Sora 账号凭证与扩展表\n- 增加 Sora 账号连通性测试与前端开关文案 --- backend/cmd/server/wire_gen.go | 5 +- backend/internal/repository/account_repo.go | 61 ++++++++++++ .../internal/repository/sora_account_repo.go | 98 +++++++++++++++++++ backend/internal/repository/wire.go | 1 + backend/internal/server/api_contract_test.go | 2 +- backend/internal/service/account_service.go | 3 + .../service/account_service_delete_test.go | 4 + .../internal/service/account_test_service.go | 73 ++++++++++++++ backend/internal/service/admin_service.go | 15 +++ backend/internal/service/domain_constants.go | 1 + .../service/gateway_multiplatform_test.go | 3 + .../service/gemini_multiplatform_test.go | 3 + .../internal/service/sora_account_service.go | 40 ++++++++ .../internal/service/token_refresh_service.go | 16 ++- backend/internal/service/token_refresher.go | 79 ++++++++++++++- backend/internal/service/wire.go | 3 + .../045_add_accounts_extra_index.sql | 13 +++ backend/migrations/046_add_sora_accounts.sql | 24 +++++ .../components/account/CreateAccountModal.vue | 95 +++++++++++++++++- frontend/src/i18n/locales/en.ts | 6 +- frontend/src/i18n/locales/zh.ts | 6 +- 21 files changed, 542 insertions(+), 9 deletions(-) create mode 100644 backend/internal/repository/sora_account_repo.go create mode 100644 backend/internal/service/sora_account_service.go create mode 100644 backend/migrations/045_add_accounts_extra_index.sql create mode 100644 backend/migrations/046_add_sora_accounts.sql diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 7b22a31e..b8668665 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -86,10 +86,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) + soraAccountRepository := repository.NewSoraAccountRepository(db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -176,7 +177,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index c11c079b..5edc4f6d 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1553,3 +1553,64 @@ func joinClauses(clauses []string, sep string) string { func itoa(v int) string { return strconv.Itoa(v) } + +// FindByExtraField 根据 extra 字段中的键值对查找账号。 +// 该方法限定 platform='sora',避免误查询其他平台的账号。 +// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。 +// +// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。 +// +// FindByExtraField finds accounts by key-value pairs in the extra field. +// Limited to platform='sora' to avoid querying accounts from other platforms. +// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). +// +// Use case: Finding Sora accounts linked via linked_openai_account_id. +func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value interface{}) ([]service.Account, error) { + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ("sora"), // 限定平台为 sora + dbaccount.DeletedAtIsNil(), + func(s *entsql.Selector) { + path := sqljson.Path(key) + switch v := value.(type) { + case string: + preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)} + if parsed, err := strconv.ParseInt(v, 10, 64); err == nil { + preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path)) + } + if len(preds) == 1 { + s.Where(preds[0]) + } else { + s.Where(entsql.Or(preds...)) + } + case int: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path), + )) + case int64: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path), + )) + case json.Number: + if parsed, err := v.Int64(); err == nil { + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path), + sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path), + )) + } else { + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path)) + } + default: + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path)) + } + }, + ). + All(ctx) + if err != nil { + return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) + } + + return r.accountsToService(ctx, accounts) +} diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go new file mode 100644 index 00000000..e0ec6073 --- /dev/null +++ b/backend/internal/repository/sora_account_repo.go @@ -0,0 +1,98 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraAccountRepository 实现 service.SoraAccountRepository 接口。 +// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。 +// +// 设计说明: +// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理 +// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义 +// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除 +type soraAccountRepository struct { + sql *sql.DB +} + +// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例 +func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository { + return &soraAccountRepository{sql: sqlDB} +} + +// Upsert 创建或更新 Sora 账号扩展信息 +// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert +func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { + accessToken, accessOK := updates["access_token"].(string) + refreshToken, refreshOK := updates["refresh_token"].(string) + sessionToken, sessionOK := updates["session_token"].(string) + + if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" { + if !sessionOK { + return errors.New("缺少 access_token/refresh_token,且未提供可更新字段") + } + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_accounts + SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END, + updated_at = NOW() + WHERE account_id = $1 + `, accountID, sessionToken) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return errors.New("sora_accounts 记录不存在,无法仅更新 session_token") + } + return nil + } + + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at) + VALUES ($1, $2, $3, $4, NOW(), NOW()) + ON CONFLICT (account_id) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END, + updated_at = NOW() + `, accountID, accessToken, refreshToken, sessionToken) + return err +} + +// GetByAccountID 根据账号 ID 获取 Sora 扩展信息 +func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT account_id, access_token, refresh_token, COALESCE(session_token, '') + FROM sora_accounts + WHERE account_id = $1 + `, accountID) + if err != nil { + return nil, err + } + defer rows.Close() + + if !rows.Next() { + return nil, nil // 记录不存在 + } + + var sa service.SoraAccount + if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil { + return nil, err + } + return &sa, nil +} + +// Delete 删除 Sora 账号扩展信息 +func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error { + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM sora_accounts WHERE account_id = $1 + `, accountID) + return err +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 7a8d85f4..929eb22b 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -53,6 +53,7 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, + NewSoraAccountRepository, // Sora 账号扩展表仓储 NewProxyRepository, NewRedeemCodeRepository, NewPromoCodeRepository, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 4d1b4be2..f3eebd41 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -594,7 +594,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 90365d2f..4befc996 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -25,6 +25,9 @@ type AccountRepository interface { // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) + // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') + // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 + FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) Update(ctx context.Context, account *Account) error Delete(ctx context.Context, id int64) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index e5eabfc6..f4e03e8e 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st panic("unexpected GetByCRSAccountID call") } +func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) { + panic("unexpected FindByExtraField call") +} + func (s *accountRepoStub) Update(ctx context.Context, account *Account) error { panic("unexpected Update call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 46376c69..f80a2af8 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -31,6 +31,7 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" + soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 ) // TestEvent represents a SSE event for account testing @@ -163,6 +164,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.testAntigravityAccountConnection(c, account, modelID) } + if account.Platform == PlatformSora { + return s.testSoraAccountConnection(c, account) + } + return s.testClaudeAccountConnection(c, account, modelID) } @@ -461,6 +466,74 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } +// testSoraAccountConnection 测试 Sora 账号的连接 +// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token) +func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { + ctx := c.Request.Context() + + authToken := account.GetCredential("access_token") + if authToken == "" { + return s.sendErrorAndEnd(c, "No access token available") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) + + req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + + // 使用 Sora 客户端标准请求头(参考 sora2api) + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body))) + } + + // 解析 /me 响应,提取用户信息 + var meResp map[string]any + if err := json.Unmarshal(body, &meResp); err != nil { + // 能收到 200 就说明 token 有效 + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"}) + } else { + // 尝试提取用户名或邮箱信息 + info := "Sora connection OK" + if name, ok := meResp["name"].(string); ok && name != "" { + info = fmt.Sprintf("Sora connection OK - User: %s", name) + } else if email, ok := meResp["email"].(string); ok && email != "" { + info = fmt.Sprintf("Sora connection OK - Email: %s", email) + } + s.sendEvent(c, TestEvent{Type: "content", Text: info}) + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 0afa0716..398de0e0 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -272,6 +272,7 @@ type adminServiceImpl struct { userRepo UserRepository groupRepo GroupRepository accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -286,6 +287,7 @@ func NewAdminService( userRepo UserRepository, groupRepo GroupRepository, accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -298,6 +300,7 @@ func NewAdminService( userRepo: userRepo, groupRepo: groupRepo, accountRepo: accountRepo, + soraAccountRepo: soraAccountRepo, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -862,6 +865,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return nil, err } + // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录 + if account.Platform == PlatformSora && s.soraAccountRepo != nil { + soraUpdates := map[string]any{ + "access_token": account.GetCredential("access_token"), + "refresh_token": account.GetCredential("refresh_token"), + } + if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { + // 只记录警告日志,不阻塞账号创建 + log.Printf("[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) + } + } + // 绑定分组 if len(groupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 3bb63ffa..31c576ed 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -22,6 +22,7 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" + PlatformSora = "sora" ) // Account type constants diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 26eb24e4..d9ae6709 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -77,6 +77,9 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } +func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) { + return nil, nil +} func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error { return nil } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index c63a020c..a2c6f937 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -66,6 +66,9 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } +func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) { + return nil, nil +} func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil } func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil } func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go new file mode 100644 index 00000000..eccc1acf --- /dev/null +++ b/backend/internal/service/sora_account_service.go @@ -0,0 +1,40 @@ +package service + +import "context" + +// SoraAccountRepository Sora 账号扩展表仓储接口 +// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。 +// +// 设计说明: +// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本 +// - Sora gateway 优先读取此表的字段以获得更好的查询性能 +// - 主表 accounts 通过 credentials JSON 字段也存储相同信息 +// - Token 刷新时需要同时更新两个表以保持数据一致性 +type SoraAccountRepository interface { + // Upsert 创建或更新 Sora 账号扩展信息 + // accountID: 关联的 accounts.id + // updates: 要更新的字段,支持 access_token、refresh_token、session_token + // + // 如果记录不存在则创建,存在则更新。 + // 用于: + // 1. 创建 Sora 账号时初始化扩展表 + // 2. Token 刷新时同步更新扩展表 + Upsert(ctx context.Context, accountID int64, updates map[string]any) error + + // GetByAccountID 根据账号 ID 获取 Sora 扩展信息 + // 返回 nil, nil 表示记录不存在(非错误) + GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) + + // Delete 删除 Sora 账号扩展信息 + // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理 + Delete(ctx context.Context, accountID int64) error +} + +// SoraAccount Sora 账号扩展信息 +// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本 +type SoraAccount struct { + AccountID int64 // 关联的 accounts.id + AccessToken string // OAuth access_token + RefreshToken string // OAuth refresh_token + SessionToken string // Session token(可选,用于 ST→AT 兜底) +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 7364bd33..797ab721 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -15,6 +15,7 @@ import ( // 定期检查并刷新即将过期的token type TokenRefreshService struct { accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 refreshers []TokenRefresher cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator @@ -43,7 +44,7 @@ func NewTokenRefreshService( // 注册平台特定的刷新器 s.refreshers = []TokenRefresher{ NewClaudeTokenRefresher(oauthService), - NewOpenAITokenRefresher(openaiOAuthService), + NewOpenAITokenRefresher(openaiOAuthService, accountRepo), NewGeminiTokenRefresher(geminiOAuthService), NewAntigravityTokenRefresher(antigravityOAuthService), } @@ -51,6 +52,19 @@ func NewTokenRefreshService( return s } +// SetSoraAccountRepo 设置 Sora 账号扩展表仓储 +// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表 +// 需要在 Start() 之前调用 +func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { + s.soraAccountRepo = repo + // 将 soraAccountRepo 注入到 OpenAITokenRefresher + for _, refresher := range s.refreshers { + if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { + openaiRefresher.SetSoraAccountRepo(repo) + } + } +} + // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 214a290a..807524fd 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -2,6 +2,7 @@ package service import ( "context" + "log" "strconv" "time" ) @@ -82,16 +83,26 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m // OpenAITokenRefresher 处理 OpenAI OAuth token刷新 type OpenAITokenRefresher struct { - openaiOAuthService *OpenAIOAuthService + openaiOAuthService *OpenAIOAuthService + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 -func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAITokenRefresher { +func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo AccountRepository) *OpenAITokenRefresher { return &OpenAITokenRefresher{ openaiOAuthService: openaiOAuthService, + accountRepo: accountRepo, } } +// SetSoraAccountRepo 设置 Sora 账号扩展表仓储 +// 用于在 Token 刷新时同步更新 sora_accounts 表 +// 如果未设置,syncLinkedSoraAccounts 只会更新 accounts.credentials +func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { + r.soraAccountRepo = repo +} + // CanRefresh 检查是否能处理此账号 // 只处理 openai 平台的 oauth 类型账号 func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { @@ -112,6 +123,7 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time // Refresh 执行token刷新 // 保留原有credentials中的所有字段,只更新token相关字段 +// 刷新成功后,异步同步关联的 Sora 账号 func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) { tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account) if err != nil { @@ -128,5 +140,68 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m } } + // 异步同步关联的 Sora 账号(不阻塞主流程) + if r.accountRepo != nil { + go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) + } + return newCredentials, nil } + +// syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步) +// 该方法异步执行,失败只记录日志,不影响主流程 +// +// 同步策略: +// 1. 更新 accounts.credentials(主表) +// 2. 更新 sora_accounts 扩展表(如果 soraAccountRepo 已设置) +// +// 超时控制:30 秒,防止数据库阻塞导致 goroutine 泄漏 +func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) { + // 添加超时控制,防止 goroutine 泄漏 + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // 1. 查找所有关联的 Sora 账号(限定 platform='sora') + soraAccounts, err := r.accountRepo.FindByExtraField(ctx, "linked_openai_account_id", openaiAccountID) + if err != nil { + log.Printf("[TokenSync] 查找关联 Sora 账号失败: openai_account_id=%d err=%v", openaiAccountID, err) + return + } + + if len(soraAccounts) == 0 { + // 没有关联的 Sora 账号,直接返回 + return + } + + // 2. 同步更新每个 Sora 账号的双表数据 + for _, soraAccount := range soraAccounts { + // 2.1 更新 accounts.credentials(主表) + soraAccount.Credentials["access_token"] = newCredentials["access_token"] + soraAccount.Credentials["refresh_token"] = newCredentials["refresh_token"] + if expiresAt, ok := newCredentials["expires_at"]; ok { + soraAccount.Credentials["expires_at"] = expiresAt + } + + if err := r.accountRepo.Update(ctx, &soraAccount); err != nil { + log.Printf("[TokenSync] 更新 Sora accounts 表失败: sora_account_id=%d openai_account_id=%d err=%v", + soraAccount.ID, openaiAccountID, err) + continue + } + + // 2.2 更新 sora_accounts 扩展表(如果仓储已设置) + if r.soraAccountRepo != nil { + soraUpdates := map[string]any{ + "access_token": newCredentials["access_token"], + "refresh_token": newCredentials["refresh_token"], + } + if err := r.soraAccountRepo.Upsert(ctx, soraAccount.ID, soraUpdates); err != nil { + log.Printf("[TokenSync] 更新 sora_accounts 表失败: account_id=%d openai_account_id=%d err=%v", + soraAccount.ID, openaiAccountID, err) + // 继续处理其他账号,不中断 + } + } + + log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v", + soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil) + } +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b210286d..73d23025 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -39,6 +39,7 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { // ProvideTokenRefreshService creates and starts TokenRefreshService func ProvideTokenRefreshService( accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步 oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, @@ -47,6 +48,8 @@ func ProvideTokenRefreshService( cfg *config.Config, ) *TokenRefreshService { svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) + // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 + svc.SetSoraAccountRepo(soraAccountRepo) svc.Start() return svc } diff --git a/backend/migrations/045_add_accounts_extra_index.sql b/backend/migrations/045_add_accounts_extra_index.sql new file mode 100644 index 00000000..05414062 --- /dev/null +++ b/backend/migrations/045_add_accounts_extra_index.sql @@ -0,0 +1,13 @@ +-- Migration: 045_add_accounts_extra_index +-- 为 accounts.extra 字段添加 GIN 索引,优化 FindByExtraField 查询性能 +-- 用于支持通过 extra 字段中的 linked_openai_account_id 快速查找关联的 Sora 账号 + +CREATE INDEX IF NOT EXISTS idx_accounts_extra_gin +ON accounts USING GIN (extra); + +-- 查询示例(使用 @> 操作符) +-- EXPLAIN ANALYZE +-- SELECT * FROM accounts +-- WHERE platform = 'sora' +-- AND extra @> '{"linked_openai_account_id": 123}'::jsonb +-- AND deleted_at IS NULL; diff --git a/backend/migrations/046_add_sora_accounts.sql b/backend/migrations/046_add_sora_accounts.sql new file mode 100644 index 00000000..62f98718 --- /dev/null +++ b/backend/migrations/046_add_sora_accounts.sql @@ -0,0 +1,24 @@ +-- Migration: 046_add_sora_accounts +-- 新增 sora_accounts 扩展表,存储 Sora 账号的 OAuth 凭证 +-- 与 accounts 主表形成双表结构: +-- - accounts: 统一账号管理和调度 +-- - sora_accounts: Sora gateway 快速读取和资格校验 +-- +-- 设计说明: +-- - account_id 为主键,外键关联 accounts.id +-- - ON DELETE CASCADE 确保删除账号时自动清理扩展表 +-- - access_token/refresh_token 与 accounts.credentials 保持同步 + +CREATE TABLE IF NOT EXISTS sora_accounts ( + account_id BIGINT PRIMARY KEY, + access_token TEXT NOT NULL, + refresh_token TEXT NOT NULL, + session_token TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT fk_sora_accounts_account_id + FOREIGN KEY (account_id) REFERENCES accounts(id) + ON DELETE CASCADE +); + +-- 索引说明:主键已自动创建唯一索引,无需额外创建 idx_sora_accounts_account_id diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 144241ff..0e81a717 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1482,6 +1482,32 @@
+ +
+ +
+ ([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(true) +const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) @@ -2334,6 +2361,7 @@ const resetForm = () => { customErrorCodeInput.value = null interceptWarmupRequests.value = false autoPauseOnExpired.value = true + enableSoraOnOpenAIOAuth.value = false // Reset quota control state windowCostEnabled.value = false windowCostLimit.value = null @@ -2509,7 +2537,72 @@ const handleOpenAIExchange = async (authCode: string) => { const credentials = openaiOAuth.buildCredentials(tokenInfo) const extra = openaiOAuth.buildExtraInfo(tokenInfo) - await createAccountAndFinish('openai', 'oauth', credentials, extra) + + // 应用临时不可调度配置 + if (!applyTempUnschedConfig(credentials)) { + return + } + + // 1. 创建 OpenAI 账号 + const openaiAccount = await adminAPI.accounts.create({ + name: form.name, + notes: form.notes, + platform: 'openai', + type: 'oauth', + credentials, + extra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + + appStore.showSuccess(t('admin.accounts.accountCreated')) + + // 2. 如果启用了 Sora,同时创建 Sora 账号 + if (enableSoraOnOpenAIOAuth.value) { + try { + // Sora 使用相同的 OAuth credentials + const soraCredentials = { + access_token: credentials.access_token, + refresh_token: credentials.refresh_token, + expires_at: credentials.expires_at + } + + // 建立关联关系 + const soraExtra = { + ...extra, + linked_openai_account_id: String(openaiAccount.id) + } + + await adminAPI.accounts.create({ + name: `${form.name} (Sora)`, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials: soraCredentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + + appStore.showSuccess(t('admin.accounts.soraAccountCreated')) + } catch (error: any) { + console.error('创建 Sora 账号失败:', error) + appStore.showWarning(t('admin.accounts.soraAccountFailed')) + } + } + + emit('created') + handleClose() } catch (error: any) { openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') appStore.showError(openaiOAuth.error.value) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index e293491b..a1403a8e 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1245,7 +1245,9 @@ export default { // OpenAI specific hints openai: { baseUrlHint: 'Leave default for official OpenAI API', - apiKeyHint: 'Your OpenAI API Key' + apiKeyHint: 'Your OpenAI API Key', + enableSora: 'Enable Sora simultaneously', + enableSoraHint: 'Sora uses the same OpenAI account. Enable to create Sora account simultaneously.' }, modelRestriction: 'Model Restriction (Optional)', modelWhitelist: 'Model Whitelist', @@ -1337,6 +1339,8 @@ export default { creating: 'Creating...', updating: 'Updating...', accountCreated: 'Account created successfully', + soraAccountCreated: 'Sora account created simultaneously', + soraAccountFailed: 'Failed to create Sora account, please add manually later', accountUpdated: 'Account updated successfully', failedToCreate: 'Failed to create account', failedToUpdate: 'Failed to update account', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index dbeb3819..7b85ca64 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1380,7 +1380,9 @@ export default { // OpenAI specific hints openai: { baseUrlHint: '留空使用官方 OpenAI API', - apiKeyHint: '您的 OpenAI API Key' + apiKeyHint: '您的 OpenAI API Key', + enableSora: '同时启用 Sora', + enableSoraHint: 'Sora 使用相同的 OpenAI 账号,开启后将同时创建 Sora 平台账号' }, modelRestriction: '模型限制(可选)', modelWhitelist: '模型白名单', @@ -1469,6 +1471,8 @@ export default { creating: '创建中...', updating: '更新中...', accountCreated: '账号创建成功', + soraAccountCreated: 'Sora 账号已同时创建', + soraAccountFailed: 'Sora 账号创建失败,请稍后手动添加', accountUpdated: '账号更新成功', failedToCreate: '创建账号失败', failedToUpdate: '更新账号失败', From 618a614cbf15f4040abab456afd64ccd7777be96 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 31 Jan 2026 20:22:22 +0800 Subject: [PATCH 005/363] =?UTF-8?q?feat(Sora):=20=E5=AE=8C=E6=88=90Sora?= =?UTF-8?q?=E7=BD=91=E5=85=B3=E6=8E=A5=E5=85=A5=E4=B8=8E=E5=AA=92=E4=BD=93?= =?UTF-8?q?=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 Sora 网关路由、账号调度与同步服务\n补充媒体代理与签名 URL、模型列表动态拉取\n完善计费配置、前端支持与相关测试 --- README_CN.md | 21 + backend/cmd/server/wire_gen.go | 15 +- backend/ent/group.go | 58 +- backend/ent/group/group.go | 32 + backend/ent/group/where.go | 220 ++++++ backend/ent/group_create.go | 392 +++++++++++ backend/ent/group_update.go | 288 ++++++++ backend/ent/migrate/schema.go | 31 +- backend/ent/mutation.go | 615 ++++++++++++++-- backend/ent/runtime/runtime.go | 10 +- backend/ent/schema/group.go | 18 + backend/ent/schema/usage_log.go | 5 + backend/ent/usagelog.go | 16 +- backend/ent/usagelog/usagelog.go | 10 + backend/ent/usagelog/where.go | 80 +++ backend/ent/usagelog_create.go | 83 +++ backend/ent/usagelog_update.go | 62 ++ backend/internal/config/config.go | 97 +++ .../internal/handler/admin/group_handler.go | 20 +- .../internal/handler/admin/model_handler.go | 55 ++ .../handler/admin/model_handler_test.go | 87 +++ backend/internal/handler/dto/mappers.go | 5 + backend/internal/handler/dto/types.go | 7 + backend/internal/handler/gateway_handler.go | 23 + backend/internal/handler/handler.go | 2 + .../internal/handler/sora_gateway_handler.go | 474 +++++++++++++ backend/internal/handler/wire.go | 6 + .../pkg/tlsfingerprint/dialer_test.go | 23 +- backend/internal/repository/api_key_repo.go | 8 + backend/internal/repository/group_repo.go | 8 + backend/internal/repository/usage_log_repo.go | 12 +- backend/internal/server/routes/admin.go | 7 + backend/internal/server/routes/gateway.go | 36 + backend/internal/service/admin_service.go | 192 ++++- .../service/admin_service_bulk_update_test.go | 55 ++ .../internal/service/api_key_auth_cache.go | 4 + .../service/api_key_auth_cache_impl.go | 8 + backend/internal/service/billing_service.go | 67 ++ backend/internal/service/gateway_service.go | 26 +- backend/internal/service/group.go | 18 + .../internal/service/openai_token_provider.go | 6 +- .../service/openai_token_provider_test.go | 4 +- backend/internal/service/sora2api_service.go | 355 ++++++++++ .../internal/service/sora2api_sync_service.go | 255 +++++++ .../internal/service/sora_gateway_service.go | 660 ++++++++++++++++++ backend/internal/service/sora_media_sign.go | 42 ++ .../internal/service/sora_media_sign_test.go | 34 + .../service/token_cache_invalidator.go | 2 +- .../internal/service/token_refresh_service.go | 12 + backend/internal/service/token_refresher.go | 28 +- backend/internal/service/usage_log.go | 1 + backend/internal/service/wire.go | 7 + .../047_add_sora_pricing_and_media_type.sql | 11 + deploy/Caddyfile | 51 +- deploy/config.example.yaml | 52 ++ frontend/src/api/admin/index.ts | 7 +- frontend/src/api/admin/models.ts | 14 + .../account/ModelWhitelistSelector.vue | 65 +- .../admin/account/AccountTableFilters.vue | 2 +- frontend/src/components/common/GroupBadge.vue | 8 + .../src/components/common/PlatformIcon.vue | 6 + .../components/common/PlatformTypeBadge.vue | 7 + frontend/src/composables/useModelWhitelist.ts | 21 + frontend/src/i18n/locales/en.ts | 17 +- frontend/src/i18n/locales/zh.ts | 17 +- frontend/src/types/index.ts | 17 +- frontend/src/views/admin/GroupsView.vue | 145 +++- 67 files changed, 4840 insertions(+), 202 deletions(-) create mode 100644 backend/internal/handler/admin/model_handler.go create mode 100644 backend/internal/handler/admin/model_handler_test.go create mode 100644 backend/internal/handler/sora_gateway_handler.go create mode 100644 backend/internal/service/sora2api_service.go create mode 100644 backend/internal/service/sora2api_sync_service.go create mode 100644 backend/internal/service/sora_gateway_service.go create mode 100644 backend/internal/service/sora_media_sign.go create mode 100644 backend/internal/service/sora_media_sign_test.go create mode 100644 backend/migrations/047_add_sora_pricing_and_media_type.sql create mode 100644 frontend/src/api/admin/models.ts diff --git a/README_CN.md b/README_CN.md index 8129c3b2..707f0201 100644 --- a/README_CN.md +++ b/README_CN.md @@ -300,6 +300,27 @@ default: rate_multiplier: 1.0 ``` +### Sora 媒体签名 URL(可选) + +当配置 `gateway.sora_media_signing_key` 且 `gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL(`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query)。 + +```yaml +gateway: + # /sora/media 是否强制要求 API Key(默认 false) + sora_media_require_api_key: false + # 媒体临时签名密钥(为空则禁用签名) + sora_media_signing_key: "your-signing-key" + # 临时签名 URL 有效期(秒) + sora_media_signed_url_ttl_seconds: 900 +``` + +> 若未配置签名密钥,`/sora/media-signed` 将返回 503。 +> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true,仅允许携带 API Key 的 `/sora/media` 访问。 + +访问策略说明: +- `/sora/media`:内部调用或客户端携带 API Key 才能下载 +- `/sora/media-signed`:外部可访问,但有签名 + 过期控制 + `config.yaml` 还支持以下安全相关配置: - `cors.allowed_origins` 配置 CORS 白名单 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index b8668665..1d88b612 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -87,10 +87,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) soraAccountRepository := repository.NewSoraAccountRepository(db) + sora2APIService := service.NewSora2APIService(configConfig) + sora2APISyncService := service.NewSora2APISyncService(sora2APIService, accountRepository) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, sora2APISyncService, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -162,11 +164,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) + modelHandler := admin.NewModelHandler(sora2APIService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, modelHandler) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, sora2APIService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) + soraGatewayService := service.NewSoraGatewayService(sora2APIService, httpUpstream, rateLimitService, configConfig) + soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -177,7 +182,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, sora2APISyncService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ diff --git a/backend/ent/group.go b/backend/ent/group.go index 0d0c0538..0a32543b 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -52,6 +52,14 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` + // SoraImagePrice360 holds the value of the "sora_image_price_360" field. + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + // SoraImagePrice540 holds the value of the "sora_image_price_540" field. + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field. + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. + SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -170,7 +178,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: values[i] = new(sql.NullInt64) @@ -309,6 +317,34 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } + case group.FieldSoraImagePrice360: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i]) + } else if value.Valid { + _m.SoraImagePrice360 = new(float64) + *_m.SoraImagePrice360 = value.Float64 + } + case group.FieldSoraImagePrice540: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i]) + } else if value.Valid { + _m.SoraImagePrice540 = new(float64) + *_m.SoraImagePrice540 = value.Float64 + } + case group.FieldSoraVideoPricePerRequest: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequest = new(float64) + *_m.SoraVideoPricePerRequest = value.Float64 + } + case group.FieldSoraVideoPricePerRequestHd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequestHd = new(float64) + *_m.SoraVideoPricePerRequestHd = value.Float64 + } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -479,6 +515,26 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.SoraImagePrice360; v != nil { + builder.WriteString("sora_image_price_360=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraImagePrice540; v != nil { + builder.WriteString("sora_image_price_540=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequest; v != nil { + builder.WriteString("sora_video_price_per_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequestHd; v != nil { + builder.WriteString("sora_video_price_per_request_hd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index d66d3edc..7470dd82 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,6 +49,14 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" + // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database. + FieldSoraImagePrice360 = "sora_image_price_360" + // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database. + FieldSoraImagePrice540 = "sora_image_price_540" + // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database. + FieldSoraVideoPricePerRequest = "sora_video_price_per_request" + // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. + FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -149,6 +157,10 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, + FieldSoraImagePrice360, + FieldSoraImagePrice540, + FieldSoraVideoPricePerRequest, + FieldSoraVideoPricePerRequestHd, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldModelRouting, @@ -307,6 +319,26 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } +// BySoraImagePrice360 orders the results by the sora_image_price_360 field. +func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc() +} + +// BySoraImagePrice540 orders the results by the sora_image_price_540 field. +func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc() +} + +// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field. +func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc() +} + +// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field. +func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() +} + // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 6ce9e4c6..3f8f4c04 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,6 +140,26 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } +// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ. +func SoraImagePrice360(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ. +func SoraImagePrice540(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ. +func SoraVideoPricePerRequest(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ. +func SoraVideoPricePerRequestHd(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1010,6 +1030,206 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } +// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field. +func SoraImagePrice360In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field. +func SoraImagePrice360GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field. +func SoraImagePrice360LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field. +func SoraImagePrice540In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field. +func SoraImagePrice540GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field. +func SoraImagePrice540LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540)) +} + +// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540)) +} + +// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd)) +} + +// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) +} + // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 0f251e0b..ac5cb4d5 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,6 +258,62 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice360(v) + return _c +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice360(*v) + } + return _c +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice540(v) + return _c +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice540(*v) + } + return _c +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequest(v) + return _c +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequest(*v) + } + return _c +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequestHd(v) + return _c +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequestHd(*v) + } + return _c +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -632,6 +688,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } + if value, ok := _c.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + _node.SoraImagePrice360 = &value + } + if value, ok := _c.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + _node.SoraImagePrice540 = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + _node.SoraVideoPricePerRequest = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + _node.SoraVideoPricePerRequestHd = &value + } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -1092,6 +1164,102 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice360, v) + return u +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice360) + return u +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice360, v) + return u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice360) + return u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice540, v) + return u +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice540) + return u +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice540, v) + return u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice540) + return u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequest) + return u +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequest) + return u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequestHd) + return u +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequestHd) + return u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1539,6 +1707,118 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2163,6 +2443,118 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index c3cc2708..528a7fe9 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -354,6 +354,114 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -817,6 +925,42 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1472,6 +1616,114 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -1965,6 +2217,42 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d1f05186..fe1f80a8 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -224,6 +224,10 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, @@ -499,6 +503,7 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, + {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64}, @@ -514,31 +519,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -547,32 +552,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[26]}, }, { Name: "usagelog_model", @@ -587,12 +592,12 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 9b330616..b3d1e410 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -3836,61 +3836,69 @@ func (m *AccountGroupMutation) ResetEdge(name string) error { // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + sora_image_price_360 *float64 + addsora_image_price_360 *float64 + sora_image_price_540 *float64 + addsora_image_price_540 *float64 + sora_video_price_per_request *float64 + addsora_video_price_per_request *float64 + sora_video_price_per_request_hd *float64 + addsora_video_price_per_request_hd *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -4873,6 +4881,286 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (m *GroupMutation) SetSoraImagePrice360(f float64) { + m.sora_image_price_360 = &f + m.addsora_image_price_360 = nil +} + +// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation. +func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) { + v := m.sora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err) + } + return oldValue.SoraImagePrice360, nil +} + +// AddSoraImagePrice360 adds f to the "sora_image_price_360" field. +func (m *GroupMutation) AddSoraImagePrice360(f float64) { + if m.addsora_image_price_360 != nil { + *m.addsora_image_price_360 += f + } else { + m.addsora_image_price_360 = &f + } +} + +// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) { + v := m.addsora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (m *GroupMutation) ClearSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + m.clearedFields[group.FieldSoraImagePrice360] = struct{}{} +} + +// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice360Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice360] + return ok +} + +// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field. +func (m *GroupMutation) ResetSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + delete(m.clearedFields, group.FieldSoraImagePrice360) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (m *GroupMutation) SetSoraImagePrice540(f float64) { + m.sora_image_price_540 = &f + m.addsora_image_price_540 = nil +} + +// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation. +func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) { + v := m.sora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err) + } + return oldValue.SoraImagePrice540, nil +} + +// AddSoraImagePrice540 adds f to the "sora_image_price_540" field. +func (m *GroupMutation) AddSoraImagePrice540(f float64) { + if m.addsora_image_price_540 != nil { + *m.addsora_image_price_540 += f + } else { + m.addsora_image_price_540 = &f + } +} + +// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) { + v := m.addsora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (m *GroupMutation) ClearSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + m.clearedFields[group.FieldSoraImagePrice540] = struct{}{} +} + +// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice540Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice540] + return ok +} + +// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field. +func (m *GroupMutation) ResetSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + delete(m.clearedFields, group.FieldSoraImagePrice540) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) { + m.sora_video_price_per_request = &f + m.addsora_video_price_per_request = nil +} + +// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) { + v := m.sora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err) + } + return oldValue.SoraVideoPricePerRequest, nil +} + +// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field. +func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) { + if m.addsora_video_price_per_request != nil { + *m.addsora_video_price_per_request += f + } else { + m.addsora_video_price_per_request = &f + } +} + +// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) { + v := m.addsora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{} +} + +// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest] + return ok +} + +// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequest) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) { + m.sora_video_price_per_request_hd = &f + m.addsora_video_price_per_request_hd = nil +} + +// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.sora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err) + } + return oldValue.SoraVideoPricePerRequestHd, nil +} + +// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) { + if m.addsora_video_price_per_request_hd != nil { + *m.addsora_video_price_per_request_hd += f + } else { + m.addsora_video_price_per_request_hd = &f + } +} + +// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.addsora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{} +} + +// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd] + return ok +} + +// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -5422,7 +5710,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 21) + fields := make([]string, 0, 25) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -5474,6 +5762,18 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.sora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.sora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.sora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.sora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -5528,6 +5828,14 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() + case group.FieldSoraImagePrice360: + return m.SoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.SoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.SoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.SoraVideoPricePerRequestHd() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -5579,6 +5887,14 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) + case group.FieldSoraImagePrice360: + return m.OldSoraImagePrice360(ctx) + case group.FieldSoraImagePrice540: + return m.OldSoraImagePrice540(ctx) + case group.FieldSoraVideoPricePerRequest: + return m.OldSoraVideoPricePerRequest(ctx) + case group.FieldSoraVideoPricePerRequestHd: + return m.OldSoraVideoPricePerRequestHd(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -5715,6 +6031,34 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequestHd(v) + return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -5775,6 +6119,18 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.addsora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.addsora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.addsora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.addsora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -5802,6 +6158,14 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() + case group.FieldSoraImagePrice360: + return m.AddedSoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.AddedSoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.AddedSoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.AddedSoraVideoPricePerRequestHd() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() } @@ -5869,6 +6233,34 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequestHd(v) + return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -5908,6 +6300,18 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } + if m.FieldCleared(group.FieldSoraImagePrice360) { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.FieldCleared(group.FieldSoraImagePrice540) { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequest) { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } @@ -5952,6 +6356,18 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ClearSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ClearSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ClearSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ClearSoraVideoPricePerRequestHd() + return nil case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil @@ -6017,6 +6433,18 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ResetSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ResetSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ResetSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ResetSoraVideoPricePerRequestHd() + return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -11504,6 +11932,7 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string + media_type *string created_at *time.Time clearedFields map[string]struct{} user *int64 @@ -13130,6 +13559,55 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } +// SetMediaType sets the "media_type" field. +func (m *UsageLogMutation) SetMediaType(s string) { + m.media_type = &s +} + +// MediaType returns the value of the "media_type" field in the mutation. +func (m *UsageLogMutation) MediaType() (r string, exists bool) { + v := m.media_type + if v == nil { + return + } + return *v, true +} + +// OldMediaType returns the old "media_type" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMediaType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMediaType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMediaType: %w", err) + } + return oldValue.MediaType, nil +} + +// ClearMediaType clears the value of the "media_type" field. +func (m *UsageLogMutation) ClearMediaType() { + m.media_type = nil + m.clearedFields[usagelog.FieldMediaType] = struct{}{} +} + +// MediaTypeCleared returns if the "media_type" field was cleared in this mutation. +func (m *UsageLogMutation) MediaTypeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldMediaType] + return ok +} + +// ResetMediaType resets all changes to the "media_type" field. +func (m *UsageLogMutation) ResetMediaType() { + m.media_type = nil + delete(m.clearedFields, usagelog.FieldMediaType) +} + // SetCreatedAt sets the "created_at" field. func (m *UsageLogMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -13335,7 +13813,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 31) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -13423,6 +13901,9 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } + if m.media_type != nil { + fields = append(fields, usagelog.FieldMediaType) + } if m.created_at != nil { fields = append(fields, usagelog.FieldCreatedAt) } @@ -13492,6 +13973,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() + case usagelog.FieldMediaType: + return m.MediaType() case usagelog.FieldCreatedAt: return m.CreatedAt() } @@ -13561,6 +14044,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) + case usagelog.FieldMediaType: + return m.OldMediaType(ctx) case usagelog.FieldCreatedAt: return m.OldCreatedAt(ctx) } @@ -13775,6 +14260,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil + case usagelog.FieldMediaType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMediaType(v) + return nil case usagelog.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -14055,6 +14547,9 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } + if m.FieldCleared(usagelog.FieldMediaType) { + fields = append(fields, usagelog.FieldMediaType) + } return fields } @@ -14093,6 +14588,9 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil + case usagelog.FieldMediaType: + m.ClearMediaType() + return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -14188,6 +14686,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil + case usagelog.FieldMediaType: + m.ResetMediaType() + return nil case usagelog.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 1e3f4cbe..15b02ad1 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -278,11 +278,11 @@ func init() { // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[14].Descriptor() + groupDescClaudeCodeOnly := groupFields[18].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[17].Descriptor() + groupDescModelRoutingEnabled := groupFields[21].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) promocodeFields := schema.PromoCode{}.Fields() @@ -647,8 +647,12 @@ func init() { usagelogDescImageSize := usagelogFields[28].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) + // usagelogDescMediaType is the schema descriptor for media_type field. + usagelogDescMediaType := usagelogFields[29].Descriptor() + // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[29].Descriptor() + usagelogDescCreatedAt := usagelogFields[30].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 5d0a1e9a..7fa04b8a 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -87,6 +87,24 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Sora 按次计费配置(阶段 1) + field.Float("sora_image_price_360"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_image_price_540"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request_hd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index fc7c7165..602f23f6 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -118,6 +118,11 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), + // 媒体类型字段(sora 使用) + field.String("media_type"). + MaxLen(16). + Optional(). + Nillable(), // 时间戳(只有 created_at,日志不可修改) field.Time("created_at"). diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 81c466b4..63a14197 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -80,6 +80,8 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` + // MediaType holds the value of the "media_type" field. + MediaType *string `json:"media_type,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -171,7 +173,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -378,6 +380,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } + case usagelog.FieldMediaType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field media_type", values[i]) + } else if value.Valid { + _m.MediaType = new(string) + *_m.MediaType = value.String + } case usagelog.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -548,6 +557,11 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.MediaType; v != nil { + builder.WriteString("media_type=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteByte(')') diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 980f1e58..3ea5d054 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -72,6 +72,8 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" + // FieldMediaType holds the string denoting the media_type field in the database. + FieldMediaType = "media_type" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // EdgeUser holds the string denoting the user edge name in mutations. @@ -155,6 +157,7 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, + FieldMediaType, FieldCreatedAt, } @@ -211,6 +214,8 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error + // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + MediaTypeValidator func(string) error // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time ) @@ -368,6 +373,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } +// ByMediaType orders the results by the media_type field. +func ByMediaType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMediaType, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index 28e2ab4c..0a33dba2 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } +// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. +func MediaType(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) @@ -1440,6 +1445,81 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } +// MediaTypeEQ applies the EQ predicate on the "media_type" field. +func MediaTypeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + +// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. +func MediaTypeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v)) +} + +// MediaTypeIn applies the In predicate on the "media_type" field. +func MediaTypeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...)) +} + +// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. +func MediaTypeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...)) +} + +// MediaTypeGT applies the GT predicate on the "media_type" field. +func MediaTypeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldMediaType, v)) +} + +// MediaTypeGTE applies the GTE predicate on the "media_type" field. +func MediaTypeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v)) +} + +// MediaTypeLT applies the LT predicate on the "media_type" field. +func MediaTypeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldMediaType, v)) +} + +// MediaTypeLTE applies the LTE predicate on the "media_type" field. +func MediaTypeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v)) +} + +// MediaTypeContains applies the Contains predicate on the "media_type" field. +func MediaTypeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldMediaType, v)) +} + +// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. +func MediaTypeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v)) +} + +// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. +func MediaTypeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v)) +} + +// MediaTypeIsNil applies the IsNil predicate on the "media_type" field. +func MediaTypeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldMediaType)) +} + +// MediaTypeNotNil applies the NotNil predicate on the "media_type" field. +func MediaTypeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldMediaType)) +} + +// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. +func MediaTypeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v)) +} + +// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. +func MediaTypeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index a17d6507..668a0ede 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } +// SetMediaType sets the "media_type" field. +func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate { + _c.mutation.SetMediaType(v) + return _c +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { + if v != nil { + _c.SetMediaType(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { _c.mutation.SetCreatedAt(v) @@ -627,6 +641,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _c.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} } @@ -762,6 +781,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } + if value, ok := _c.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + _node.MediaType = &value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -1407,6 +1430,24 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert { + u.Set(usagelog.FieldMediaType, v) + return u +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldMediaType) + return u +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { + u.SetNull(usagelog.FieldMediaType) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2040,6 +2081,27 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + // Exec executes the query. func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2839,6 +2901,27 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + // Exec executes the query. func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 571a7b3c..22f2613f 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -612,6 +612,26 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { + _u.mutation.ClearMediaType() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { return _u.SetUserID(v.ID) @@ -726,6 +746,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -894,6 +919,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -1639,6 +1670,26 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { + _u.mutation.ClearMediaType() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { return _u.SetUserID(v.ID) @@ -1766,6 +1817,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -1951,6 +2007,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 00a78480..5dd2b415 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -58,6 +58,7 @@ type Config struct { UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Sora2API Sora2APIConfig `mapstructure:"sora2api"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` @@ -204,6 +205,24 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } +// Sora2APIConfig Sora2API 服务配置 +type Sora2APIConfig struct { + // BaseURL Sora2API 服务地址(例如 http://localhost:8000) + BaseURL string `mapstructure:"base_url"` + // APIKey Sora2API OpenAI 兼容接口的 API Key + APIKey string `mapstructure:"api_key"` + // AdminUsername 管理员用户名(用于 token 同步) + AdminUsername string `mapstructure:"admin_username"` + // AdminPassword 管理员密码(用于 token 同步) + AdminPassword string `mapstructure:"admin_password"` + // AdminTokenTTLSeconds 管理员 Token 缓存时长(秒) + AdminTokenTTLSeconds int `mapstructure:"admin_token_ttl_seconds"` + // AdminTimeoutSeconds 管理接口请求超时(秒) + AdminTimeoutSeconds int `mapstructure:"admin_timeout_seconds"` + // TokenImportMode token 导入模式:at/offline + TokenImportMode string `mapstructure:"token_import_mode"` +} + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -258,6 +277,24 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` + // Sora 专用配置 + // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) + SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` + // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) + SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` + // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) + SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` + // SoraStreamMode: stream 强制策略(force/error) + SoraStreamMode string `mapstructure:"sora_stream_mode"` + // SoraModelFilters: 模型列表过滤配置 + SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` + // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key + SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` + // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) + SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` + // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) + SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` + // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) MaxAccountSwitches int `mapstructure:"max_account_switches"` // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) @@ -273,6 +310,12 @@ type GatewayConfig struct { TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` } +// SoraModelFiltersConfig Sora 模型过滤配置 +type SoraModelFiltersConfig struct { + // HidePromptEnhance 是否隐藏 prompt-enhance 模型 + HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` +} + // TLSFingerprintConfig TLS指纹伪装配置 // 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端 type TLSFingerprintConfig struct { @@ -823,6 +866,13 @@ func setDefaults() { viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) + viper.SetDefault("gateway.sora_request_timeout_seconds", 180) + viper.SetDefault("gateway.sora_stream_mode", "force") + viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true) + viper.SetDefault("gateway.sora_media_require_api_key", true) + viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) @@ -869,6 +919,15 @@ func setDefaults() { viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") + + // Sora2API + viper.SetDefault("sora2api.base_url", "") + viper.SetDefault("sora2api.api_key", "") + viper.SetDefault("sora2api.admin_username", "") + viper.SetDefault("sora2api.admin_password", "") + viper.SetDefault("sora2api.admin_token_ttl_seconds", 900) + viper.SetDefault("sora2api.admin_timeout_seconds", 10) + viper.SetDefault("sora2api.token_import_mode", "at") } func (c *Config) Validate() error { @@ -1085,6 +1144,25 @@ func (c *Config) Validate() error { if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if c.Gateway.SoraMaxBodySize < 0 { + return fmt.Errorf("gateway.sora_max_body_size must be non-negative") + } + if c.Gateway.SoraStreamTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative") + } + if c.Gateway.SoraRequestTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative") + } + if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 { + return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative") + } + if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" { + switch mode { + case "force", "error": + default: + return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") + } + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -1181,6 +1259,25 @@ func (c *Config) Validate() error { c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds { return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds") } + if strings.TrimSpace(c.Sora2API.BaseURL) != "" { + if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil { + return fmt.Errorf("sora2api.base_url invalid: %w", err) + } + warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL) + } + if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" { + switch mode { + case "at", "offline": + default: + return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline") + } + } + if c.Sora2API.AdminTokenTTLSeconds < 0 { + return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative") + } + if c.Sora2API.AdminTimeoutSeconds < 0 { + return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative") + } if c.Ops.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 926624d2..1af570d9 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -38,6 +38,10 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) @@ -49,7 +53,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` @@ -61,6 +65,10 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) @@ -167,6 +175,10 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, ModelRouting: req.ModelRouting, @@ -209,6 +221,10 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, ModelRouting: req.ModelRouting, diff --git a/backend/internal/handler/admin/model_handler.go b/backend/internal/handler/admin/model_handler.go new file mode 100644 index 00000000..035b09bd --- /dev/null +++ b/backend/internal/handler/admin/model_handler.go @@ -0,0 +1,55 @@ +package admin + +import ( + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ModelHandler handles admin model listing requests. +type ModelHandler struct { + sora2apiService *service.Sora2APIService +} + +// NewModelHandler creates a new ModelHandler. +func NewModelHandler(sora2apiService *service.Sora2APIService) *ModelHandler { + return &ModelHandler{ + sora2apiService: sora2apiService, + } +} + +// List handles listing models for a specific platform +// GET /api/v1/admin/models?platform=sora +func (h *ModelHandler) List(c *gin.Context) { + platform := strings.TrimSpace(strings.ToLower(c.Query("platform"))) + if platform == "" { + response.BadRequest(c, "platform is required") + return + } + + switch platform { + case service.PlatformSora: + if h.sora2apiService == nil || !h.sora2apiService.Enabled() { + response.Error(c, http.StatusServiceUnavailable, "sora2api not configured") + return + } + models, err := h.sora2apiService.ListModels(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusServiceUnavailable, "failed to fetch sora models") + return + } + ids := make([]string, 0, len(models)) + for _, m := range models { + if strings.TrimSpace(m.ID) != "" { + ids = append(ids, m.ID) + } + } + response.Success(c, ids) + default: + response.BadRequest(c, "unsupported platform") + } +} diff --git a/backend/internal/handler/admin/model_handler_test.go b/backend/internal/handler/admin/model_handler_test.go new file mode 100644 index 00000000..e61dc064 --- /dev/null +++ b/backend/internal/handler/admin/model_handler_test.go @@ -0,0 +1,87 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +func TestModelHandlerListSoraSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`)) + })) + t.Cleanup(upstream.Close) + + cfg := &config.Config{} + cfg.Sora2API.BaseURL = upstream.URL + cfg.Sora2API.APIKey = "test-key" + soraService := service.NewSora2APIService(cfg) + + h := NewModelHandler(soraService) + router := gin.New() + router.GET("/admin/models", h.List) + + req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) + } + var resp response.Response + if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + if resp.Code != 0 { + t.Fatalf("响应 code=%d", resp.Code) + } + data, ok := resp.Data.([]any) + if !ok { + t.Fatalf("响应 data 类型错误") + } + if len(data) != 2 { + t.Fatalf("模型数量不符: %d", len(data)) + } +} + +func TestModelHandlerListSoraNotConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewModelHandler(&service.Sora2APIService{}) + router := gin.New() + router.GET("/admin/models", h.List) + + req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) + } +} + +func TestModelHandlerListInvalidPlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewModelHandler(&service.Sora2APIService{}) + router := gin.New() + router.GET("/admin/models", h.List) + + req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) + } +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d58a8a29..b44c3225 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -136,6 +136,10 @@ func groupFromServiceBase(g *service.Group) Group { ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, CreatedAt: g.CreatedAt, @@ -379,6 +383,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, ImageSize: l.ImageSize, + MediaType: l.MediaType, UserAgent: l.UserAgent, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 938d707c..3ae899ee 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -61,6 +61,12 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + // Sora 按次计费配置 + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` + // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` @@ -246,6 +252,7 @@ type UsageLog struct { // 图片生成字段 ImageCount int `json:"image_count"` ImageSize *string `json:"image_size"` + MediaType *string `json:"media_type"` // User-Agent UserAgent *string `json:"user_agent"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 70ea51bf..983cc6b3 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -29,6 +29,7 @@ type GatewayHandler struct { geminiCompatService *service.GeminiMessagesCompatService antigravityGatewayService *service.AntigravityGatewayService userService *service.UserService + sora2apiService *service.Sora2APIService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int @@ -41,6 +42,7 @@ func NewGatewayHandler( geminiCompatService *service.GeminiMessagesCompatService, antigravityGatewayService *service.AntigravityGatewayService, userService *service.UserService, + sora2apiService *service.Sora2APIService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, cfg *config.Config, @@ -62,6 +64,7 @@ func NewGatewayHandler( geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, userService: userService, + sora2apiService: sora2apiService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, @@ -478,6 +481,26 @@ func (h *GatewayHandler) Models(c *gin.Context) { groupID = &apiKey.Group.ID platform = apiKey.Group.Platform } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" { + platform = forcedPlatform + } + + if platform == service.PlatformSora { + if h.sora2apiService == nil || !h.sora2apiService.Enabled() { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured") + return + } + models, err := h.sora2apiService.ListModels(c.Request.Context()) + if err != nil { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models") + return + } + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": models, + }) + return + } // Get available models from account configurations (without platform filter) availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 5b1b317d..7905148c 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -23,6 +23,7 @@ type AdminHandlers struct { Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler + Model *admin.ModelHandler } // Handlers contains all HTTP handlers @@ -36,6 +37,7 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler + SoraGateway *SoraGatewayHandler Setting *SettingHandler } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go new file mode 100644 index 00000000..94f712df --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler.go @@ -0,0 +1,474 @@ +package handler + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "path" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// SoraGatewayHandler handles Sora chat completions requests +type SoraGatewayHandler struct { + gatewayService *service.GatewayService + soraGatewayService *service.SoraGatewayService + billingCacheService *service.BillingCacheService + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + streamMode string + sora2apiBaseURL string + soraMediaSigningKey string +} + +// NewSoraGatewayHandler creates a new SoraGatewayHandler +func NewSoraGatewayHandler( + gatewayService *service.GatewayService, + soraGatewayService *service.SoraGatewayService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, + cfg *config.Config, +) *SoraGatewayHandler { + pingInterval := time.Duration(0) + maxAccountSwitches := 3 + streamMode := "force" + signKey := "" + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { + streamMode = mode + } + signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) + } + baseURL := "" + if cfg != nil { + baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/") + } + return &SoraGatewayHandler{ + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + billingCacheService: billingCacheService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, + streamMode: strings.ToLower(streamMode), + sora2apiBaseURL: baseURL, + soraMediaSigningKey: signKey, + } +} + +// ChatCompletions handles Sora /v1/chat/completions endpoint +func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + reqModel, _ := reqBody["model"].(string) + if reqModel == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqMessages, _ := reqBody["messages"].([]any) + if len(reqMessages) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") + return + } + + clientStream, _ := reqBody["stream"].(bool) + if !clientStream { + if h.streamMode == "error" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") + return + } + reqBody["stream"] = true + updated, err := json.Marshal(reqBody) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + body = updated + } + + setOpsRequestContext(c, reqModel, clientStream, body) + + platform := "" + if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { + platform = forced + } else if apiKey.Group != nil { + platform = apiKey.Group.Platform + } + if platform != service.PlatformSora { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") + return + } + + streamStarted := false + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + log.Printf("Increment wait count failed: %v", err) + } else if !canWait { + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) + if err != nil { + log.Printf("User concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + log.Printf("Billing eligibility check failed after wait: %v", err) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := generateOpenAISessionHash(c, reqBody) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") + if err != nil { + log.Printf("[Sora Handler] SelectAccount failed: %v", err) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + account := selection.Account + setOpsSelectedAccount(c, account.ID) + + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + clientStream, + &streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + log.Printf("Account %d: Forward request failed: %v", account.ID, err) + return + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account, userAgent, clientIP) + return + } +} + +func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && reqBody != nil { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + sessionID = strings.TrimSpace(v) + } + } + if sessionID == "" { + return "" + } + hash := sha256.Sum256([]byte(sessionID)) + return hex.EncodeToString(hash[:]) +} + +func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", + fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) +} + +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + +func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { + if streamStarted { + flusher, ok := c.Writer.(http.Flusher) + if ok { + errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { + _ = c.Error(err) + } + flusher.Flush() + } + return + } + h.errorResponse(c, status, errType, message) +} + +func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// MediaProxy proxies /tmp or /static media files from sora2api +func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) { + h.proxySoraMedia(c, false) +} + +// MediaProxySigned proxies /tmp or /static media files with signature verification +func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) { + h.proxySoraMedia(c, true) +} + +func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) { + if h.sora2apiBaseURL == "" { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "sora2api 未配置", + }, + }) + return + } + + rawPath := c.Param("filepath") + if rawPath == "" { + c.Status(http.StatusNotFound) + return + } + cleaned := path.Clean(rawPath) + if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") { + c.Status(http.StatusNotFound) + return + } + + query := c.Request.URL.Query() + if requireSignature { + if h.soraMediaSigningKey == "" { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Sora 媒体签名未配置", + }, + }) + return + } + expiresStr := strings.TrimSpace(query.Get("expires")) + signature := strings.TrimSpace(query.Get("sig")) + expires, err := strconv.ParseInt(expiresStr, 10, 64) + if err != nil || expires <= time.Now().Unix() { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "type": "authentication_error", + "message": "Sora 媒体签名已过期", + }, + }) + return + } + query.Del("sig") + query.Del("expires") + signingQuery := query.Encode() + if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "type": "authentication_error", + "message": "Sora 媒体签名无效", + }, + }) + return + } + } + + targetURL := h.sora2apiBaseURL + cleaned + if rawQuery := query.Encode(); rawQuery != "" { + targetURL += "?" + rawQuery + } + + req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil) + if err != nil { + c.Status(http.StatusBadGateway) + return + } + copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"} + for _, key := range copyHeaders { + if val := c.GetHeader(key); val != "" { + req.Header.Set(key, val) + } + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.Status(http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} { + if val := resp.Header.Get(key); val != "" { + c.Header(key, val) + } + } + c.Status(resp.StatusCode) + _, _ = io.Copy(c.Writer, resp.Body) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 2af7905e..1e3ef17d 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -26,6 +26,7 @@ func ProvideAdminHandlers( subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, + modelHandler *admin.ModelHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -45,6 +46,7 @@ func ProvideAdminHandlers( Subscription: subscriptionHandler, Usage: usageHandler, UserAttribute: userAttributeHandler, + Model: modelHandler, } } @@ -69,6 +71,7 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, + soraGatewayHandler *SoraGatewayHandler, settingHandler *SettingHandler, ) *Handlers { return &Handlers{ @@ -81,6 +84,7 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, + SoraGateway: soraGatewayHandler, Setting: settingHandler, } } @@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet( NewSubscriptionHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewSoraGatewayHandler, ProvideSettingHandler, // Admin handlers @@ -116,6 +121,7 @@ var ProviderSet = wire.NewSet( admin.NewSubscriptionHandler, admin.NewUsageHandler, admin.NewUserAttributeHandler, + admin.NewModelHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index 845d51e5..31a59fc7 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -13,6 +13,7 @@ import ( "io" "net/http" "net/url" + "os" "strings" "testing" "time" @@ -38,9 +39,7 @@ type TLSInfo struct { // TestDialerBasicConnection tests that the dialer can establish TLS connections. func TestDialerBasicConnection(t *testing.T) { - if testing.Short() { - t.Skip("skipping network test in short mode") - } + skipNetworkTest(t) // Create a dialer with default profile profile := &Profile{ @@ -74,10 +73,7 @@ func TestDialerBasicConnection(t *testing.T) { // Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) func TestJA3Fingerprint(t *testing.T) { - // Skip if network is unavailable or if running in short mode - if testing.Short() { - t.Skip("skipping integration test in short mode") - } + skipNetworkTest(t) profile := &Profile{ Name: "Claude CLI Test", @@ -178,6 +174,15 @@ func TestJA3Fingerprint(t *testing.T) { } } +func skipNetworkTest(t *testing.T) { + if testing.Short() { + t.Skip("跳过网络测试(short 模式)") + } + if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" { + t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)") + } +} + // TestDialerWithProfile tests that different profiles produce different fingerprints. func TestDialerWithProfile(t *testing.T) { // Create two dialers with different profiles @@ -317,9 +322,7 @@ type TestProfileExpectation struct { // TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. // Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... func TestAllProfiles(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } + skipNetworkTest(t) // Define all profiles to test with their expected fingerprints // These profiles are from config.yaml gateway.tls_fingerprint.profiles diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index ab890844..9308326b 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -134,6 +134,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, + group.FieldSoraImagePrice360, + group.FieldSoraImagePrice540, + group.FieldSoraVideoPricePerRequest, + group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, group.FieldModelRoutingEnabled, @@ -421,6 +425,10 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 5c4d6cf4..75684fc9 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). @@ -106,6 +110,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 963db7ba..0696c958 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, created_at" type usageLogRepository struct { client *dbent.Client @@ -114,6 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ip_address, image_count, image_size, + media_type, created_at ) VALUES ( $1, $2, $3, $4, $5, @@ -121,7 +122,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) + mediaType := nullString(log.MediaType) var requestIDArg any if requestID != "" { @@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ipAddress, log.ImageCount, imageSize, + mediaType, createdAt, } if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { @@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ipAddress sql.NullString imageCount int imageSize sql.NullString + mediaType sql.NullString createdAt time.Time ) @@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &ipAddress, &imageCount, &imageSize, + &mediaType, &createdAt, ); err != nil { return nil, err @@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } + if mediaType.Valid { + log.MediaType = &mediaType.String + } return log, nil } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 050e724d..2c1762d3 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -64,6 +64,9 @@ func RegisterAdminRoutes( // 用户属性管理 registerUserAttributeRoutes(admin, h) + + // 模型列表 + registerModelRoutes(admin, h) } } @@ -371,3 +374,7 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) } } + +func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + admin.GET("/models", h.Admin.Model.List) +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index bf019ce3..32f34e0c 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -20,6 +20,11 @@ func RegisterGatewayRoutes( cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) + soraMaxBodySize := cfg.Gateway.SoraMaxBodySize + if soraMaxBodySize <= 0 { + soraMaxBodySize = cfg.Gateway.MaxBodySize + } + soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) @@ -38,6 +43,16 @@ func RegisterGatewayRoutes( gateway.POST("/responses", h.OpenAIGateway.Responses) } + // Sora Chat Completions + soraGateway := r.Group("/v1") + soraGateway.Use(soraBodyLimit) + soraGateway.Use(clientRequestID) + soraGateway.Use(opsErrorLogger) + soraGateway.Use(gin.HandlerFunc(apiKeyAuth)) + { + soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions) + } + // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) gemini := r.Group("/v1beta") gemini.Use(bodyLimit) @@ -82,4 +97,25 @@ func RegisterGatewayRoutes( antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } + + // Sora 专用路由(强制使用 sora 平台) + soraV1 := r.Group("/sora/v1") + soraV1.Use(soraBodyLimit) + soraV1.Use(clientRequestID) + soraV1.Use(opsErrorLogger) + soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) + soraV1.Use(gin.HandlerFunc(apiKeyAuth)) + { + soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) + soraV1.GET("/models", h.Gateway.Models) + } + + // Sora 媒体代理(可选 API Key 验证) + if cfg.Gateway.SoraMediaRequireAPIKey { + r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy) + } else { + r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy) + } + // Sora 媒体代理(签名 URL,无需 API Key) + r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 398de0e0..a29bf4db 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -102,11 +102,16 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled bool // 是否启用模型路由 @@ -124,11 +129,16 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled *bool // 是否启用模型路由 @@ -273,6 +283,7 @@ type adminServiceImpl struct { groupRepo GroupRepository accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 + soraSyncService *Sora2APISyncService // Sora2API 同步服务 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -288,6 +299,7 @@ func NewAdminService( groupRepo GroupRepository, accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, + soraSyncService *Sora2APISyncService, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -301,6 +313,7 @@ func NewAdminService( groupRepo: groupRepo, accountRepo: accountRepo, soraAccountRepo: soraAccountRepo, + soraSyncService: soraSyncService, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -567,6 +580,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + soraImagePrice360 := normalizePrice(input.SoraImagePrice360) + soraImagePrice540 := normalizePrice(input.SoraImagePrice540) + soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) + soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) // 校验降级分组 if input.FallbackGroupID != nil { @@ -576,22 +593,26 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn } group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - ImagePrice1K: imagePrice1K, - ImagePrice2K: imagePrice2K, - ImagePrice4K: imagePrice4K, - ClaudeCodeOnly: input.ClaudeCodeOnly, - FallbackGroupID: input.FallbackGroupID, - ModelRouting: input.ModelRouting, + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + SoraImagePrice360: soraImagePrice360, + SoraImagePrice540: soraImagePrice540, + SoraVideoPricePerRequest: soraVideoPrice, + SoraVideoPricePerRequestHD: soraVideoPriceHD, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + ModelRouting: input.ModelRouting, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -702,6 +723,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } + if input.SoraImagePrice360 != nil { + group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) + } + if input.SoraImagePrice540 != nil { + group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) + } + if input.SoraVideoPricePerRequest != nil { + group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) + } + if input.SoraVideoPricePerRequestHD != nil { + group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) + } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -884,6 +917,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // 同步到 sora2api(异步,不阻塞创建) + s.syncSoraAccountAsync(account) + return account, nil } @@ -974,7 +1010,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + s.syncSoraAccountAsync(updated) + return updated, nil } // BulkUpdateAccounts updates multiple accounts in one request. @@ -990,16 +1031,23 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp return result, nil } - // Preload account platforms for mixed channel risk checks if group bindings are requested. + needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + needSoraSync := s != nil && s.soraSyncService != nil + + // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} - if input.GroupIDs != nil && !input.SkipMixedChannelCheck { + if needMixedChannelCheck || needSoraSync { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { - return nil, err - } - for _, account := range accounts { - if account != nil { - platformByID[account.ID] = account.Platform + if needMixedChannelCheck { + return nil, err + } + log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err) + } else { + for _, account := range accounts { + if account != nil { + platformByID[account.ID] = account.Platform + } } } } @@ -1086,13 +1134,46 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Success++ result.SuccessIDs = append(result.SuccessIDs, accountID) result.Results = append(result.Results, entry) + + // 批量更新后同步 sora2api + if needSoraSync { + platform := platformByID[accountID] + if platform == "" { + updated, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) + continue + } + if updated.Platform == PlatformSora { + s.syncSoraAccountAsync(updated) + } + continue + } + + if platform == PlatformSora { + updated, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) + continue + } + s.syncSoraAccountAsync(updated) + } + } } return result, nil } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - return s.accountRepo.Delete(ctx, id) + account, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return err + } + if err := s.accountRepo.Delete(ctx, id); err != nil { + return err + } + s.deleteSoraAccountAsync(account) + return nil } func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { @@ -1125,7 +1206,46 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err } - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + s.syncSoraAccountAsync(updated) + return updated, nil +} + +func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) { + if s == nil || s.soraSyncService == nil || account == nil { + return + } + if account.Platform != PlatformSora { + return + } + syncAccount := *account + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil { + log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err) + } + }() +} + +func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) { + if s == nil || s.soraSyncService == nil || account == nil { + return + } + if account.Platform != PlatformSora { + return + } + syncAccount := *account + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil { + log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err) + } + }() } // Proxy management implementations diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 662b95fb..cbdbe625 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct { bulkUpdateErr error bulkUpdateIDs []int64 bindGroupErrByID map[int64]error + getByIDsAccounts []*Account + getByIDsErr error + getByIDsCalled bool + getByIDsIDs []int64 + getByIDAccounts map[int64]*Account + getByIDErrByID map[int64]error + getByIDCalled []int64 } func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { @@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i return nil } +func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) { + s.getByIDsCalled = true + s.getByIDsIDs = append([]int64{}, ids...) + if s.getByIDsErr != nil { + return nil, s.getByIDsErr + } + return s.getByIDsAccounts, nil +} + +func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) { + s.getByIDCalled = append(s.getByIDCalled, id) + if err, ok := s.getByIDErrByID[id]; ok { + return nil, err + } + if account, ok := s.getByIDAccounts[id]; ok { + return account, nil + } + return nil, errors.New("account not found") +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} @@ -78,3 +105,31 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } + +// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。 +func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformSora}, + }, + getByIDAccounts: map[int64]*Account{ + 1: {ID: 1, Platform: PlatformSora}, + }, + } + svc := &adminServiceImpl{ + accountRepo: repo, + soraSyncService: &Sora2APISyncService{}, + } + + schedulable := true + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + Schedulable: &schedulable, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 1, result.Success) + require.True(t, repo.getByIDsCalled) + require.ElementsMatch(t, []int64{1}, repo.getByIDCalled) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 5b476dbc..6247da00 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -35,6 +35,10 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index eb5c7534..5569a503 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -235,6 +235,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, ModelRouting: apiKey.Group.ModelRouting, @@ -279,6 +283,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, ModelRouting: snapshot.Group.ModelRouting, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index f2afc343..9b72bf6e 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -303,6 +303,14 @@ type ImagePriceConfig struct { Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } +// SoraPriceConfig Sora 按次计费配置 +type SoraPriceConfig struct { + ImagePrice360 *float64 + ImagePrice540 *float64 + VideoPricePerRequest *float64 + VideoPricePerRequestHD *float64 +} + // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" @@ -332,6 +340,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag } } +// CalculateSoraImageCost 计算 Sora 图片按次费用 +func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + if imageCount <= 0 { + return &CostBreakdown{} + } + + unitPrice := 0.0 + if groupConfig != nil { + switch imageSize { + case "540": + if groupConfig.ImagePrice540 != nil { + unitPrice = *groupConfig.ImagePrice540 + } + default: + if groupConfig.ImagePrice360 != nil { + unitPrice = *groupConfig.ImagePrice360 + } + } + } + + totalCost := unitPrice * float64(imageCount) + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + +// CalculateSoraVideoCost 计算 Sora 视频按次费用 +func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + unitPrice := 0.0 + if groupConfig != nil { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + if groupConfig.VideoPricePerRequestHD != nil { + unitPrice = *groupConfig.VideoPricePerRequestHD + } + } + if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil { + unitPrice = *groupConfig.VideoPricePerRequest + } + } + + totalCost := unitPrice + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + // getImageUnitPrice 获取图片单价 func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { // 优先使用分组配置的价格 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9565da29..f0933ae3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -184,6 +184,10 @@ type ForwardResult struct { // 图片生成计费字段(仅 gemini-3-pro-image 使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" + + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -3461,7 +3465,22 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu var cost *CostBreakdown // 根据请求类型选择计费方式 - if result.ImageCount > 0 { + if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == "image" { + cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } else { + cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + } + } else if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig if apiKey.Group != nil { @@ -3501,6 +3520,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.ImageSize != "" { imageSize = &result.ImageSize } + var mediaType *string + if strings.TrimSpace(result.MediaType) != "" { + mediaType = &result.MediaType + } accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, @@ -3526,6 +3549,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + MediaType: mediaType, CreatedAt: time.Now(), } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index d6d1269b..bc97e062 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,6 +26,12 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 + // Sora 按次计费配置(阶段 1) + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 @@ -83,6 +89,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { } } +// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540) +func (g *Group) GetSoraImagePrice(imageSize string) *float64 { + switch imageSize { + case "360": + return g.SoraImagePrice360 + case "540": + return g.SoraImagePrice540 + default: + return g.SoraImagePrice360 + } +} + // IsGroupContextValid reports whether a group from context has the fields required for routing decisions. func IsGroupContextValid(group *Group) bool { if group == nil { diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 87a7713b..026d9061 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -41,8 +41,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { - return "", errors.New("not an openai oauth account") + if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai/sora oauth account") } cacheKey := OpenAITokenCacheKey(account) @@ -157,7 +157,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - accessToken := account.GetOpenAIAccessToken() + accessToken := account.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found in credentials") } diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index c2e3dbb0..3c649a7e 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } diff --git a/backend/internal/service/sora2api_service.go b/backend/internal/service/sora2api_service.go new file mode 100644 index 00000000..d4bf9ba4 --- /dev/null +++ b/backend/internal/service/sora2api_service.go @@ -0,0 +1,355 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// Sora2APIModel represents a model entry returned by sora2api. +type Sora2APIModel struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by,omitempty"` + Description string `json:"description,omitempty"` +} + +// Sora2APIModelList represents /v1/models response. +type Sora2APIModelList struct { + Object string `json:"object"` + Data []Sora2APIModel `json:"data"` +} + +// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem. +type Sora2APIImportTokenItem struct { + Email string `json:"email"` + AccessToken string `json:"access_token,omitempty"` + SessionToken string `json:"session_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ClientID string `json:"client_id,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + Remark string `json:"remark,omitempty"` + IsActive bool `json:"is_active"` + ImageEnabled bool `json:"image_enabled"` + VideoEnabled bool `json:"video_enabled"` + ImageConcurrency int `json:"image_concurrency"` + VideoConcurrency int `json:"video_concurrency"` +} + +// Sora2APIToken represents minimal fields for admin list. +type Sora2APIToken struct { + ID int64 `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Remark string `json:"remark"` +} + +// Sora2APIService provides access to sora2api endpoints. +type Sora2APIService struct { + cfg *config.Config + + baseURL string + apiKey string + adminUsername string + adminPassword string + adminTokenTTL time.Duration + adminTimeout time.Duration + tokenImportMode string + + client *http.Client + adminClient *http.Client + + adminToken string + adminTokenAt time.Time + adminMu sync.Mutex + + modelCache []Sora2APIModel + modelCacheAt time.Time + modelMu sync.RWMutex +} + +func NewSora2APIService(cfg *config.Config) *Sora2APIService { + if cfg == nil { + return &Sora2APIService{} + } + adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second + if adminTTL <= 0 { + adminTTL = 15 * time.Minute + } + adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second + if adminTimeout <= 0 { + adminTimeout = 10 * time.Second + } + return &Sora2APIService{ + cfg: cfg, + baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"), + apiKey: strings.TrimSpace(cfg.Sora2API.APIKey), + adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername), + adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword), + adminTokenTTL: adminTTL, + adminTimeout: adminTimeout, + tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)), + client: &http.Client{}, + adminClient: &http.Client{Timeout: adminTimeout}, + } +} + +func (s *Sora2APIService) Enabled() bool { + return s != nil && s.baseURL != "" && s.apiKey != "" +} + +func (s *Sora2APIService) AdminEnabled() bool { + return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != "" +} + +func (s *Sora2APIService) buildURL(path string) string { + if s.baseURL == "" { + return path + } + if strings.HasPrefix(path, "/") { + return s.baseURL + path + } + return s.baseURL + "/" + path +} + +// BuildURL 返回完整的 sora2api URL(用于代理媒体) +func (s *Sora2APIService) BuildURL(path string) string { + return s.buildURL(path) +} + +func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) { + if !s.Enabled() { + return nil, errors.New("sora2api not configured") + } + req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+s.apiKey) + req.Header.Set("Content-Type", "application/json") + return req, nil +} + +func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) { + if !s.Enabled() { + return nil, errors.New("sora2api not configured") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+s.apiKey) + resp, err := s.client.Do(req) + if err != nil { + return s.cachedModelsOnError(err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode)) + } + + var payload Sora2APIModelList + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return s.cachedModelsOnError(err) + } + models := payload.Data + if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance { + filtered := make([]Sora2APIModel, 0, len(models)) + for _, m := range models { + if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") { + continue + } + filtered = append(filtered, m) + } + models = filtered + } + + s.modelMu.Lock() + s.modelCache = models + s.modelCacheAt = time.Now() + s.modelMu.Unlock() + + return models, nil +} + +func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) { + s.modelMu.RLock() + cached := append([]Sora2APIModel(nil), s.modelCache...) + s.modelMu.RUnlock() + if len(cached) > 0 { + log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err) + return cached, nil + } + return nil, err +} + +func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error { + if !s.AdminEnabled() { + return errors.New("sora2api admin not configured") + } + mode := s.tokenImportMode + if mode == "" { + mode = "at" + } + payload := map[string]any{ + "tokens": items, + "mode": mode, + } + _, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil) + return err +} + +func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) { + if !s.AdminEnabled() { + return nil, errors.New("sora2api admin not configured") + } + var tokens []Sora2APIToken + _, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens) + return tokens, err +} + +func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error { + if !s.AdminEnabled() { + return errors.New("sora2api admin not configured") + } + path := fmt.Sprintf("/api/tokens/%d/disable", tokenID) + _, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil) + return err +} + +func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error { + if !s.AdminEnabled() { + return errors.New("sora2api admin not configured") + } + path := fmt.Sprintf("/api/tokens/%d", tokenID) + _, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil) + return err +} + +func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) { + if !s.AdminEnabled() { + return nil, errors.New("sora2api admin not configured") + } + token, err := s.getAdminToken(ctx) + if err != nil { + return nil, err + } + resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out) + if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if resp != nil && resp.StatusCode == http.StatusUnauthorized { + s.invalidateAdminToken() + token, err = s.getAdminToken(ctx) + if err != nil { + return resp, err + } + return s.doAdminRequestWithToken(ctx, method, path, token, body, out) + } + return resp, err +} + +func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) { + var reader *bytes.Reader + if body != nil { + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + reader = bytes.NewReader(buf) + } else { + reader = bytes.NewReader(nil) + } + req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + resp, err := s.adminClient.Do(req) + if err != nil { + return resp, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode) + } + if out != nil { + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return resp, err + } + } + return resp, nil +} + +func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) { + s.adminMu.Lock() + defer s.adminMu.Unlock() + + if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL { + return s.adminToken, nil + } + + if !s.AdminEnabled() { + return "", errors.New("sora2api admin not configured") + } + + payload := map[string]string{ + "username": s.adminUsername, + "password": s.adminPassword, + } + buf, err := json.Marshal(payload) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + resp, err := s.adminClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode) + } + var result struct { + Success bool `json:"success"` + Token string `json:"token"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + if !result.Success || result.Token == "" { + if result.Message == "" { + result.Message = "sora2api login failed" + } + return "", errors.New(result.Message) + } + s.adminToken = result.Token + s.adminTokenAt = time.Now() + return result.Token, nil +} + +func (s *Sora2APIService) invalidateAdminToken() { + s.adminMu.Lock() + defer s.adminMu.Unlock() + s.adminToken = "" + s.adminTokenAt = time.Time{} +} diff --git a/backend/internal/service/sora2api_sync_service.go b/backend/internal/service/sora2api_sync_service.go new file mode 100644 index 00000000..33978432 --- /dev/null +++ b/backend/internal/service/sora2api_sync_service.go @@ -0,0 +1,255 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池 +type Sora2APISyncService struct { + sora2api *Sora2APIService + accountRepo AccountRepository + httpClient *http.Client +} + +func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService { + return &Sora2APISyncService{ + sora2api: sora2api, + accountRepo: accountRepo, + httpClient: &http.Client{Timeout: 10 * time.Second}, + } +} + +func (s *Sora2APISyncService) Enabled() bool { + return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled() +} + +// SyncAccount 将 Sora 账号同步到 sora2api(导入或更新) +func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error { + if !s.Enabled() { + return nil + } + if account == nil || account.Platform != PlatformSora { + return nil + } + + accessToken := strings.TrimSpace(account.GetCredential("access_token")) + if accessToken == "" { + return errors.New("sora 账号缺少 access_token") + } + + email, updated := s.resolveAccountEmail(ctx, account) + if email == "" { + return errors.New("无法解析 Sora 账号邮箱") + } + if updated && s.accountRepo != nil { + if err := s.accountRepo.Update(ctx, account); err != nil { + log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err) + } + } + + item := Sora2APIImportTokenItem{ + Email: email, + AccessToken: accessToken, + SessionToken: strings.TrimSpace(account.GetCredential("session_token")), + RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")), + ClientID: strings.TrimSpace(account.GetCredential("client_id")), + Remark: account.Name, + IsActive: account.IsActive() && account.Schedulable, + ImageEnabled: true, + VideoEnabled: true, + ImageConcurrency: normalizeSoraConcurrency(account.Concurrency), + VideoConcurrency: normalizeSoraConcurrency(account.Concurrency), + } + + if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil { + return err + } + return nil +} + +// DisableAccount 禁用 sora2api 中的 token +func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error { + if !s.Enabled() { + return nil + } + if account == nil || account.Platform != PlatformSora { + return nil + } + tokenID, err := s.resolveTokenID(ctx, account) + if err != nil { + return err + } + return s.sora2api.DisableToken(ctx, tokenID) +} + +// DeleteAccount 删除 sora2api 中的 token +func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error { + if !s.Enabled() { + return nil + } + if account == nil || account.Platform != PlatformSora { + return nil + } + tokenID, err := s.resolveTokenID(ctx, account) + if err != nil { + return err + } + return s.sora2api.DeleteToken(ctx, tokenID) +} + +func normalizeSoraConcurrency(value int) int { + if value <= 0 { + return -1 + } + return value +} + +func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) { + if account == nil { + return "", false + } + if email := strings.TrimSpace(account.GetCredential("email")); email != "" { + return email, false + } + if email := strings.TrimSpace(account.GetExtraString("email")); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + + accessToken := strings.TrimSpace(account.GetCredential("access_token")) + if accessToken != "" { + if email := extractEmailFromAccessToken(accessToken); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + if email := s.fetchEmailFromSora(ctx, accessToken); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + } + + return "", false +} + +func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) { + if account == nil { + return 0, errors.New("account is nil") + } + + if account.Extra != nil { + if v, ok := account.Extra["sora2api_token_id"]; ok { + if id, ok := v.(float64); ok && id > 0 { + return int64(id), nil + } + if id, ok := v.(int64); ok && id > 0 { + return id, nil + } + if id, ok := v.(int); ok && id > 0 { + return int64(id), nil + } + } + } + + email := strings.TrimSpace(account.GetCredential("email")) + if email == "" { + email, _ = s.resolveAccountEmail(ctx, account) + } + if email == "" { + return 0, errors.New("sora2api token email missing") + } + + tokenID, err := s.findTokenIDByEmail(ctx, email) + if err != nil { + return 0, err + } + return tokenID, nil +} + +func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) { + if !s.Enabled() { + return 0, errors.New("sora2api admin not configured") + } + tokens, err := s.sora2api.ListTokens(ctx) + if err != nil { + return 0, err + } + for _, token := range tokens { + if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) { + return token.ID, nil + } + } + return 0, fmt.Errorf("sora2api token not found for email: %s", email) +} + +func extractEmailFromAccessToken(accessToken string) string { + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + claims := jwt.MapClaims{} + _, _, err := parser.ParseUnverified(accessToken, claims) + if err != nil { + return "" + } + if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" { + return email + } + if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok { + if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" { + return email + } + } + return "" +} + +func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string { + if s.httpClient == nil { + return "" + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil) + if err != nil { + return "" + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + + resp, err := s.httpClient.Do(req) + if err != nil { + return "" + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return "" + } + var payload map[string]any + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return "" + } + if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" { + return email + } + return "" +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go new file mode 100644 index 00000000..82f4eaaa --- /dev/null +++ b/backend/internal/service/sora_gateway_service.go @@ -0,0 +1,660 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" +) + +var soraSSEDataRe = regexp.MustCompile(`^data:\s*`) +var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) +var soraVideoHTMLRe = regexp.MustCompile(`(?i)]+src=['"]([^'"]+)['"]`) + +var soraImageSizeMap = map[string]string{ + "gpt-image": "360", + "gpt-image-landscape": "540", + "gpt-image-portrait": "540", +} + +type soraStreamingResult struct { + content string + mediaType string + mediaURLs []string + imageCount int + imageSize string + firstTokenMs *int +} + +// SoraGatewayService handles forwarding requests to sora2api. +type SoraGatewayService struct { + sora2api *Sora2APIService + httpUpstream HTTPUpstream + rateLimitService *RateLimitService + cfg *config.Config +} + +func NewSoraGatewayService( + sora2api *Sora2APIService, + httpUpstream HTTPUpstream, + rateLimitService *RateLimitService, + cfg *config.Config, +) *SoraGatewayService { + return &SoraGatewayService{ + sora2api: sora2api, + httpUpstream: httpUpstream, + rateLimitService: rateLimitService, + cfg: cfg, + } +} + +func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { + startTime := time.Now() + + if s.sora2api == nil || !s.sora2api.Enabled() { + if c != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "sora2api 未配置", + }, + }) + } + return nil, errors.New("sora2api not configured") + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + reqModel, _ := reqBody["model"].(string) + reqStream, _ := reqBody["stream"].(bool) + + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel && mappedModel != "" { + reqBody["model"] = mappedModel + if updated, err := json.Marshal(reqBody); err == nil { + body = updated + } + } + + reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) + if cancel != nil { + defer cancel() + } + + upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body) + if err != nil { + return nil, err + } + if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + upstreamReq.Header.Set("User-Agent", ua) + } + } + if reqStream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } + + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + proxyURL := "" + if account != nil && account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + var resp *http.Response + if s.httpUpstream != nil { + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + } else { + resp, err = http.DefaultClient.Do(upstreamReq) + } + if err != nil { + s.setUpstreamRequestError(c, account, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + return s.handleErrorResponse(ctx, resp, c, account, reqModel) + } + + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream) + if err != nil { + return nil, err + } + + result := &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: streamResult.firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: streamResult.mediaType, + MediaURL: firstMediaURL(streamResult.mediaURLs), + ImageCount: streamResult.imageCount, + ImageSize: streamResult.imageSize, + } + + return result, nil +} + +func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if s == nil || s.cfg == nil { + return ctx, nil + } + timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds + if stream { + timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds + } + if timeoutSeconds <= 0 { + return ctx, nil + } + return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) +} + +func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if c != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + } +} + +func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + if s.rateLimitService == nil || account == nil || resp == nil { + return + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" { + upstreamMsg = msg + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if c != nil { + responsePayload := s.buildErrorPayload(respBody, upstreamMsg) + c.JSON(resp.StatusCode, responsePayload) + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any { + if len(respBody) > 0 { + var payload map[string]any + if err := json.Unmarshal(respBody, &payload); err == nil { + if errObj, ok := payload["error"].(map[string]any); ok { + if overrideMessage != "" { + errObj["message"] = overrideMessage + } + payload["error"] = errObj + return payload + } + } + } + return map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "message": overrideMessage, + }, + } +} + +func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) { + if resp == nil { + return nil, errors.New("empty response") + } + + if clientStream { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + } + + w := c.Writer + flusher, _ := w.(http.Flusher) + + contentBuilder := strings.Builder{} + var firstTokenMs *int + var upstreamError error + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + sendLine := func(line string) error { + if !clientStream { + return nil + } + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + return err + } + if flusher != nil { + flusher.Flush() + } + return nil + } + + for scanner.Scan() { + line := scanner.Text() + if soraSSEDataRe.MatchString(line) { + data := soraSSEDataRe.ReplaceAllString(line, "") + if data == "[DONE]" { + if err := sendLine("data: [DONE]"); err != nil { + return nil, err + } + break + } + updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel) + if errEvent != nil && upstreamError == nil { + upstreamError = errEvent + } + if contentDelta != "" { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + contentBuilder.WriteString(contentDelta) + } + if err := sendLine(updatedLine); err != nil { + return nil, err + } + continue + } + if err := sendLine(line); err != nil { + return nil, err + } + } + + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + if clientStream { + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n") + if flusher != nil { + flusher.Flush() + } + } + return nil, err + } + if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) + } + if clientStream { + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n") + if flusher != nil { + flusher.Flush() + } + } + return nil, err + } + + content := contentBuilder.String() + mediaType, mediaURLs := s.extractSoraMedia(content) + if mediaType == "" && isSoraPromptEnhanceModel(originalModel) { + mediaType = "prompt" + } + imageSize := "" + imageCount := 0 + if mediaType == "image" { + imageSize = soraImageSizeFromModel(originalModel) + imageCount = len(mediaURLs) + } + + if upstreamError != nil && !clientStream { + if c != nil { + c.JSON(http.StatusBadGateway, map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "message": upstreamError.Error(), + }, + }) + } + return nil, upstreamError + } + + if !clientStream { + response := buildSoraNonStreamResponse(content, originalModel) + if len(mediaURLs) > 0 { + response["media_url"] = mediaURLs[0] + if len(mediaURLs) > 1 { + response["media_urls"] = mediaURLs + } + } + c.JSON(http.StatusOK, response) + } + + return &soraStreamingResult{ + content: content, + mediaType: mediaType, + mediaURLs: mediaURLs, + imageCount: imageCount, + imageSize: imageSize, + firstTokenMs: firstTokenMs, + }, nil +} + +func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) { + if strings.TrimSpace(data) == "" { + return "data: ", "", nil + } + + var payload map[string]any + if err := json.Unmarshal([]byte(data), &payload); err != nil { + return "data: " + data, "", nil + } + + if errObj, ok := payload["error"].(map[string]any); ok { + if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { + return "data: " + data, "", errors.New(msg) + } + } + + if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" { + payload["model"] = originalModel + } + + contentDelta, updated := extractSoraContent(payload) + if updated { + rewritten := s.rewriteSoraContent(contentDelta) + if rewritten != contentDelta { + applySoraContent(payload, rewritten) + contentDelta = rewritten + } + } + + updatedData, err := json.Marshal(payload) + if err != nil { + return "data: " + data, contentDelta, nil + } + return "data: " + string(updatedData), contentDelta, nil +} + +func extractSoraContent(payload map[string]any) (string, bool) { + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + return "", false + } + choice, ok := choices[0].(map[string]any) + if !ok { + return "", false + } + if delta, ok := choice["delta"].(map[string]any); ok { + if content, ok := delta["content"].(string); ok { + return content, true + } + } + if message, ok := choice["message"].(map[string]any); ok { + if content, ok := message["content"].(string); ok { + return content, true + } + } + return "", false +} + +func applySoraContent(payload map[string]any, content string) { + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + return + } + choice, ok := choices[0].(map[string]any) + if !ok { + return + } + if delta, ok := choice["delta"].(map[string]any); ok { + delta["content"] = content + choice["delta"] = delta + return + } + if message, ok := choice["message"].(map[string]any); ok { + message["content"] = content + choice["message"] = message + } +} + +func (s *SoraGatewayService) rewriteSoraContent(content string) string { + if content == "" { + return content + } + content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string { + sub := soraImageMarkdownRe.FindStringSubmatch(match) + if len(sub) < 2 { + return match + } + rewritten := s.rewriteSoraURL(sub[1]) + if rewritten == sub[1] { + return match + } + return strings.Replace(match, sub[1], rewritten, 1) + }) + content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string { + sub := soraVideoHTMLRe.FindStringSubmatch(match) + if len(sub) < 2 { + return match + } + rewritten := s.rewriteSoraURL(sub[1]) + if rewritten == sub[1] { + return match + } + return strings.Replace(match, sub[1], rewritten, 1) + }) + return content +} + +func (s *SoraGatewayService) rewriteSoraURL(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return raw + } + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + path := parsed.Path + if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") { + return raw + } + return s.buildSoraMediaURL(path, parsed.RawQuery) +} + +func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) { + if content == "" { + return "", nil + } + if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 { + return "video", []string{match[1]} + } + imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1) + if len(imageMatches) == 0 { + return "", nil + } + urls := make([]string, 0, len(imageMatches)) + for _, match := range imageMatches { + if len(match) > 1 { + urls = append(urls, match[1]) + } + } + return "image", urls +} + +func buildSoraNonStreamResponse(content, model string) map[string]any { + return map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }, + }, + } +} + +func soraImageSizeFromModel(model string) string { + modelLower := strings.ToLower(model) + if size, ok := soraImageSizeMap[modelLower]; ok { + return size + } + if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { + return "540" + } + return "360" +} + +func isSoraPromptEnhanceModel(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance") +} + +func soraProErrorMessage(model, upstreamMsg string) string { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" + } + if strings.Contains(modelLower, "sora2pro") { + return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" + } + return "" +} + +func firstMediaURL(urls []string) string { + if len(urls) == 0 { + return "" + } + return urls[0] +} + +func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { + if path == "" { + return path + } + prefix := "/sora/media" + values := url.Values{} + if rawQuery != "" { + if parsed, err := url.ParseQuery(rawQuery); err == nil { + values = parsed + } + } + + signKey := "" + ttlSeconds := 0 + if s != nil && s.cfg != nil { + signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) + ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds + } + values.Del("sig") + values.Del("expires") + signingQuery := values.Encode() + if signKey != "" && ttlSeconds > 0 { + expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() + signature := SignSoraMediaURL(path, signingQuery, expires, signKey) + if signature != "" { + values.Set("expires", strconv.FormatInt(expires, 10)) + values.Set("sig", signature) + prefix = "/sora/media-signed" + } + } + + encoded := values.Encode() + if encoded == "" { + return prefix + path + } + return prefix + path + "?" + encoded +} diff --git a/backend/internal/service/sora_media_sign.go b/backend/internal/service/sora_media_sign.go new file mode 100644 index 00000000..5d4a8d88 --- /dev/null +++ b/backend/internal/service/sora_media_sign.go @@ -0,0 +1,42 @@ +package service + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "strconv" + "strings" +) + +// SignSoraMediaURL 生成 Sora 媒体临时签名 +func SignSoraMediaURL(path string, query string, expires int64, key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(buildSoraMediaSignPayload(path, query))) + mac.Write([]byte("|")) + mac.Write([]byte(strconv.FormatInt(expires, 10))) + return hex.EncodeToString(mac.Sum(nil)) +} + +// VerifySoraMediaURL 校验 Sora 媒体签名 +func VerifySoraMediaURL(path string, query string, expires int64, signature string, key string) bool { + signature = strings.TrimSpace(signature) + if signature == "" { + return false + } + expected := SignSoraMediaURL(path, query, expires, key) + if expected == "" { + return false + } + return hmac.Equal([]byte(signature), []byte(expected)) +} + +func buildSoraMediaSignPayload(path string, query string) string { + if strings.TrimSpace(query) == "" { + return path + } + return path + "?" + query +} diff --git a/backend/internal/service/sora_media_sign_test.go b/backend/internal/service/sora_media_sign_test.go new file mode 100644 index 00000000..2bbba987 --- /dev/null +++ b/backend/internal/service/sora_media_sign_test.go @@ -0,0 +1,34 @@ +package service + +import "testing" + +func TestSoraMediaSignVerify(t *testing.T) { + key := "test-key" + path := "/tmp/abc.png" + query := "a=1&b=2" + expires := int64(1700000000) + + signature := SignSoraMediaURL(path, query, expires, key) + if signature == "" { + t.Fatal("签名为空") + } + if !VerifySoraMediaURL(path, query, expires, signature, key) { + t.Fatal("签名校验失败") + } + if VerifySoraMediaURL(path, "a=1", expires, signature, key) { + t.Fatal("签名参数不同仍然通过") + } + if VerifySoraMediaURL(path, query, expires+1, signature, key) { + t.Fatal("签名过期校验未失败") + } +} + +func TestSoraMediaSignWithEmptyKey(t *testing.T) { + signature := SignSoraMediaURL("/tmp/a.png", "a=1", 1, "") + if signature != "" { + t.Fatalf("空密钥不应生成签名") + } + if VerifySoraMediaURL("/tmp/a.png", "a=1", 1, "sig", "") { + t.Fatalf("空密钥不应通过校验") + } +} diff --git a/backend/internal/service/token_cache_invalidator.go b/backend/internal/service/token_cache_invalidator.go index 74c9edc3..5c7ae8e9 100644 --- a/backend/internal/service/token_cache_invalidator.go +++ b/backend/internal/service/token_cache_invalidator.go @@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac // Antigravity 同样可能有两种缓存键 keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account)) keysToDelete = append(keysToDelete, "ag:"+accountIDKey) - case PlatformOpenAI: + case PlatformOpenAI, PlatformSora: keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account)) case PlatformAnthropic: keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account)) diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 797ab721..167d2b54 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -19,6 +19,7 @@ type TokenRefreshService struct { refreshers []TokenRefresher cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator + soraSyncService *Sora2APISyncService stopCh chan struct{} wg sync.WaitGroup @@ -65,6 +66,17 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { } } +// SetSoraSyncService 设置 Sora2API 同步服务 +// 需要在 Start() 之前调用 +func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) { + s.soraSyncService = svc + for _, refresher := range s.refreshers { + if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { + openaiRefresher.SetSoraSyncService(svc) + } + } +} + // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 807524fd..9699092d 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -86,6 +86,7 @@ type OpenAITokenRefresher struct { openaiOAuthService *OpenAIOAuthService accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 + soraSyncService *Sora2APISyncService // Sora2API 同步服务 } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -103,17 +104,22 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { r.soraAccountRepo = repo } +// SetSoraSyncService 设置 Sora2API 同步服务 +func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) { + r.soraSyncService = svc +} + // CanRefresh 检查是否能处理此账号 // 只处理 openai 平台的 oauth 类型账号 func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { - return account.Platform == PlatformOpenAI && + return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) && account.Type == AccountTypeOAuth } // NeedsRefresh 检查token是否需要刷新 // 基于 expires_at 字段判断是否在刷新窗口内 func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { - expiresAt := account.GetOpenAITokenExpiresAt() + expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { return false } @@ -145,6 +151,17 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) } + // 如果是 Sora 平台账号,同步到 sora2api(不阻塞主流程) + if account.Platform == PlatformSora && r.soraSyncService != nil { + syncAccount := *account + syncAccount.Credentials = newCredentials + go func() { + if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil { + log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err) + } + }() + } + return newCredentials, nil } @@ -201,6 +218,13 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena } } + // 2.3 同步到 sora2api(如果配置) + if r.soraSyncService != nil { + if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil { + log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err) + } + } + log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v", soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil) } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 3b0e934f..4be35501 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -46,6 +46,7 @@ type UsageLog struct { // 图片生成字段 ImageCount int ImageSize *string + MediaType *string CreatedAt time.Time diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 73d23025..689fa5d7 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -40,6 +40,7 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { func ProvideTokenRefreshService( accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步 + soraSyncService *Sora2APISyncService, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, @@ -50,6 +51,9 @@ func ProvideTokenRefreshService( svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) + if soraSyncService != nil { + svc.SetSoraSyncService(soraSyncService) + } svc.Start() return svc } @@ -224,6 +228,7 @@ var ProviderSet = wire.NewSet( NewBillingCacheService, NewAdminService, NewGatewayService, + NewSoraGatewayService, NewOpenAIGatewayService, NewOAuthService, NewOpenAIOAuthService, @@ -237,6 +242,8 @@ var ProviderSet = wire.NewSet( NewAntigravityTokenProvider, NewOpenAITokenProvider, NewClaudeTokenProvider, + NewSora2APIService, + NewSora2APISyncService, NewAntigravityGatewayService, ProvideRateLimitService, NewAccountUsageService, diff --git a/backend/migrations/047_add_sora_pricing_and_media_type.sql b/backend/migrations/047_add_sora_pricing_and_media_type.sql new file mode 100644 index 00000000..d70e37c5 --- /dev/null +++ b/backend/migrations/047_add_sora_pricing_and_media_type.sql @@ -0,0 +1,11 @@ +-- Migration: 047_add_sora_pricing_and_media_type +-- 新增 Sora 按次计费字段与 usage_logs.media_type + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS sora_image_price_360 decimal(20,8), + ADD COLUMN IF NOT EXISTS sora_image_price_540 decimal(20,8), + ADD COLUMN IF NOT EXISTS sora_video_price_per_request decimal(20,8), + ADD COLUMN IF NOT EXISTS sora_video_price_per_request_hd decimal(20,8); + +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS media_type VARCHAR(16); diff --git a/deploy/Caddyfile b/deploy/Caddyfile index d4144057..fce88654 100644 --- a/deploy/Caddyfile +++ b/deploy/Caddyfile @@ -1,39 +1,6 @@ -# ============================================================================= -# Sub2API Caddy Reverse Proxy Configuration (宿主机部署) -# ============================================================================= -# 使用方法: -# 1. 安装 Caddy: https://caddyserver.com/docs/install -# 2. 修改下方 example.com 为你的域名 -# 3. 确保域名 DNS 已指向服务器 -# 4. 复制配置: sudo cp Caddyfile /etc/caddy/Caddyfile -# 5. 重载配置: sudo systemctl reload caddy -# -# Caddy 会自动申请和续期 Let's Encrypt SSL 证书 -# ============================================================================= - -# 全局配置 -{ - # Let's Encrypt 邮箱通知 - email admin@example.com - - # 服务器配置 - servers { - # 启用 HTTP/2 和 HTTP/3 - protocols h1 h2 h3 - - # 超时配置 - timeouts { - read_body 30s - read_header 10s - write 300s - idle 300s - } - } -} - # 修改为你的域名 -example.com { - # ========================================================================= +api.sub2api.com { + # ========================================================================= # 静态资源长期缓存(高优先级,放在最前面) # 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存 # ========================================================================= @@ -87,17 +54,13 @@ example.com { # 连接池优化 transport http { - versions h2c h1 keepalive 120s keepalive_idle_conns 256 read_buffer 16KB write_buffer 16KB compression off } - - # SSE/流式传输优化:禁用响应缓冲,立即刷新数据给客户端 - flush_interval -1 - + # 故障转移 fail_duration 30s max_fails 3 @@ -112,10 +75,6 @@ example.com { gzip 6 minimum_length 256 match { - # SSE 请求通常会带 Accept: text/event-stream,需排除压缩 - not header Accept text/event-stream* - # 排除已知 SSE 路径(即便 Accept 缺失) - not path /v1/messages /v1/responses /responses /antigravity/v1/messages /v1beta/models/* /antigravity/v1beta/models/* header Content-Type text/* header Content-Type application/json* header Content-Type application/javascript* @@ -199,7 +158,3 @@ example.com { respond "{err.status_code} {err.status_text}" } } - -# ============================================================================= -# HTTP 重定向到 HTTPS (Caddy 默认自动处理,此处显式声明) -# ============================================================================= diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 558b8ef0..99386fc9 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -116,6 +116,33 @@ gateway: # Max request body size in bytes (default: 100MB) # 请求体最大字节数(默认 100MB) max_body_size: 104857600 + # Sora max request body size in bytes (0=use max_body_size) + # Sora 请求体最大字节数(0=使用 max_body_size) + sora_max_body_size: 268435456 + # Sora stream timeout (seconds, 0=disable) + # Sora 流式请求总超时(秒,0=禁用) + sora_stream_timeout_seconds: 900 + # Sora non-stream timeout (seconds, 0=disable) + # Sora 非流式请求超时(秒,0=禁用) + sora_request_timeout_seconds: 180 + # Sora stream enforcement mode: force/error + # Sora stream 强制策略:force/error + sora_stream_mode: "force" + # Sora model filters + # Sora 模型过滤配置 + sora_model_filters: + # Hide prompt-enhance models by default + # 默认隐藏 prompt-enhance 模型 + hide_prompt_enhance: true + # Require API key for /sora/media proxy (default: false) + # /sora/media 是否强制要求 API Key(默认 true) + sora_media_require_api_key: true + # Sora media temporary signing key (empty disables signed URL) + # Sora 媒体临时签名密钥(为空则禁用签名) + sora_media_signing_key: "" + # Signed URL TTL seconds (<=0 disables) + # 临时签名 URL 有效期(秒,<=0 表示禁用) + sora_media_signed_url_ttl_seconds: 900 # Connection pool isolation strategy: # 连接池隔离策略: # - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts) @@ -220,6 +247,31 @@ gateway: # name: "Custom Profile 1" # profile_2: # name: "Custom Profile 2" + +# ============================================================================= +# Sora2API Configuration +# Sora2API 配置 +# ============================================================================= +sora2api: + # Sora2API base URL + # Sora2API 服务地址 + base_url: "http://127.0.0.1:8000" + # Sora2API API Key (for /v1/chat/completions and /v1/models) + # Sora2API API Key(用于生成/模型列表) + api_key: "" + # Admin username/password (for token sync) + # 管理口用户名/密码(用于 token 同步) + admin_username: "admin" + admin_password: "admin" + # Admin token cache ttl (seconds) + # 管理口 token 缓存时长(秒) + admin_token_ttl_seconds: 900 + # Admin request timeout (seconds) + # 管理口请求超时(秒) + admin_timeout_seconds: 10 + # Token import mode: at/offline + # Token 导入模式:at/offline + token_import_mode: "at" # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] # curves: [29, 23, 24] # point_formats: [0] diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index e86f6348..505c1419 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -18,6 +18,7 @@ import geminiAPI from './gemini' import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' +import modelsAPI from './models' /** * Unified admin API object for convenient access @@ -37,7 +38,8 @@ export const adminAPI = { gemini: geminiAPI, antigravity: antigravityAPI, userAttributes: userAttributesAPI, - ops: opsAPI + ops: opsAPI, + models: modelsAPI } export { @@ -55,7 +57,8 @@ export { geminiAPI, antigravityAPI, userAttributesAPI, - opsAPI + opsAPI, + modelsAPI } export default adminAPI diff --git a/frontend/src/api/admin/models.ts b/frontend/src/api/admin/models.ts new file mode 100644 index 00000000..897304ac --- /dev/null +++ b/frontend/src/api/admin/models.ts @@ -0,0 +1,14 @@ +import { apiClient } from '@/api/client' + +export async function getPlatformModels(platform: string): Promise { + const { data } = await apiClient.get('/admin/models', { + params: { platform } + }) + return data +} + +export const modelsAPI = { + getPlatformModels +} + +export default modelsAPI diff --git a/frontend/src/components/account/ModelWhitelistSelector.vue b/frontend/src/components/account/ModelWhitelistSelector.vue index c8c1b852..227e6e61 100644 --- a/frontend/src/components/account/ModelWhitelistSelector.vue +++ b/frontend/src/components/account/ModelWhitelistSelector.vue @@ -45,6 +45,19 @@ :placeholder="t('admin.accounts.searchModels')" @click.stop /> +
+ + {{ t('admin.accounts.soraModelsLoading') }} + + +
+ +
+ +

+ {{ t('admin.groups.soraPricing.description') }} +

+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
@@ -848,6 +906,64 @@
+ +
+ +

+ {{ t('admin.groups.soraPricing.description') }} +

+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
@@ -1152,7 +1268,8 @@ const platformOptions = computed(() => [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, - { value: 'antigravity', label: 'Antigravity' } + { value: 'antigravity', label: 'Antigravity' }, + { value: 'sora', label: 'Sora' } ]) const platformFilterOptions = computed(() => [ @@ -1160,7 +1277,8 @@ const platformFilterOptions = computed(() => [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, - { value: 'antigravity', label: 'Antigravity' } + { value: 'antigravity', label: 'Antigravity' }, + { value: 'sora', label: 'Sora' } ]) const editStatusOptions = computed(() => [ @@ -1240,6 +1358,16 @@ const createForm = reactive({ image_price_1k: null as number | null, image_price_2k: null as number | null, image_price_4k: null as number | null, + // Sora 按次计费配置 + sora_image_price_360: null as number | null, + sora_image_price_540: null as number | null, + sora_video_price_per_request: null as number | null, + sora_video_price_per_request_hd: null as number | null, + // Sora 按次计费配置 + sora_image_price_360: null as number | null, + sora_image_price_540: null as number | null, + sora_video_price_per_request: null as number | null, + sora_video_price_per_request_hd: null as number | null, // Claude Code 客户端限制(仅 anthropic 平台使用) claude_code_only: false, fallback_group_id: null as number | null, @@ -1411,6 +1539,11 @@ const editForm = reactive({ image_price_1k: null as number | null, image_price_2k: null as number | null, image_price_4k: null as number | null, + // Sora 按次计费配置 + sora_image_price_360: null as number | null, + sora_image_price_540: null as number | null, + sora_video_price_per_request: null as number | null, + sora_video_price_per_request_hd: null as number | null, // Claude Code 客户端限制(仅 anthropic 平台使用) claude_code_only: false, fallback_group_id: null as number | null, @@ -1495,6 +1628,10 @@ const closeCreateModal = () => { createForm.image_price_1k = null createForm.image_price_2k = null createForm.image_price_4k = null + createForm.sora_image_price_360 = null + createForm.sora_image_price_540 = null + createForm.sora_video_price_per_request = null + createForm.sora_video_price_per_request_hd = null createForm.claude_code_only = false createForm.fallback_group_id = null createModelRoutingRules.value = [] @@ -1544,6 +1681,10 @@ const handleEdit = async (group: AdminGroup) => { editForm.image_price_1k = group.image_price_1k editForm.image_price_2k = group.image_price_2k editForm.image_price_4k = group.image_price_4k + editForm.sora_image_price_360 = group.sora_image_price_360 + editForm.sora_image_price_540 = group.sora_image_price_540 + editForm.sora_video_price_per_request = group.sora_video_price_per_request + editForm.sora_video_price_per_request_hd = group.sora_video_price_per_request_hd editForm.claude_code_only = group.claude_code_only || false editForm.fallback_group_id = group.fallback_group_id editForm.model_routing_enabled = group.model_routing_enabled || false From 78d0ca3775da8fb2162dca8d7358508c2e9f3fdb Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 31 Jan 2026 21:46:28 +0800 Subject: [PATCH 006/363] =?UTF-8?q?fix(sora):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E9=87=8D=E5=86=99=E4=B8=8E=E8=AE=A1=E8=B4=B9?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/config/config.go | 15 +- .../internal/handler/admin/group_handler.go | 110 +++++++------- backend/internal/handler/dto/mappers.go | 42 +++--- backend/internal/handler/dto/types.go | 6 +- .../internal/handler/sora_gateway_handler.go | 12 +- backend/internal/repository/account_repo.go | 2 +- backend/internal/repository/api_key_repo.go | 50 +++---- .../internal/repository/sora_account_repo.go | 2 +- backend/internal/service/account_service.go | 2 +- .../internal/service/api_key_auth_cache.go | 34 ++--- .../service/api_key_auth_cache_impl.go | 78 +++++----- backend/internal/service/gateway_service.go | 4 +- backend/internal/service/group.go | 6 +- backend/internal/service/sora2api_service.go | 8 +- .../internal/service/sora_gateway_service.go | 134 +++++++++++++++++- backend/internal/service/sora_media_sign.go | 12 +- .../internal/service/token_refresh_service.go | 4 - backend/internal/service/token_refresher.go | 8 +- backend/internal/service/wire.go | 6 +- 19 files changed, 325 insertions(+), 210 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 5dd2b415..f3dec213 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -928,6 +928,7 @@ func setDefaults() { viper.SetDefault("sora2api.admin_token_ttl_seconds", 900) viper.SetDefault("sora2api.admin_timeout_seconds", 10) viper.SetDefault("sora2api.token_import_mode", "at") + } func (c *Config) Validate() error { @@ -1263,20 +1264,6 @@ func (c *Config) Validate() error { if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil { return fmt.Errorf("sora2api.base_url invalid: %w", err) } - warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL) - } - if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" { - switch mode { - case "at", "offline": - default: - return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline") - } - } - if c.Sora2API.AdminTokenTTLSeconds < 0 { - return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative") - } - if c.Sora2API.AdminTimeoutSeconds < 0 { - return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative") } if c.Ops.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 1af570d9..328c8fce 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -35,15 +35,15 @@ type CreateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` @@ -62,15 +62,15 @@ type UpdateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - ClaudeCodeOnly *bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled *bool `json:"model_routing_enabled"` @@ -163,26 +163,26 @@ func (h *GroupHandler) Create(c *gin.Context) { } group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) @@ -208,27 +208,27 @@ func (h *GroupHandler) Update(c *gin.Context) { } group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: req.Status, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: req.Status, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index b44c3225..58a4ad86 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -122,28 +122,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { func groupFromServiceBase(g *service.Group) Group { return Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 3ae899ee..505f9dd4 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -62,9 +62,9 @@ type Group struct { ImagePrice4K *float64 `json:"image_price_4k"` // Sora 按次计费配置 - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` // Claude Code 客户端限制 diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 94f712df..05833144 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -33,6 +33,7 @@ type SoraGatewayHandler struct { streamMode string sora2apiBaseURL string soraMediaSigningKey string + mediaClient *http.Client } // NewSoraGatewayHandler creates a new SoraGatewayHandler @@ -61,6 +62,10 @@ func NewSoraGatewayHandler( if cfg != nil { baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/") } + mediaTimeout := 180 * time.Second + if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 { + mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second + } return &SoraGatewayHandler{ gatewayService: gatewayService, soraGatewayService: soraGatewayService, @@ -70,6 +75,7 @@ func NewSoraGatewayHandler( streamMode: strings.ToLower(streamMode), sora2apiBaseURL: baseURL, soraMediaSigningKey: signKey, + mediaClient: &http.Client{Timeout: mediaTimeout}, } } @@ -457,7 +463,11 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo } } - resp, err := http.DefaultClient.Do(req) + client := h.mediaClient + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) if err != nil { c.Status(http.StatusBadGateway) return diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 5edc4f6d..170e5de9 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1565,7 +1565,7 @@ func itoa(v int) string { // Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). // // Use case: Finding Sora accounts linked via linked_openai_account_id. -func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value interface{}) ([]service.Account, error) { +func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { accounts, err := r.client.Account.Query(). Where( dbaccount.PlatformEQ("sora"), // 限定平台为 sora diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 9308326b..a020ee2b 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -410,32 +410,32 @@ func groupEntityToService(g *dbent.Group) *service.Group { return nil } return &service.Group{ - ID: g.ID, - Name: g.Name, - Description: derefString(g.Description), - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - Hydrated: true, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - ImagePrice1K: g.ImagePrice1k, - ImagePrice2K: g.ImagePrice2k, - ImagePrice4K: g.ImagePrice4k, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + ID: g.ID, + Name: g.Name, + Description: derefString(g.Description), + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + Hydrated: true, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + ImagePrice1K: g.ImagePrice1k, + ImagePrice2K: g.ImagePrice2k, + ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, - DefaultValidityDays: g.DefaultValidityDays, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go index e0ec6073..ad2ae638 100644 --- a/backend/internal/repository/sora_account_repo.go +++ b/backend/internal/repository/sora_account_repo.go @@ -76,7 +76,7 @@ func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID in if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() if !rows.Next() { return nil, nil // 记录不存在 diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 4befc996..a261fb21 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -27,7 +27,7 @@ type AccountRepository interface { GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 - FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) + FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) Update(ctx context.Context, account *Account) error Delete(ctx context.Context, id int64) error diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 6247da00..9d8f87f2 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -23,24 +23,24 @@ type APIKeyAuthUserSnapshot struct { // APIKeyAuthGroupSnapshot 分组快照 type APIKeyAuthGroupSnapshot struct { - ID int64 `json:"id"` - Name string `json:"name"` - Platform string `json:"platform"` - Status string `json:"status"` - SubscriptionType string `json:"subscription_type"` - RateMultiplier float64 `json:"rate_multiplier"` - DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` - ImagePrice1K *float64 `json:"image_price_1k,omitempty"` - ImagePrice2K *float64 `json:"image_price_2k,omitempty"` - ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` - SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Only anthropic groups use these fields; others may leave them empty. diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 5569a503..19ba4e79 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -223,26 +223,26 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ - ID: apiKey.Group.ID, - Name: apiKey.Group.Name, - Platform: apiKey.Group.Platform, - Status: apiKey.Group.Status, - SubscriptionType: apiKey.Group.SubscriptionType, - RateMultiplier: apiKey.Group.RateMultiplier, - DailyLimitUSD: apiKey.Group.DailyLimitUSD, - WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, - MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, - ImagePrice1K: apiKey.Group.ImagePrice1K, - ImagePrice2K: apiKey.Group.ImagePrice2K, - ImagePrice4K: apiKey.Group.ImagePrice4K, - SoraImagePrice360: apiKey.Group.SoraImagePrice360, - SoraImagePrice540: apiKey.Group.SoraImagePrice540, - SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, - FallbackGroupID: apiKey.Group.FallbackGroupID, - ModelRouting: apiKey.Group.ModelRouting, - ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + ModelRouting: apiKey.Group.ModelRouting, + ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, } } return snapshot @@ -270,27 +270,27 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho } if snapshot.Group != nil { apiKey.Group = &Group{ - ID: snapshot.Group.ID, - Name: snapshot.Group.Name, - Platform: snapshot.Group.Platform, - Status: snapshot.Group.Status, - Hydrated: true, - SubscriptionType: snapshot.Group.SubscriptionType, - RateMultiplier: snapshot.Group.RateMultiplier, - DailyLimitUSD: snapshot.Group.DailyLimitUSD, - WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, - MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, - ImagePrice1K: snapshot.Group.ImagePrice1K, - ImagePrice2K: snapshot.Group.ImagePrice2K, - ImagePrice4K: snapshot.Group.ImagePrice4K, - SoraImagePrice360: snapshot.Group.SoraImagePrice360, - SoraImagePrice540: snapshot.Group.SoraImagePrice540, - SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, - FallbackGroupID: snapshot.Group.FallbackGroupID, - ModelRouting: snapshot.Group.ModelRouting, - ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + ModelRouting: snapshot.Group.ModelRouting, + ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, } } return apiKey diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f0933ae3..6925801d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3465,7 +3465,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu var cost *CostBreakdown // 根据请求类型选择计费方式 - if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" { + if result.MediaType == "image" || result.MediaType == "video" { var soraConfig *SoraPriceConfig if apiKey.Group != nil { soraConfig = &SoraPriceConfig{ @@ -3480,6 +3480,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else { cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) } + } else if result.MediaType == "prompt" { + cost = &CostBreakdown{} } else if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index bc97e062..e8bf03d4 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -27,9 +27,9 @@ type Group struct { ImagePrice4K *float64 // Sora 按次计费配置(阶段 1) - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 SoraVideoPricePerRequestHD *float64 // Claude Code 客户端限制 diff --git a/backend/internal/service/sora2api_service.go b/backend/internal/service/sora2api_service.go index d4bf9ba4..c047cd40 100644 --- a/backend/internal/service/sora2api_service.go +++ b/backend/internal/service/sora2api_service.go @@ -62,7 +62,6 @@ type Sora2APIService struct { adminUsername string adminPassword string adminTokenTTL time.Duration - adminTimeout time.Duration tokenImportMode string client *http.Client @@ -72,9 +71,8 @@ type Sora2APIService struct { adminTokenAt time.Time adminMu sync.Mutex - modelCache []Sora2APIModel - modelCacheAt time.Time - modelMu sync.RWMutex + modelCache []Sora2APIModel + modelMu sync.RWMutex } func NewSora2APIService(cfg *config.Config) *Sora2APIService { @@ -96,7 +94,6 @@ func NewSora2APIService(cfg *config.Config) *Sora2APIService { adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername), adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword), adminTokenTTL: adminTTL, - adminTimeout: adminTimeout, tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)), client: &http.Client{}, adminClient: &http.Client{Timeout: adminTimeout}, @@ -176,7 +173,6 @@ func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, erro s.modelMu.Lock() s.modelCache = models - s.modelCacheAt = time.Now() s.modelMu.Unlock() return models, nil diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 82f4eaaa..2909a76f 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -23,6 +23,8 @@ var soraSSEDataRe = regexp.MustCompile(`^data:\s*`) var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) var soraVideoHTMLRe = regexp.MustCompile(`(?i)]+src=['"]([^'"]+)['"]`) +const soraRewriteBufferLimit = 2048 + var soraImageSizeMap = map[string]string{ "gpt-image": "360", "gpt-image-landscape": "540", @@ -30,7 +32,6 @@ var soraImageSizeMap = map[string]string{ } type soraStreamingResult struct { - content string mediaType string mediaURLs []string imageCount int @@ -307,6 +308,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * contentBuilder := strings.Builder{} var firstTokenMs *int var upstreamError error + rewriteBuffer := "" scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -333,12 +335,29 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * if soraSSEDataRe.MatchString(line) { data := soraSSEDataRe.ReplaceAllString(line, "") if data == "[DONE]" { + if rewriteBuffer != "" { + flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel) + if err != nil { + return nil, err + } + if flushLine != "" { + if flushContent != "" { + if _, err := contentBuilder.WriteString(flushContent); err != nil { + return nil, err + } + } + if err := sendLine(flushLine); err != nil { + return nil, err + } + } + rewriteBuffer = "" + } if err := sendLine("data: [DONE]"); err != nil { return nil, err } break } - updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel) + updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer) if errEvent != nil && upstreamError == nil { upstreamError = errEvent } @@ -347,7 +366,9 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - contentBuilder.WriteString(contentDelta) + if _, err := contentBuilder.WriteString(contentDelta); err != nil { + return nil, err + } } if err := sendLine(updatedLine); err != nil { return nil, err @@ -417,7 +438,6 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * } return &soraStreamingResult{ - content: content, mediaType: mediaType, mediaURLs: mediaURLs, imageCount: imageCount, @@ -426,7 +446,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp * }, nil } -func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) { +func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) { if strings.TrimSpace(data) == "" { return "data: ", "", nil } @@ -448,7 +468,12 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin contentDelta, updated := extractSoraContent(payload) if updated { - rewritten := s.rewriteSoraContent(contentDelta) + var rewritten string + if rewriteBuffer != nil { + rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer) + } else { + rewritten = s.rewriteSoraContent(contentDelta) + } if rewritten != contentDelta { applySoraContent(payload, rewritten) contentDelta = rewritten @@ -504,6 +529,78 @@ func applySoraContent(payload map[string]any, content string) { } } +func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string { + if buffer == nil { + return s.rewriteSoraContent(contentDelta) + } + if contentDelta == "" && *buffer == "" { + return "" + } + combined := *buffer + contentDelta + rewritten := s.rewriteSoraContent(combined) + bufferStart := s.findSoraRewriteBufferStart(rewritten) + if bufferStart < 0 { + *buffer = "" + return rewritten + } + if len(rewritten)-bufferStart > soraRewriteBufferLimit { + bufferStart = len(rewritten) - soraRewriteBufferLimit + } + output := rewritten[:bufferStart] + *buffer = rewritten[bufferStart:] + return output +} + +func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int { + minIndex := -1 + start := 0 + for { + idx := strings.Index(content[start:], "![") + if idx < 0 { + break + } + idx += start + if !hasSoraImageMatchAt(content, idx) { + if minIndex == -1 || idx < minIndex { + minIndex = idx + } + } + start = idx + 2 + } + lower := strings.ToLower(content) + start = 0 + for { + idx := strings.Index(lower[start:], "= len(content) { + return false + } + loc := soraImageMarkdownRe.FindStringIndex(content[idx:]) + return loc != nil && loc[0] == 0 +} + +func hasSoraVideoMatchAt(content string, idx int) bool { + if idx < 0 || idx >= len(content) { + return false + } + loc := soraVideoHTMLRe.FindStringIndex(content[idx:]) + return loc != nil && loc[0] == 0 +} + func (s *SoraGatewayService) rewriteSoraContent(content string) string { if content == "" { return content @@ -533,6 +630,31 @@ func (s *SoraGatewayService) rewriteSoraContent(content string) string { return content } +func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) { + if buffer == "" { + return "", "", nil + } + rewritten := s.rewriteSoraContent(buffer) + payload := map[string]any{ + "choices": []any{ + map[string]any{ + "delta": map[string]any{ + "content": rewritten, + }, + "index": 0, + }, + }, + } + if originalModel != "" { + payload["model"] = originalModel + } + updatedData, err := json.Marshal(payload) + if err != nil { + return "", "", err + } + return "data: " + string(updatedData), rewritten, nil +} + func (s *SoraGatewayService) rewriteSoraURL(raw string) string { raw = strings.TrimSpace(raw) if raw == "" { diff --git a/backend/internal/service/sora_media_sign.go b/backend/internal/service/sora_media_sign.go index 5d4a8d88..26bf8923 100644 --- a/backend/internal/service/sora_media_sign.go +++ b/backend/internal/service/sora_media_sign.go @@ -15,9 +15,15 @@ func SignSoraMediaURL(path string, query string, expires int64, key string) stri return "" } mac := hmac.New(sha256.New, []byte(key)) - mac.Write([]byte(buildSoraMediaSignPayload(path, query))) - mac.Write([]byte("|")) - mac.Write([]byte(strconv.FormatInt(expires, 10))) + if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil { + return "" + } + if _, err := mac.Write([]byte("|")); err != nil { + return "" + } + if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil { + return "" + } return hex.EncodeToString(mac.Sum(nil)) } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 167d2b54..435056ab 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -15,11 +15,9 @@ import ( // 定期检查并刷新即将过期的token type TokenRefreshService struct { accountRepo AccountRepository - soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 refreshers []TokenRefresher cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator - soraSyncService *Sora2APISyncService stopCh chan struct{} wg sync.WaitGroup @@ -57,7 +55,6 @@ func NewTokenRefreshService( // 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表 // 需要在 Start() 之前调用 func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { - s.soraAccountRepo = repo // 将 soraAccountRepo 注入到 OpenAITokenRefresher for _, refresher := range s.refreshers { if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { @@ -69,7 +66,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { // SetSoraSyncService 设置 Sora2API 同步服务 // 需要在 Start() 之前调用 func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) { - s.soraSyncService = svc for _, refresher := range s.refreshers { if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { openaiRefresher.SetSoraSyncService(svc) diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 9699092d..7e084bd5 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -83,10 +83,10 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (m // OpenAITokenRefresher 处理 OpenAI OAuth token刷新 type OpenAITokenRefresher struct { - openaiOAuthService *OpenAIOAuthService - accountRepo AccountRepository - soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 - soraSyncService *Sora2APISyncService // Sora2API 同步服务 + openaiOAuthService *OpenAIOAuthService + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 + soraSyncService *Sora2APISyncService // Sora2API 同步服务 } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 689fa5d7..fb0946d2 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -51,9 +51,7 @@ func ProvideTokenRefreshService( svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) - if soraSyncService != nil { - svc.SetSoraSyncService(soraSyncService) - } + svc.SetSoraSyncService(soraSyncService) svc.Start() return svc } @@ -242,8 +240,6 @@ var ProviderSet = wire.NewSet( NewAntigravityTokenProvider, NewOpenAITokenProvider, NewClaudeTokenProvider, - NewSora2APIService, - NewSora2APISyncService, NewAntigravityGatewayService, ProvideRateLimitService, NewAccountUsageService, From 399dd78b2ac62e6b008a16da5564ba423e7f8bbd Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 1 Feb 2026 21:37:10 +0800 Subject: [PATCH 007/363] =?UTF-8?q?feat(Sora):=20=E7=9B=B4=E8=BF=9E?= =?UTF-8?q?=E7=94=9F=E6=88=90=E5=B9=B6=E7=A7=BB=E9=99=A4sora2api=E4=BE=9D?= =?UTF-8?q?=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现直连 Sora 客户端、媒体落地与清理策略\n更新网关与前端配置以支持 Sora 平台\n补齐单元测试与契约测试,新增 curl 测试脚本\n\n测试: go test ./... -tags=unit --- backend/cmd/server/wire.go | 7 + backend/cmd/server/wire_gen.go | 25 +- backend/internal/config/config.go | 116 ++- .../internal/handler/admin/model_handler.go | 55 -- .../handler/admin/model_handler_test.go | 87 -- backend/internal/handler/gateway_handler.go | 16 +- backend/internal/handler/handler.go | 1 - .../internal/handler/sora_gateway_handler.go | 83 +- .../handler/sora_gateway_handler_test.go | 441 +++++++++ backend/internal/handler/wire.go | 3 - backend/internal/server/api_contract_test.go | 9 + backend/internal/server/routes/admin.go | 7 - .../internal/service/account_test_service.go | 2 +- backend/internal/service/admin_service.go | 76 +- .../service/admin_service_bulk_update_test.go | 28 - backend/internal/service/sora2api_service.go | 351 ------- .../internal/service/sora2api_sync_service.go | 255 ----- backend/internal/service/sora_client.go | 884 ++++++++++++++++++ backend/internal/service/sora_client_test.go | 54 ++ .../internal/service/sora_gateway_service.go | 618 ++++++++++-- .../service/sora_gateway_service_test.go | 99 ++ .../service/sora_media_cleanup_service.go | 117 +++ .../sora_media_cleanup_service_test.go | 46 + .../internal/service/sora_media_storage.go | 256 +++++ .../service/sora_media_storage_test.go | 69 ++ backend/internal/service/sora_models.go | 252 +++++ .../internal/service/token_refresh_service.go | 10 - backend/internal/service/token_refresher.go | 24 - backend/internal/service/wire.go | 18 +- build_image.sh | 8 + deploy/Dockerfile | 111 +++ deploy/config.example.yaml | 82 +- frontend/src/api/admin/index.ts | 7 +- frontend/src/api/admin/models.ts | 14 - .../components/account/CreateAccountModal.vue | 6 +- .../account/ModelWhitelistSelector.vue | 54 +- .../account/OAuthAuthorizationFlow.vue | 11 +- frontend/src/composables/useModelWhitelist.ts | 2 +- frontend/src/views/admin/GroupsView.vue | 5 - 39 files changed, 3120 insertions(+), 1189 deletions(-) delete mode 100644 backend/internal/handler/admin/model_handler.go delete mode 100644 backend/internal/handler/admin/model_handler_test.go create mode 100644 backend/internal/handler/sora_gateway_handler_test.go delete mode 100644 backend/internal/service/sora2api_service.go delete mode 100644 backend/internal/service/sora2api_sync_service.go create mode 100644 backend/internal/service/sora_client.go create mode 100644 backend/internal/service/sora_client_test.go create mode 100644 backend/internal/service/sora_gateway_service_test.go create mode 100644 backend/internal/service/sora_media_cleanup_service.go create mode 100644 backend/internal/service/sora_media_cleanup_service_test.go create mode 100644 backend/internal/service/sora_media_storage.go create mode 100644 backend/internal/service/sora_media_storage_test.go create mode 100644 backend/internal/service/sora_models.go create mode 100755 build_image.sh create mode 100644 deploy/Dockerfile delete mode 100644 frontend/src/api/admin/models.ts diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 5ef04a66..1e9e440e 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -67,6 +67,7 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, @@ -100,6 +101,12 @@ func provideCleanup( } return nil }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1d88b612..dd0eb0d9 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -87,12 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) soraAccountRepository := repository.NewSoraAccountRepository(db) - sora2APIService := service.NewSora2APIService(configConfig) - sora2APISyncService := service.NewSora2APISyncService(sora2APIService, accountRepository) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, sora2APISyncService, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -164,11 +162,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - modelHandler := admin.NewModelHandler(sora2APIService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, modelHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, sora2APIService, concurrencyService, billingCacheService, configConfig) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) - soraGatewayService := service.NewSoraGatewayService(sora2APIService, httpUpstream, rateLimitService, configConfig) + soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider) + soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) + soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler) @@ -182,9 +181,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, sora2APISyncService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) + soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -214,6 +214,7 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, @@ -246,6 +247,12 @@ func provideCleanup( } return nil }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index f3dec213..147cc3e9 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -58,7 +58,7 @@ type Config struct { UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - Sora2API Sora2APIConfig `mapstructure:"sora2api"` + Sora SoraConfig `mapstructure:"sora"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` @@ -205,22 +205,40 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } -// Sora2APIConfig Sora2API 服务配置 -type Sora2APIConfig struct { - // BaseURL Sora2API 服务地址(例如 http://localhost:8000) - BaseURL string `mapstructure:"base_url"` - // APIKey Sora2API OpenAI 兼容接口的 API Key - APIKey string `mapstructure:"api_key"` - // AdminUsername 管理员用户名(用于 token 同步) - AdminUsername string `mapstructure:"admin_username"` - // AdminPassword 管理员密码(用于 token 同步) - AdminPassword string `mapstructure:"admin_password"` - // AdminTokenTTLSeconds 管理员 Token 缓存时长(秒) - AdminTokenTTLSeconds int `mapstructure:"admin_token_ttl_seconds"` - // AdminTimeoutSeconds 管理接口请求超时(秒) - AdminTimeoutSeconds int `mapstructure:"admin_timeout_seconds"` - // TokenImportMode token 导入模式:at/offline - TokenImportMode string `mapstructure:"token_import_mode"` +// SoraConfig 直连 Sora 配置 +type SoraConfig struct { + Client SoraClientConfig `mapstructure:"client"` + Storage SoraStorageConfig `mapstructure:"storage"` +} + +// SoraClientConfig 直连 Sora 客户端配置 +type SoraClientConfig struct { + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + Debug bool `mapstructure:"debug"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` +} + +// SoraStorageConfig 媒体存储配置 +type SoraStorageConfig struct { + Type string `mapstructure:"type"` + LocalPath string `mapstructure:"local_path"` + FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` + MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` + Debug bool `mapstructure:"debug"` + Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` +} + +// SoraStorageCleanupConfig 媒体清理配置 +type SoraStorageCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + Schedule string `mapstructure:"schedule"` + RetentionDays int `mapstructure:"retention_days"` } // GatewayConfig API网关相关配置 @@ -905,6 +923,26 @@ func setDefaults() { viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) + // Sora 直连配置 + viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") + viper.SetDefault("sora.client.timeout_seconds", 120) + viper.SetDefault("sora.client.max_retries", 3) + viper.SetDefault("sora.client.poll_interval_seconds", 2) + viper.SetDefault("sora.client.max_poll_attempts", 600) + viper.SetDefault("sora.client.debug", false) + viper.SetDefault("sora.client.headers", map[string]string{}) + viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + viper.SetDefault("sora.client.disable_tls_fingerprint", false) + + viper.SetDefault("sora.storage.type", "local") + viper.SetDefault("sora.storage.local_path", "") + viper.SetDefault("sora.storage.fallback_to_upstream", true) + viper.SetDefault("sora.storage.max_concurrent_downloads", 4) + viper.SetDefault("sora.storage.debug", false) + viper.SetDefault("sora.storage.cleanup.enabled", true) + viper.SetDefault("sora.storage.cleanup.retention_days", 7) + viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *") + // TokenRefresh viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 @@ -920,15 +958,6 @@ func setDefaults() { viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") - // Sora2API - viper.SetDefault("sora2api.base_url", "") - viper.SetDefault("sora2api.api_key", "") - viper.SetDefault("sora2api.admin_username", "") - viper.SetDefault("sora2api.admin_password", "") - viper.SetDefault("sora2api.admin_token_ttl_seconds", 900) - viper.SetDefault("sora2api.admin_timeout_seconds", 10) - viper.SetDefault("sora2api.token_import_mode", "at") - } func (c *Config) Validate() error { @@ -1164,6 +1193,36 @@ func (c *Config) Validate() error { return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") } } + if c.Sora.Client.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.timeout_seconds must be non-negative") + } + if c.Sora.Client.MaxRetries < 0 { + return fmt.Errorf("sora.client.max_retries must be non-negative") + } + if c.Sora.Client.PollIntervalSeconds < 0 { + return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") + } + if c.Sora.Client.MaxPollAttempts < 0 { + return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") + } + if c.Sora.Storage.MaxConcurrentDownloads < 0 { + return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") + } + if c.Sora.Storage.Cleanup.Enabled { + if c.Sora.Storage.Cleanup.RetentionDays <= 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") + } + if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" { + return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled") + } + } else { + if c.Sora.Storage.Cleanup.RetentionDays < 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative") + } + } + if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" { + return fmt.Errorf("sora.storage.type must be 'local'") + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -1260,11 +1319,6 @@ func (c *Config) Validate() error { c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds { return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds") } - if strings.TrimSpace(c.Sora2API.BaseURL) != "" { - if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil { - return fmt.Errorf("sora2api.base_url invalid: %w", err) - } - } if c.Ops.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") } diff --git a/backend/internal/handler/admin/model_handler.go b/backend/internal/handler/admin/model_handler.go deleted file mode 100644 index 035b09bd..00000000 --- a/backend/internal/handler/admin/model_handler.go +++ /dev/null @@ -1,55 +0,0 @@ -package admin - -import ( - "net/http" - "strings" - - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -// ModelHandler handles admin model listing requests. -type ModelHandler struct { - sora2apiService *service.Sora2APIService -} - -// NewModelHandler creates a new ModelHandler. -func NewModelHandler(sora2apiService *service.Sora2APIService) *ModelHandler { - return &ModelHandler{ - sora2apiService: sora2apiService, - } -} - -// List handles listing models for a specific platform -// GET /api/v1/admin/models?platform=sora -func (h *ModelHandler) List(c *gin.Context) { - platform := strings.TrimSpace(strings.ToLower(c.Query("platform"))) - if platform == "" { - response.BadRequest(c, "platform is required") - return - } - - switch platform { - case service.PlatformSora: - if h.sora2apiService == nil || !h.sora2apiService.Enabled() { - response.Error(c, http.StatusServiceUnavailable, "sora2api not configured") - return - } - models, err := h.sora2apiService.ListModels(c.Request.Context()) - if err != nil { - response.Error(c, http.StatusServiceUnavailable, "failed to fetch sora models") - return - } - ids := make([]string, 0, len(models)) - for _, m := range models { - if strings.TrimSpace(m.ID) != "" { - ids = append(ids, m.ID) - } - } - response.Success(c, ids) - default: - response.BadRequest(c, "unsupported platform") - } -} diff --git a/backend/internal/handler/admin/model_handler_test.go b/backend/internal/handler/admin/model_handler_test.go deleted file mode 100644 index e61dc064..00000000 --- a/backend/internal/handler/admin/model_handler_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package admin - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -func TestModelHandlerListSoraSuccess(t *testing.T) { - gin.SetMode(gin.TestMode) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`)) - })) - t.Cleanup(upstream.Close) - - cfg := &config.Config{} - cfg.Sora2API.BaseURL = upstream.URL - cfg.Sora2API.APIKey = "test-key" - soraService := service.NewSora2APIService(cfg) - - h := NewModelHandler(soraService) - router := gin.New() - router.GET("/admin/models", h.List) - - req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) - recorder := httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusOK { - t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) - } - var resp response.Response - if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - if resp.Code != 0 { - t.Fatalf("响应 code=%d", resp.Code) - } - data, ok := resp.Data.([]any) - if !ok { - t.Fatalf("响应 data 类型错误") - } - if len(data) != 2 { - t.Fatalf("模型数量不符: %d", len(data)) - } -} - -func TestModelHandlerListSoraNotConfigured(t *testing.T) { - gin.SetMode(gin.TestMode) - - h := NewModelHandler(&service.Sora2APIService{}) - router := gin.New() - router.GET("/admin/models", h.List) - - req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) - recorder := httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusServiceUnavailable { - t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) - } -} - -func TestModelHandlerListInvalidPlatform(t *testing.T) { - gin.SetMode(gin.TestMode) - - h := NewModelHandler(&service.Sora2APIService{}) - router := gin.New() - router.GET("/admin/models", h.List) - - req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil) - recorder := httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusBadRequest { - t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) - } -} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 983cc6b3..a7b98940 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -29,11 +29,11 @@ type GatewayHandler struct { geminiCompatService *service.GeminiMessagesCompatService antigravityGatewayService *service.AntigravityGatewayService userService *service.UserService - sora2apiService *service.Sora2APIService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int + cfg *config.Config } // NewGatewayHandler creates a new GatewayHandler @@ -42,7 +42,6 @@ func NewGatewayHandler( geminiCompatService *service.GeminiMessagesCompatService, antigravityGatewayService *service.AntigravityGatewayService, userService *service.UserService, - sora2apiService *service.Sora2APIService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, cfg *config.Config, @@ -64,11 +63,11 @@ func NewGatewayHandler( geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, userService: userService, - sora2apiService: sora2apiService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, + cfg: cfg, } } @@ -486,18 +485,9 @@ func (h *GatewayHandler) Models(c *gin.Context) { } if platform == service.PlatformSora { - if h.sora2apiService == nil || !h.sora2apiService.Enabled() { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured") - return - } - models, err := h.sora2apiService.ListModels(c.Request.Context()) - if err != nil { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models") - return - } c.JSON(http.StatusOK, gin.H{ "object": "list", - "data": models, + "data": service.DefaultSoraModels(h.cfg), }) return } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 7905148c..d7014a22 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -23,7 +23,6 @@ type AdminHandlers struct { Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler - Model *admin.ModelHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 05833144..faed3b33 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -10,7 +10,9 @@ import ( "io" "log" "net/http" + "os" "path" + "path/filepath" "strconv" "strings" "time" @@ -31,9 +33,8 @@ type SoraGatewayHandler struct { concurrencyHelper *ConcurrencyHelper maxAccountSwitches int streamMode string - sora2apiBaseURL string soraMediaSigningKey string - mediaClient *http.Client + soraMediaRoot string } // NewSoraGatewayHandler creates a new SoraGatewayHandler @@ -48,6 +49,7 @@ func NewSoraGatewayHandler( maxAccountSwitches := 3 streamMode := "force" signKey := "" + mediaRoot := "/app/data/sora" if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second if cfg.Gateway.MaxAccountSwitches > 0 { @@ -57,14 +59,9 @@ func NewSoraGatewayHandler( streamMode = mode } signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) - } - baseURL := "" - if cfg != nil { - baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/") - } - mediaTimeout := 180 * time.Second - if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 { - mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second + if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { + mediaRoot = root + } } return &SoraGatewayHandler{ gatewayService: gatewayService, @@ -73,9 +70,8 @@ func NewSoraGatewayHandler( concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, streamMode: strings.ToLower(streamMode), - sora2apiBaseURL: baseURL, soraMediaSigningKey: signKey, - mediaClient: &http.Client{Timeout: mediaTimeout}, + soraMediaRoot: mediaRoot, } } @@ -377,34 +373,24 @@ func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, }) } -// MediaProxy proxies /tmp or /static media files from sora2api +// MediaProxy serves local Sora media files. func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) { h.proxySoraMedia(c, false) } -// MediaProxySigned proxies /tmp or /static media files with signature verification +// MediaProxySigned serves local Sora media files with signature verification. func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) { h.proxySoraMedia(c, true) } func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) { - if h.sora2apiBaseURL == "" { - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "type": "api_error", - "message": "sora2api 未配置", - }, - }) - return - } - rawPath := c.Param("filepath") if rawPath == "" { c.Status(http.StatusNotFound) return } cleaned := path.Clean(rawPath) - if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") { + if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") { c.Status(http.StatusNotFound) return } @@ -445,40 +431,25 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo return } } - - targetURL := h.sora2apiBaseURL + cleaned - if rawQuery := query.Encode(); rawQuery != "" { - targetURL += "?" + rawQuery - } - - req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil) - if err != nil { - c.Status(http.StatusBadGateway) + if strings.TrimSpace(h.soraMediaRoot) == "" { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Sora 媒体目录未配置", + }, + }) return } - copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"} - for _, key := range copyHeaders { - if val := c.GetHeader(key); val != "" { - req.Header.Set(key, val) - } - } - client := h.mediaClient - if client == nil { - client = http.DefaultClient - } - resp, err := client.Do(req) - if err != nil { - c.Status(http.StatusBadGateway) + relative := strings.TrimPrefix(cleaned, "/") + localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative)) + if _, err := os.Stat(localPath); err != nil { + if os.IsNotExist(err) { + c.Status(http.StatusNotFound) + return + } + c.Status(http.StatusInternalServerError) return } - defer func() { _ = resp.Body.Close() }() - - for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} { - if val := resp.Header.Get(key); val != "" { - c.Header(key, val) - } - } - c.Status(resp.StatusCode) - _, _ = io.Copy(c.Writer, resp.Body) + c.File(localPath) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go new file mode 100644 index 00000000..91881dec --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -0,0 +1,441 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type stubSoraClient struct { + imageURLs []string +} + +func (s *stubSoraClient) Enabled() bool { return true } +func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) { + return "upload", nil +} +func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) { + return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil +} +func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) { + return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil +} + +type stubConcurrencyCache struct{} + +func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + return nil +} +func (c stubConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} +func (c stubConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} +func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} +func (c stubConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + return nil +} +func (c stubConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} +func (c stubConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} +func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + result := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return result, nil +} +func (c stubConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +type stubAccountRepo struct { + accounts map[int64]*service.Account +} + +func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil } +func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) { + if acc, ok := r.accounts[id]; ok { + return acc, nil + } + return nil, service.ErrAccountNotFound +} +func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { + var result []*service.Account + for _, id := range ids { + if acc, ok := r.accounts[id]; ok { + result = append(result, acc) + } + } + return result, nil +} +func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) { + _, ok := r.accounts[id] + return ok, nil +} +func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil } +func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil } +func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.listSchedulableByPlatform(platform), nil +} +func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return nil +} +func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil } +func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + return nil +} +func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + return 0, nil +} +func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { + return nil +} +func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) { + return r.listSchedulable(), nil +} +func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { + return r.listSchedulable(), nil +} +func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.listSchedulableByPlatform(platform), nil +} +func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return r.listSchedulableByPlatform(platform), nil +} +func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + var result []service.Account + for _, acc := range r.accounts { + for _, platform := range platforms { + if acc.Platform == platform && acc.IsSchedulable() { + result = append(result, *acc) + break + } + } + } + return result, nil +} +func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { + return r.ListSchedulableByPlatforms(ctx, platforms) +} +func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + return nil +} +func (r *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { + return nil +} +func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { + return nil +} +func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + return nil +} +func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + return nil +} +func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + return nil +} +func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { + return nil +} +func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + return nil +} +func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { + return 0, nil +} + +func (r *stubAccountRepo) listSchedulable() []service.Account { + var result []service.Account + for _, acc := range r.accounts { + if acc.IsSchedulable() { + result = append(result, *acc) + } + } + return result +} + +func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account { + var result []service.Account + for _, acc := range r.accounts { + if acc.Platform == platform && acc.IsSchedulable() { + result = append(result, *acc) + } + } + return result +} + +type stubGroupRepo struct { + group *service.Group +} + +func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil } +func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) { + return r.group, nil +} +func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) { + return r.group, nil +} +func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil } +func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil } +func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil } +func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { + return nil, nil +} +func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { + return false, nil +} +func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} +func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} + +type stubUsageLogRepo struct{} + +func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) { + return true, nil +} +func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { + return nil, nil +} +func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil } +func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) { + return nil, nil +} +func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) { + return nil, nil +} + +func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + RunMode: config.RunModeSimple, + Gateway: config.GatewayConfig{ + SoraStreamMode: "force", + MaxAccountSwitches: 1, + Scheduling: config.GatewaySchedulingConfig{ + LoadBatchEnabled: false, + }, + }, + Concurrency: config.ConcurrencyConfig{PingInterval: 0}, + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.test", + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + + account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1} + accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}} + group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true} + groupRepo := &stubGroupRepo{group: group} + + usageLogRepo := &stubUsageLogRepo{} + deferredService := service.NewDeferredService(accountRepo, nil, 0) + billingService := service.NewBillingService(cfg, nil) + concurrencyService := service.NewConcurrencyService(stubConcurrencyCache{}) + billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg) + t.Cleanup(func() { + billingCacheService.Stop() + }) + + gatewayService := service.NewGatewayService( + accountRepo, + groupRepo, + usageLogRepo, + nil, + nil, + nil, + cfg, + nil, + concurrencyService, + billingService, + nil, + billingCacheService, + nil, + nil, + deferredService, + nil, + nil, + ) + + soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} + soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg) + + handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}` + c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + apiKey := &service.APIKey{ + ID: 1, + UserID: 1, + Status: service.StatusActive, + GroupID: &group.ID, + User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive}, + Group: group, + } + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency}) + + handler.ChatCompletions(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotEmpty(t, resp["media_url"]) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 1e3ef17d..c20b7fbc 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -26,7 +26,6 @@ func ProvideAdminHandlers( subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, - modelHandler *admin.ModelHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -46,7 +45,6 @@ func ProvideAdminHandlers( Subscription: subscriptionHandler, Usage: usageHandler, UserAttribute: userAttributeHandler, - Model: modelHandler, } } @@ -121,7 +119,6 @@ var ProviderSet = wire.NewSet( admin.NewSubscriptionHandler, admin.NewUsageHandler, admin.NewUserAttributeHandler, - admin.NewModelHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index f3eebd41..409a7625 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -178,6 +178,10 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, + "sora_image_price_360": null, + "sora_image_price_540": null, + "sora_video_price_per_request": null, + "sora_video_price_per_request_hd": null, "claude_code_only": false, "fallback_group_id": null, "created_at": "2025-01-02T03:04:05Z", @@ -394,6 +398,7 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, + "media_type": null, "created_at": "2025-01-02T03:04:05Z", "user_agent": null } @@ -887,6 +892,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st return nil, errors.New("not implemented") } +func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 2c1762d3..050e724d 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -64,9 +64,6 @@ func RegisterAdminRoutes( // 用户属性管理 registerUserAttributeRoutes(admin, h) - - // 模型列表 - registerModelRoutes(admin, h) } } @@ -374,7 +371,3 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) } } - -func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { - admin.GET("/models", h.Admin.Model.List) -} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index f80a2af8..a76c4d20 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -491,7 +491,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * return s.sendErrorAndEnd(c, "Failed to create request") } - // 使用 Sora 客户端标准请求头(参考 sora2api) + // 使用 Sora 客户端标准请求头 req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") req.Header.Set("Accept", "application/json") diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index a29bf4db..94b18322 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -283,7 +283,6 @@ type adminServiceImpl struct { groupRepo GroupRepository accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 - soraSyncService *Sora2APISyncService // Sora2API 同步服务 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -299,7 +298,6 @@ func NewAdminService( groupRepo GroupRepository, accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, - soraSyncService *Sora2APISyncService, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -313,7 +311,6 @@ func NewAdminService( groupRepo: groupRepo, accountRepo: accountRepo, soraAccountRepo: soraAccountRepo, - soraSyncService: soraSyncService, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -917,9 +914,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } - // 同步到 sora2api(异步,不阻塞创建) - s.syncSoraAccountAsync(account) - return account, nil } @@ -1014,7 +1008,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U if err != nil { return nil, err } - s.syncSoraAccountAsync(updated) return updated, nil } @@ -1032,17 +1025,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck - needSoraSync := s != nil && s.soraSyncService != nil // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} - if needMixedChannelCheck || needSoraSync { + if needMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { if needMixedChannelCheck { return nil, err } - log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err) } else { for _, account := range accounts { if account != nil { @@ -1134,45 +1125,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Success++ result.SuccessIDs = append(result.SuccessIDs, accountID) result.Results = append(result.Results, entry) - - // 批量更新后同步 sora2api - if needSoraSync { - platform := platformByID[accountID] - if platform == "" { - updated, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) - continue - } - if updated.Platform == PlatformSora { - s.syncSoraAccountAsync(updated) - } - continue - } - - if platform == PlatformSora { - updated, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) - continue - } - s.syncSoraAccountAsync(updated) - } - } } return result, nil } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - account, err := s.accountRepo.GetByID(ctx, id) - if err != nil { - return err - } if err := s.accountRepo.Delete(ctx, id); err != nil { return err } - s.deleteSoraAccountAsync(account) return nil } @@ -1210,44 +1171,9 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, if err != nil { return nil, err } - s.syncSoraAccountAsync(updated) return updated, nil } -func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) { - if s == nil || s.soraSyncService == nil || account == nil { - return - } - if account.Platform != PlatformSora { - return - } - syncAccount := *account - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil { - log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err) - } - }() -} - -func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) { - if s == nil || s.soraSyncService == nil || account == nil { - return - } - if account.Platform != PlatformSora { - return - } - syncAccount := *account - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil { - log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err) - } - }() -} - // Proxy management implementations func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index cbdbe625..0dccacbb 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -105,31 +105,3 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } - -// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。 -func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) { - repo := &accountRepoStubForBulkUpdate{ - getByIDsAccounts: []*Account{ - {ID: 1, Platform: PlatformSora}, - }, - getByIDAccounts: map[int64]*Account{ - 1: {ID: 1, Platform: PlatformSora}, - }, - } - svc := &adminServiceImpl{ - accountRepo: repo, - soraSyncService: &Sora2APISyncService{}, - } - - schedulable := true - input := &BulkUpdateAccountsInput{ - AccountIDs: []int64{1}, - Schedulable: &schedulable, - } - - result, err := svc.BulkUpdateAccounts(context.Background(), input) - require.NoError(t, err) - require.Equal(t, 1, result.Success) - require.True(t, repo.getByIDsCalled) - require.ElementsMatch(t, []int64{1}, repo.getByIDCalled) -} diff --git a/backend/internal/service/sora2api_service.go b/backend/internal/service/sora2api_service.go deleted file mode 100644 index c047cd40..00000000 --- a/backend/internal/service/sora2api_service.go +++ /dev/null @@ -1,351 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "log" - "net/http" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" -) - -// Sora2APIModel represents a model entry returned by sora2api. -type Sora2APIModel struct { - ID string `json:"id"` - Object string `json:"object"` - OwnedBy string `json:"owned_by,omitempty"` - Description string `json:"description,omitempty"` -} - -// Sora2APIModelList represents /v1/models response. -type Sora2APIModelList struct { - Object string `json:"object"` - Data []Sora2APIModel `json:"data"` -} - -// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem. -type Sora2APIImportTokenItem struct { - Email string `json:"email"` - AccessToken string `json:"access_token,omitempty"` - SessionToken string `json:"session_token,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - ClientID string `json:"client_id,omitempty"` - ProxyURL string `json:"proxy_url,omitempty"` - Remark string `json:"remark,omitempty"` - IsActive bool `json:"is_active"` - ImageEnabled bool `json:"image_enabled"` - VideoEnabled bool `json:"video_enabled"` - ImageConcurrency int `json:"image_concurrency"` - VideoConcurrency int `json:"video_concurrency"` -} - -// Sora2APIToken represents minimal fields for admin list. -type Sora2APIToken struct { - ID int64 `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Remark string `json:"remark"` -} - -// Sora2APIService provides access to sora2api endpoints. -type Sora2APIService struct { - cfg *config.Config - - baseURL string - apiKey string - adminUsername string - adminPassword string - adminTokenTTL time.Duration - tokenImportMode string - - client *http.Client - adminClient *http.Client - - adminToken string - adminTokenAt time.Time - adminMu sync.Mutex - - modelCache []Sora2APIModel - modelMu sync.RWMutex -} - -func NewSora2APIService(cfg *config.Config) *Sora2APIService { - if cfg == nil { - return &Sora2APIService{} - } - adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second - if adminTTL <= 0 { - adminTTL = 15 * time.Minute - } - adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second - if adminTimeout <= 0 { - adminTimeout = 10 * time.Second - } - return &Sora2APIService{ - cfg: cfg, - baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"), - apiKey: strings.TrimSpace(cfg.Sora2API.APIKey), - adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername), - adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword), - adminTokenTTL: adminTTL, - tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)), - client: &http.Client{}, - adminClient: &http.Client{Timeout: adminTimeout}, - } -} - -func (s *Sora2APIService) Enabled() bool { - return s != nil && s.baseURL != "" && s.apiKey != "" -} - -func (s *Sora2APIService) AdminEnabled() bool { - return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != "" -} - -func (s *Sora2APIService) buildURL(path string) string { - if s.baseURL == "" { - return path - } - if strings.HasPrefix(path, "/") { - return s.baseURL + path - } - return s.baseURL + "/" + path -} - -// BuildURL 返回完整的 sora2api URL(用于代理媒体) -func (s *Sora2APIService) BuildURL(path string) string { - return s.buildURL(path) -} - -func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) { - if !s.Enabled() { - return nil, errors.New("sora2api not configured") - } - req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+s.apiKey) - req.Header.Set("Content-Type", "application/json") - return req, nil -} - -func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) { - if !s.Enabled() { - return nil, errors.New("sora2api not configured") - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+s.apiKey) - resp, err := s.client.Do(req) - if err != nil { - return s.cachedModelsOnError(err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode)) - } - - var payload Sora2APIModelList - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { - return s.cachedModelsOnError(err) - } - models := payload.Data - if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance { - filtered := make([]Sora2APIModel, 0, len(models)) - for _, m := range models { - if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") { - continue - } - filtered = append(filtered, m) - } - models = filtered - } - - s.modelMu.Lock() - s.modelCache = models - s.modelMu.Unlock() - - return models, nil -} - -func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) { - s.modelMu.RLock() - cached := append([]Sora2APIModel(nil), s.modelCache...) - s.modelMu.RUnlock() - if len(cached) > 0 { - log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err) - return cached, nil - } - return nil, err -} - -func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error { - if !s.AdminEnabled() { - return errors.New("sora2api admin not configured") - } - mode := s.tokenImportMode - if mode == "" { - mode = "at" - } - payload := map[string]any{ - "tokens": items, - "mode": mode, - } - _, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil) - return err -} - -func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) { - if !s.AdminEnabled() { - return nil, errors.New("sora2api admin not configured") - } - var tokens []Sora2APIToken - _, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens) - return tokens, err -} - -func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error { - if !s.AdminEnabled() { - return errors.New("sora2api admin not configured") - } - path := fmt.Sprintf("/api/tokens/%d/disable", tokenID) - _, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil) - return err -} - -func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error { - if !s.AdminEnabled() { - return errors.New("sora2api admin not configured") - } - path := fmt.Sprintf("/api/tokens/%d", tokenID) - _, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil) - return err -} - -func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) { - if !s.AdminEnabled() { - return nil, errors.New("sora2api admin not configured") - } - token, err := s.getAdminToken(ctx) - if err != nil { - return nil, err - } - resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out) - if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized { - return resp, nil - } - if resp != nil && resp.StatusCode == http.StatusUnauthorized { - s.invalidateAdminToken() - token, err = s.getAdminToken(ctx) - if err != nil { - return resp, err - } - return s.doAdminRequestWithToken(ctx, method, path, token, body, out) - } - return resp, err -} - -func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) { - var reader *bytes.Reader - if body != nil { - buf, err := json.Marshal(body) - if err != nil { - return nil, err - } - reader = bytes.NewReader(buf) - } else { - reader = bytes.NewReader(nil) - } - req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+token) - if body != nil { - req.Header.Set("Content-Type", "application/json") - } - resp, err := s.adminClient.Do(req) - if err != nil { - return resp, err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode) - } - if out != nil { - if err := json.NewDecoder(resp.Body).Decode(out); err != nil { - return resp, err - } - } - return resp, nil -} - -func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) { - s.adminMu.Lock() - defer s.adminMu.Unlock() - - if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL { - return s.adminToken, nil - } - - if !s.AdminEnabled() { - return "", errors.New("sora2api admin not configured") - } - - payload := map[string]string{ - "username": s.adminUsername, - "password": s.adminPassword, - } - buf, err := json.Marshal(payload) - if err != nil { - return "", err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf)) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") - resp, err := s.adminClient.Do(req) - if err != nil { - return "", err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode) - } - var result struct { - Success bool `json:"success"` - Token string `json:"token"` - Message string `json:"message"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", err - } - if !result.Success || result.Token == "" { - if result.Message == "" { - result.Message = "sora2api login failed" - } - return "", errors.New(result.Message) - } - s.adminToken = result.Token - s.adminTokenAt = time.Now() - return result.Token, nil -} - -func (s *Sora2APIService) invalidateAdminToken() { - s.adminMu.Lock() - defer s.adminMu.Unlock() - s.adminToken = "" - s.adminTokenAt = time.Time{} -} diff --git a/backend/internal/service/sora2api_sync_service.go b/backend/internal/service/sora2api_sync_service.go deleted file mode 100644 index 33978432..00000000 --- a/backend/internal/service/sora2api_sync_service.go +++ /dev/null @@ -1,255 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log" - "net/http" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" -) - -// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池 -type Sora2APISyncService struct { - sora2api *Sora2APIService - accountRepo AccountRepository - httpClient *http.Client -} - -func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService { - return &Sora2APISyncService{ - sora2api: sora2api, - accountRepo: accountRepo, - httpClient: &http.Client{Timeout: 10 * time.Second}, - } -} - -func (s *Sora2APISyncService) Enabled() bool { - return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled() -} - -// SyncAccount 将 Sora 账号同步到 sora2api(导入或更新) -func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error { - if !s.Enabled() { - return nil - } - if account == nil || account.Platform != PlatformSora { - return nil - } - - accessToken := strings.TrimSpace(account.GetCredential("access_token")) - if accessToken == "" { - return errors.New("sora 账号缺少 access_token") - } - - email, updated := s.resolveAccountEmail(ctx, account) - if email == "" { - return errors.New("无法解析 Sora 账号邮箱") - } - if updated && s.accountRepo != nil { - if err := s.accountRepo.Update(ctx, account); err != nil { - log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err) - } - } - - item := Sora2APIImportTokenItem{ - Email: email, - AccessToken: accessToken, - SessionToken: strings.TrimSpace(account.GetCredential("session_token")), - RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")), - ClientID: strings.TrimSpace(account.GetCredential("client_id")), - Remark: account.Name, - IsActive: account.IsActive() && account.Schedulable, - ImageEnabled: true, - VideoEnabled: true, - ImageConcurrency: normalizeSoraConcurrency(account.Concurrency), - VideoConcurrency: normalizeSoraConcurrency(account.Concurrency), - } - - if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil { - return err - } - return nil -} - -// DisableAccount 禁用 sora2api 中的 token -func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error { - if !s.Enabled() { - return nil - } - if account == nil || account.Platform != PlatformSora { - return nil - } - tokenID, err := s.resolveTokenID(ctx, account) - if err != nil { - return err - } - return s.sora2api.DisableToken(ctx, tokenID) -} - -// DeleteAccount 删除 sora2api 中的 token -func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error { - if !s.Enabled() { - return nil - } - if account == nil || account.Platform != PlatformSora { - return nil - } - tokenID, err := s.resolveTokenID(ctx, account) - if err != nil { - return err - } - return s.sora2api.DeleteToken(ctx, tokenID) -} - -func normalizeSoraConcurrency(value int) int { - if value <= 0 { - return -1 - } - return value -} - -func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) { - if account == nil { - return "", false - } - if email := strings.TrimSpace(account.GetCredential("email")); email != "" { - return email, false - } - if email := strings.TrimSpace(account.GetExtraString("email")); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - - accessToken := strings.TrimSpace(account.GetCredential("access_token")) - if accessToken != "" { - if email := extractEmailFromAccessToken(accessToken); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - if email := s.fetchEmailFromSora(ctx, accessToken); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - } - - return "", false -} - -func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) { - if account == nil { - return 0, errors.New("account is nil") - } - - if account.Extra != nil { - if v, ok := account.Extra["sora2api_token_id"]; ok { - if id, ok := v.(float64); ok && id > 0 { - return int64(id), nil - } - if id, ok := v.(int64); ok && id > 0 { - return id, nil - } - if id, ok := v.(int); ok && id > 0 { - return int64(id), nil - } - } - } - - email := strings.TrimSpace(account.GetCredential("email")) - if email == "" { - email, _ = s.resolveAccountEmail(ctx, account) - } - if email == "" { - return 0, errors.New("sora2api token email missing") - } - - tokenID, err := s.findTokenIDByEmail(ctx, email) - if err != nil { - return 0, err - } - return tokenID, nil -} - -func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) { - if !s.Enabled() { - return 0, errors.New("sora2api admin not configured") - } - tokens, err := s.sora2api.ListTokens(ctx) - if err != nil { - return 0, err - } - for _, token := range tokens { - if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) { - return token.ID, nil - } - } - return 0, fmt.Errorf("sora2api token not found for email: %s", email) -} - -func extractEmailFromAccessToken(accessToken string) string { - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) - claims := jwt.MapClaims{} - _, _, err := parser.ParseUnverified(accessToken, claims) - if err != nil { - return "" - } - if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" { - return email - } - if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok { - if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" { - return email - } - } - return "" -} - -func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string { - if s.httpClient == nil { - return "" - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil) - if err != nil { - return "" - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - req.Header.Set("Accept", "application/json") - - resp, err := s.httpClient.Do(req) - if err != nil { - return "" - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return "" - } - var payload map[string]any - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { - return "" - } - if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" { - return email - } - return "" -} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go new file mode 100644 index 00000000..9ecb4688 --- /dev/null +++ b/backend/internal/service/sora_client.go @@ -0,0 +1,884 @@ +package service + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math/rand" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/google/uuid" + "golang.org/x/crypto/sha3" +) + +const ( + soraChatGPTBaseURL = "https://chatgpt.com" + soraSentinelFlow = "sora_2_create_task" + soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" +) + +const ( + soraPowMaxIteration = 500000 +) + +var soraPowCores = []int{8, 16, 24, 32} + +var soraPowScripts = []string{ + "https://cdn.oaistatic.com/_next/static/cXh69klOLzS0Gy2joLDRS/_ssgManifest.js?dpl=453ebaec0d44c2decab71692e1bfe39be35a24b3", +} + +var soraPowDPL = []string{ + "prod-f501fe933b3edf57aea882da888e1a544df99840", +} + +var soraPowNavigatorKeys = []string{ + "registerProtocolHandler−function registerProtocolHandler() { [native code] }", + "storage−[object StorageManager]", + "locks−[object LockManager]", + "appCodeName−Mozilla", + "permissions−[object Permissions]", + "webdriver−false", + "vendor−Google Inc.", + "mediaDevices−[object MediaDevices]", + "cookieEnabled−true", + "product−Gecko", + "productSub−20030107", + "hardwareConcurrency−32", + "onLine−true", +} + +var soraPowDocumentKeys = []string{ + "_reactListeningo743lnnpvdg", + "location", +} + +var soraPowWindowKeys = []string{ + "0", "window", "self", "document", "name", "location", + "navigator", "screen", "innerWidth", "innerHeight", + "localStorage", "sessionStorage", "crypto", "performance", + "fetch", "setTimeout", "setInterval", "console", +} + +var soraDesktopUserAgents = []string{ + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", +} + +var soraRand = rand.New(rand.NewSource(time.Now().UnixNano())) +var soraRandMu sync.Mutex +var soraPerfStart = time.Now() + +// SoraClient 定义直连 Sora 的任务操作接口。 +type SoraClient interface { + Enabled() bool + UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) + CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) + CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) + GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) +} + +// SoraImageRequest 图片生成请求参数 +type SoraImageRequest struct { + Prompt string + Width int + Height int + MediaID string +} + +// SoraVideoRequest 视频生成请求参数 +type SoraVideoRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string + RemixTargetID string +} + +// SoraImageTaskStatus 图片任务状态 +type SoraImageTaskStatus struct { + ID string + Status string + ProgressPct float64 + URLs []string + ErrorMsg string +} + +// SoraVideoTaskStatus 视频任务状态 +type SoraVideoTaskStatus struct { + ID string + Status string + ProgressPct int + URLs []string + ErrorMsg string +} + +// SoraUpstreamError 上游错误 +type SoraUpstreamError struct { + StatusCode int + Message string + Headers http.Header + Body []byte +} + +func (e *SoraUpstreamError) Error() string { + if e == nil { + return "sora upstream error" + } + if e.Message != "" { + return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) + } + return fmt.Sprintf("sora upstream error: %d", e.StatusCode) +} + +// SoraDirectClient 直连 Sora 实现 +type SoraDirectClient struct { + cfg *config.Config + httpUpstream HTTPUpstream + tokenProvider *OpenAITokenProvider +} + +// NewSoraDirectClient 创建 Sora 直连客户端 +func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { + return &SoraDirectClient{ + cfg: cfg, + httpUpstream: httpUpstream, + tokenProvider: tokenProvider, + } +} + +// Enabled 判断是否启用 Sora 直连 +func (c *SoraDirectClient) Enabled() bool { + if c == nil || c.cfg == nil { + return false + } + return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != "" +} + +func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + if len(data) == 0 { + return "", errors.New("empty image data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + if filename == "" { + filename = "image.png" + } + var body bytes.Buffer + writer := multipart.NewWriter(&body) + contentType := mime.TypeByExtension(path.Ext(filename)) + if contentType == "" { + contentType = "application/octet-stream" + } + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, filename)) + partHeader.Set("Content-Type", contentType) + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("file_name", filename); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", writer.FormDataContentType()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) + if err != nil { + return "", err + } + var payload map[string]any + if err := json.Unmarshal(respBody, &payload); err != nil { + return "", fmt.Errorf("parse upload response: %w", err) + } + id, _ := payload["id"].(string) + if strings.TrimSpace(id) == "" { + return "", errors.New("upload response missing id") + } + return id, nil +} + +func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + operation := "simple_compose" + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + operation = "remix" + inpaintItems = append(inpaintItems, map[string]any{ + "type": "image", + "frame_index": 0, + "upload_media_id": req.MediaID, + }) + } + payload := map[string]any{ + "type": "image_gen", + "operation": operation, + "prompt": req.Prompt, + "width": req.Width, + "height": req.Height, + "n_variants": 1, + "n_frames": 1, + "inpaint_items": inpaintItems, + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + taskID, _ := resp["id"].(string) + if strings.TrimSpace(taskID) == "" { + return "", errors.New("image task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + model := req.Model + if model == "" { + model = "sy_8" + } + size := req.Size + if size == "" { + size = "small" + } + + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": req.MediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": req.Prompt, + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "model": model, + "inpaint_items": inpaintItems, + } + if strings.TrimSpace(req.RemixTargetID) != "" { + payload["remix_target_id"] = req.RemixTargetID + payload["cameo_ids"] = []string{} + payload["cameo_replacements"] = map[string]any{} + } + + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + taskID, _ := resp["id"].(string) + if strings.TrimSpace(taskID) == "" { + return "", errors.New("video task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false) + if err != nil { + return nil, err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return nil, err + } + taskResponses, _ := resp["task_responses"].([]any) + for _, item := range taskResponses { + taskResp, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := taskResp["id"].(string); id == taskID { + status := strings.TrimSpace(fmt.Sprintf("%v", taskResp["status"])) + progress := 0.0 + if v, ok := taskResp["progress_pct"].(float64); ok { + progress = v + } + urls := []string{} + if generations, ok := taskResp["generations"].([]any); ok { + for _, genItem := range generations { + gen, ok := genItem.(map[string]any) + if !ok { + continue + } + if urlStr, ok := gen["url"].(string); ok && strings.TrimSpace(urlStr) != "" { + urls = append(urls, urlStr) + } + } + } + return &SoraImageTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + URLs: urls, + }, nil + } + } + return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil +} + +func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) + if err != nil { + return nil, err + } + var pending any + if err := json.Unmarshal(respBody, &pending); err == nil { + if list, ok := pending.([]any); ok { + for _, item := range list { + task, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := task["id"].(string); id == taskID { + progress := 0 + if v, ok := task["progress_pct"].(float64); ok { + progress = int(v * 100) + } + status := strings.TrimSpace(fmt.Sprintf("%v", task["status"])) + return &SoraVideoTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + }, nil + } + } + } + } + + respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) + if err != nil { + return nil, err + } + var draftsResp map[string]any + if err := json.Unmarshal(respBody, &draftsResp); err != nil { + return nil, err + } + items, _ := draftsResp["items"].([]any) + for _, item := range items { + draft, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := draft["task_id"].(string); id == taskID { + kind := strings.TrimSpace(fmt.Sprintf("%v", draft["kind"])) + reason := strings.TrimSpace(fmt.Sprintf("%v", draft["reason_str"])) + if reason == "" { + reason = strings.TrimSpace(fmt.Sprintf("%v", draft["markdown_reason_str"])) + } + urlStr := strings.TrimSpace(fmt.Sprintf("%v", draft["downloadable_url"])) + if urlStr == "" { + urlStr = strings.TrimSpace(fmt.Sprintf("%v", draft["url"])) + } + + if kind == "sora_content_violation" || reason != "" || urlStr == "" { + msg := reason + if msg == "" { + msg = "Content violates guardrails" + } + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "failed", + ErrorMsg: msg, + }, nil + } + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "completed", + URLs: []string{urlStr}, + }, nil + } + } + + return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil +} + +func (c *SoraDirectClient) buildURL(endpoint string) string { + base := "" + if c != nil && c.cfg != nil { + base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/") + } + if base == "" { + return endpoint + } + if strings.HasPrefix(endpoint, "/") { + return base + endpoint + } + return base + "/" + endpoint +} + +func (c *SoraDirectClient) defaultUserAgent() string { + if c == nil || c.cfg == nil { + return soraDefaultUserAgent + } + ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent) + if ua == "" { + return soraDefaultUserAgent + } + return ua +} + +func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if c.tokenProvider != nil { + return c.tokenProvider.GetAccessToken(ctx, account) + } + token := strings.TrimSpace(account.GetCredential("access_token")) + if token == "" { + return "", errors.New("access_token not found") + } + return token, nil +} + +func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { + headers := http.Header{} + if token != "" { + headers.Set("Authorization", "Bearer "+token) + } + if userAgent != "" { + headers.Set("User-Agent", userAgent) + } + if c != nil && c.cfg != nil { + for key, value := range c.cfg.Sora.Client.Headers { + if strings.EqualFold(key, "authorization") || strings.EqualFold(key, "openai-sentinel-token") { + continue + } + headers.Set(key, value) + } + } + return headers +} + +func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) { + if strings.TrimSpace(urlStr) == "" { + return nil, nil, errors.New("empty upstream url") + } + timeout := 0 + if c != nil && c.cfg != nil { + timeout = c.cfg.Sora.Client.TimeoutSeconds + } + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + } + maxRetries := 0 + if allowRetry && c != nil && c.cfg != nil { + maxRetries = c.cfg.Sora.Client.MaxRetries + } + if maxRetries < 0 { + maxRetries = 0 + } + + var bodyBytes []byte + if body != nil { + b, err := io.ReadAll(body) + if err != nil { + return nil, nil, err + } + bodyBytes = b + } + + attempts := maxRetries + 1 + for attempt := 1; attempt <= attempts; attempt++ { + var reader io.Reader + if bodyBytes != nil { + reader = bytes.NewReader(bodyBytes) + } + req, err := http.NewRequestWithContext(ctx, method, urlStr, reader) + if err != nil { + return nil, nil, err + } + req.Header = headers.Clone() + start := time.Now() + + proxyURL := "" + if account != nil && account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := c.doHTTP(req, proxyURL, account) + if err != nil { + if attempt < attempts && allowRetry { + c.sleepRetry(attempt) + continue + } + return nil, nil, err + } + + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if readErr != nil { + return nil, resp.Header, readErr + } + + if c.cfg != nil && c.cfg.Sora.Client.Debug { + log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start)) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody) + if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { + c.sleepRetry(attempt) + continue + } + return nil, resp.Header, upstreamErr + } + return respBody, resp.Header, nil + } + return nil, nil, errors.New("upstream retries exhausted") +} + +func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { + enableTLS := false + if c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint { + enableTLS = true + } + if c.httpUpstream != nil { + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + return c.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS) + } + return http.DefaultClient.Do(req) +} + +func (c *SoraDirectClient) sleepRetry(attempt int) { + backoff := time.Duration(attempt*attempt) * time.Second + if backoff > 10*time.Second { + backoff = 10 * time.Second + } + time.Sleep(backoff) +} + +func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error { + msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + msg = sanitizeUpstreamErrorMessage(msg) + if msg == "" { + msg = truncateForLog(body, 256) + } + return &SoraUpstreamError{ + StatusCode: status, + Message: msg, + Headers: headers, + Body: body, + } +} + +func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) { + reqID := uuid.NewString() + userAgent := soraRandChoice(soraDesktopUserAgents) + powToken := soraGetPowToken(userAgent) + payload := map[string]any{ + "p": powToken, + "flow": soraSentinelFlow, + "id": reqID, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := http.Header{} + headers.Set("Accept", "application/json, text/plain, */*") + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + headers.Set("User-Agent", userAgent) + if accessToken != "" { + headers.Set("Authorization", "Bearer "+accessToken) + } + + urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req" + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + + sentinel := soraBuildSentinelToken(soraSentinelFlow, reqID, powToken, resp, userAgent) + if sentinel == "" { + return "", errors.New("failed to build sentinel token") + } + return sentinel, nil +} + +func soraRandChoice(items []string) string { + if len(items) == 0 { + return "" + } + soraRandMu.Lock() + idx := soraRand.Intn(len(items)) + soraRandMu.Unlock() + return items[idx] +} + +func soraGetPowToken(userAgent string) string { + configList := soraBuildPowConfig(userAgent) + seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64) + difficulty := "0fffff" + solution, _ := soraSolvePow(seed, difficulty, configList) + return "gAAAAAC" + solution +} + +func soraRandFloat() float64 { + soraRandMu.Lock() + defer soraRandMu.Unlock() + return soraRand.Float64() +} + +func soraBuildPowConfig(userAgent string) []any { + screen := soraRandChoice([]string{ + strconv.Itoa(1920 + 1080), + strconv.Itoa(2560 + 1440), + strconv.Itoa(1920 + 1200), + strconv.Itoa(2560 + 1600), + }) + screenVal, _ := strconv.Atoi(screen) + perfMs := float64(time.Since(soraPerfStart).Milliseconds()) + wallMs := float64(time.Now().UnixNano()) / 1e6 + diff := wallMs - perfMs + return []any{ + screenVal, + soraPowParseTime(), + 4294705152, + 0, + userAgent, + soraRandChoice(soraPowScripts), + soraRandChoice(soraPowDPL), + "en-US", + "en-US,es-US,en,es", + 0, + soraRandChoice(soraPowNavigatorKeys), + soraRandChoice(soraPowDocumentKeys), + soraRandChoice(soraPowWindowKeys), + perfMs, + uuid.NewString(), + "", + soraRandChoiceInt(soraPowCores), + diff, + } +} + +func soraRandChoiceInt(items []int) int { + if len(items) == 0 { + return 0 + } + soraRandMu.Lock() + idx := soraRand.Intn(len(items)) + soraRandMu.Unlock() + return items[idx] +} + +func soraPowParseTime() string { + loc := time.FixedZone("EST", -5*3600) + return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)") +} + +func soraSolvePow(seed, difficulty string, configList []any) (string, bool) { + diffLen := len(difficulty) / 2 + target, err := hexDecodeString(difficulty) + if err != nil { + return "", false + } + seedBytes := []byte(seed) + + part1 := mustMarshalJSON(configList[:3]) + part2 := mustMarshalJSON(configList[4:9]) + part3 := mustMarshalJSON(configList[10:]) + + staticPart1 := append(part1[:len(part1)-1], ',') + staticPart2 := append([]byte(","), append(part2[1:len(part2)-1], ',')...) + staticPart3 := append([]byte(","), part3[1:]...) + + for i := 0; i < soraPowMaxIteration; i++ { + dynamicI := []byte(strconv.Itoa(i)) + dynamicJ := []byte(strconv.Itoa(i >> 1)) + finalJSON := make([]byte, 0, len(staticPart1)+len(dynamicI)+len(staticPart2)+len(dynamicJ)+len(staticPart3)) + finalJSON = append(finalJSON, staticPart1...) + finalJSON = append(finalJSON, dynamicI...) + finalJSON = append(finalJSON, staticPart2...) + finalJSON = append(finalJSON, dynamicJ...) + finalJSON = append(finalJSON, staticPart3...) + + b64 := base64.StdEncoding.EncodeToString(finalJSON) + hash := sha3.Sum512(append(seedBytes, []byte(b64)...)) + if bytes.Compare(hash[:diffLen], target[:diffLen]) <= 0 { + return b64, true + } + } + + errorToken := "wQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\"%s\"", seed))) + return errorToken, false +} + +func soraBuildSentinelToken(flow, reqID, powToken string, resp map[string]any, userAgent string) string { + finalPow := powToken + proof, _ := resp["proofofwork"].(map[string]any) + if required, _ := proof["required"].(bool); required { + seed, _ := proof["seed"].(string) + difficulty, _ := proof["difficulty"].(string) + if seed != "" && difficulty != "" { + configList := soraBuildPowConfig(userAgent) + solution, _ := soraSolvePow(seed, difficulty, configList) + finalPow = "gAAAAAB" + solution + } + } + if !strings.HasSuffix(finalPow, "~S") { + finalPow += "~S" + } + turnstile, _ := resp["turnstile"].(map[string]any) + tokenPayload := map[string]any{ + "p": finalPow, + "t": safeMapString(turnstile, "dx"), + "c": safeString(resp["token"]), + "id": reqID, + "flow": flow, + } + encoded, _ := json.Marshal(tokenPayload) + return string(encoded) +} + +func safeMapString(m map[string]any, key string) string { + if m == nil { + return "" + } + if v, ok := m[key]; ok { + return safeString(v) + } + return "" +} + +func safeString(v any) string { + switch val := v.(type) { + case string: + return val + default: + return fmt.Sprintf("%v", val) + } +} + +func mustMarshalJSON(v any) []byte { + b, _ := json.Marshal(v) + return b +} + +func hexDecodeString(s string) ([]byte, error) { + dst := make([]byte, len(s)/2) + _, err := hex.Decode(dst, []byte(s)) + return dst, err +} + +func sanitizeSoraLogURL(raw string) string { + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + q := parsed.Query() + q.Del("sig") + q.Del("expires") + parsed.RawQuery = q.Encode() + return parsed.String() +} diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go new file mode 100644 index 00000000..abbe47a1 --- /dev/null +++ b/backend/internal/service/sora_client_test.go @@ -0,0 +1,54 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSoraDirectClient_DoRequestSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{BaseURL: server.URL}, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + + body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false) + require.NoError(t, err) + require.Contains(t, string(body), "ok") +} + +func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + Headers: map[string]string{ + "X-Test": "yes", + "Authorization": "should-ignore", + "openai-sentinel-token": "skip", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + + headers := client.buildBaseHeaders("token-123", "UA") + require.Equal(t, "Bearer token-123", headers.Get("Authorization")) + require.Equal(t, "UA", headers.Get("User-Agent")) + require.Equal(t, "yes", headers.Get("X-Test")) + require.Empty(t, headers.Get("openai-sentinel-token")) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 2909a76f..49cd7bba 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -4,10 +4,12 @@ import ( "bufio" "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" "io" + "mime" "net/http" "net/url" "regexp" @@ -39,23 +41,23 @@ type soraStreamingResult struct { firstTokenMs *int } -// SoraGatewayService handles forwarding requests to sora2api. +// SoraGatewayService handles forwarding requests to Sora upstream. type SoraGatewayService struct { - sora2api *Sora2APIService - httpUpstream HTTPUpstream + soraClient SoraClient + mediaStorage *SoraMediaStorage rateLimitService *RateLimitService cfg *config.Config } func NewSoraGatewayService( - sora2api *Sora2APIService, - httpUpstream HTTPUpstream, + soraClient SoraClient, + mediaStorage *SoraMediaStorage, rateLimitService *RateLimitService, cfg *config.Config, ) *SoraGatewayService { return &SoraGatewayService{ - sora2api: sora2api, - httpUpstream: httpUpstream, + soraClient: soraClient, + mediaStorage: mediaStorage, rateLimitService: rateLimitService, cfg: cfg, } @@ -64,31 +66,53 @@ func NewSoraGatewayService( func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { startTime := time.Now() - if s.sora2api == nil || !s.sora2api.Enabled() { + if s.soraClient == nil || !s.soraClient.Enabled() { if c != nil { c.JSON(http.StatusServiceUnavailable, gin.H{ "error": gin.H{ "type": "api_error", - "message": "sora2api 未配置", + "message": "Sora 上游未配置", }, }) } - return nil, errors.New("sora2api not configured") + return nil, errors.New("sora upstream not configured") } var reqBody map[string]any if err := json.Unmarshal(body, &reqBody); err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) return nil, fmt.Errorf("parse request: %w", err) } reqModel, _ := reqBody["model"].(string) reqStream, _ := reqBody["stream"].(bool) + if strings.TrimSpace(reqModel) == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) + return nil, errors.New("model is required") + } mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel && mappedModel != "" { - reqBody["model"] = mappedModel - if updated, err := json.Marshal(reqBody); err == nil { - body = updated - } + if mappedModel != "" && mappedModel != reqModel { + reqModel = mappedModel + } + + modelCfg, ok := GetSoraModelConfig(reqModel) + if !ok { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) + return nil, fmt.Errorf("unsupported model: %s", reqModel) + } + if modelCfg.Type == "prompt_enhance" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream) + return nil, fmt.Errorf("prompt-enhance not supported") + } + + prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) + if strings.TrimSpace(prompt) == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + if strings.TrimSpace(videoInput) != "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream) + return nil, errors.New("video input not supported") } reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) @@ -96,81 +120,122 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun defer cancel() } - upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body) - if err != nil { - return nil, err - } - if c != nil { - if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { - upstreamReq.Header.Set("User-Agent", ua) + var imageData []byte + imageFilename := "" + if strings.TrimSpace(imageInput) != "" { + decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) + if err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) + return nil, err } - } - if reqStream { - upstreamReq.Header.Set("Accept", "text/event-stream") + imageData = decoded + imageFilename = filename } - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) + mediaID := "" + if len(imageData) > 0 { + uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + mediaID = uploadID } - proxyURL := "" - if account != nil && account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + taskID := "" + var err error + switch modelCfg.Type { + case "image": + taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ + Prompt: prompt, + Width: modelCfg.Width, + Height: modelCfg.Height, + MediaID: mediaID, + }) + case "video": + taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ + Prompt: prompt, + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + RemixTargetID: remixTargetID, + }) + default: + err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) + } + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) } - var resp *http.Response - if s.httpUpstream != nil { - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if clientStream && c != nil { + s.prepareSoraStream(c, taskID) + } + + var mediaURLs []string + mediaType := modelCfg.Type + imageCount := 0 + imageSize := "" + if modelCfg.Type == "image" { + urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + mediaURLs = urls + imageCount = len(urls) + imageSize = soraImageSizeFromModel(reqModel) + } else if modelCfg.Type == "video" { + urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + mediaURLs = urls } else { - resp, err = http.DefaultClient.Do(upstreamReq) + mediaType = "prompt" } - if err != nil { - s.setUpstreamRequestError(c, account, err) - return nil, fmt.Errorf("upstream request failed: %w", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode >= 400 { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "failover", - Message: upstreamMsg, - }) - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + finalURLs := mediaURLs + if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { + stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) + if storeErr != nil { + return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream) } - return s.handleErrorResponse(ctx, resp, c, account, reqModel) + finalURLs = s.normalizeSoraMediaURLs(stored) + } else { + finalURLs = s.normalizeSoraMediaURLs(mediaURLs) } - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream) - if err != nil { - return nil, err + content := buildSoraContent(mediaType, finalURLs) + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + response := buildSoraNonStreamResponse(content, reqModel) + if len(finalURLs) > 0 { + response["media_url"] = finalURLs[0] + if len(finalURLs) > 1 { + response["media_urls"] = finalURLs + } + } + c.JSON(http.StatusOK, response) } - result := &ForwardResult{ - RequestID: resp.Header.Get("x-request-id"), + return &ForwardResult{ + RequestID: taskID, Model: reqModel, Stream: clientStream, Duration: time.Since(startTime), - FirstTokenMs: streamResult.firstTokenMs, + FirstTokenMs: firstTokenMs, Usage: ClaudeUsage{}, - MediaType: streamResult.mediaType, - MediaURL: firstMediaURL(streamResult.mediaURLs), - ImageCount: streamResult.imageCount, - ImageSize: streamResult.imageSize, - } - - return result, nil + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, + }, nil } func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { @@ -780,3 +845,414 @@ func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) str } return prefix + path + "?" + encoded } + +func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { + if c == nil { + return + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if strings.TrimSpace(requestID) != "" { + c.Header("x-request-id", requestID) + } +} + +func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { + if c == nil { + return nil, nil + } + writer := c.Writer + flusher, _ := writer.(http.Flusher) + + chunk := map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "content": content, + }, + }, + }, + } + encoded, _ := json.Marshal(chunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { + return nil, err + } + if flusher != nil { + flusher.Flush() + } + ms := int(time.Since(startTime).Milliseconds()) + finalChunk := map[string]any{ + "id": chunk["id"], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }, + }, + } + finalEncoded, _ := json.Marshal(finalChunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { + return &ms, err + } + if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { + return &ms, err + } + if flusher != nil { + flusher.Flush() + } + return &ms, nil +} + +func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { + if c == nil { + return + } + if stream { + flusher, _ := c.Writer.(http.Flusher) + errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + _, _ = fmt.Fprint(c.Writer, errorEvent) + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + return + } + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { + if err == nil { + return nil + } + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) { + if s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) + } + if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { + return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode} + } + msg := upstreamErr.Message + if override := soraProErrorMessage(model, msg); override != "" { + msg = override + } + s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) + return err + } + if errors.Is(err, context.DeadlineExceeded) { + s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) + return err + } + s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) + return err +} + +func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetImageTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "succeeded", "completed": + return status.URLs, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("Sora image generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("Sora image generation timeout") +} + +func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetVideoTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "completed", "succeeded": + return status.URLs, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("Sora video generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("Sora video generation timeout") +} + +func (s *SoraGatewayService) pollInterval() time.Duration { + if s == nil || s.cfg == nil { + return 2 * time.Second + } + interval := s.cfg.Sora.Client.PollIntervalSeconds + if interval <= 0 { + interval = 2 + } + return time.Duration(interval) * time.Second +} + +func (s *SoraGatewayService) pollMaxAttempts() int { + if s == nil || s.cfg == nil { + return 600 + } + maxAttempts := s.cfg.Sora.Client.MaxPollAttempts + if maxAttempts <= 0 { + maxAttempts = 600 + } + return maxAttempts +} + +func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { + if c == nil { + return + } + interval := 10 * time.Second + if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { + interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second + } + if time.Since(*lastPing) < interval { + return + } + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + *lastPing = time.Now() + } +} + +func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { + if len(urls) == 0 { + return urls + } + output := make([]string, 0, len(urls)) + for _, raw := range urls { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + output = append(output, raw) + continue + } + pathVal := raw + if !strings.HasPrefix(pathVal, "/") { + pathVal = "/" + pathVal + } + output = append(output, s.buildSoraMediaURL(pathVal, "")) + } + return output +} + +func buildSoraContent(mediaType string, urls []string) string { + switch mediaType { + case "image": + parts := make([]string, 0, len(urls)) + for _, u := range urls { + parts = append(parts, fmt.Sprintf("![image](%s)", u)) + } + return strings.Join(parts, "\n") + case "video": + if len(urls) == 0 { + return "" + } + return fmt.Sprintf("```html\n\n```", urls[0]) + default: + return "" + } +} + +func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { + if body == nil { + return "", "", "", "" + } + if v, ok := body["remix_target_id"].(string); ok { + remixTargetID = v + } + if v, ok := body["image"].(string); ok { + imageInput = v + } + if v, ok := body["video"].(string); ok { + videoInput = v + } + if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { + prompt = v + } + if messages, ok := body["messages"].([]any); ok { + builder := strings.Builder{} + for _, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if role != "" && role != "user" { + continue + } + content := msg["content"] + text, img, vid := parseSoraMessageContent(content) + if text != "" { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString(text) + } + if imageInput == "" && img != "" { + imageInput = img + } + if videoInput == "" && vid != "" { + videoInput = vid + } + } + if prompt == "" { + prompt = builder.String() + } + } + return prompt, imageInput, videoInput, remixTargetID +} + +func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { + switch val := content.(type) { + case string: + return val, "", "" + case []any: + builder := strings.Builder{} + for _, item := range val { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + t, _ := itemMap["type"].(string) + switch t { + case "text": + if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString(txt) + } + case "image_url": + if imageInput == "" { + if urlVal, ok := itemMap["image_url"].(map[string]any); ok { + imageInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["image_url"].(string); ok { + imageInput = urlStr + } + } + case "video_url": + if videoInput == "" { + if urlVal, ok := itemMap["video_url"].(map[string]any); ok { + videoInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["video_url"].(string); ok { + videoInput = urlStr + } + } + } + } + return builder.String(), imageInput, videoInput + default: + return "", "", "" + } +} + +func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, "", errors.New("empty image input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, "", errors.New("invalid data url") + } + meta := parts[0] + payload := parts[1] + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return nil, "", err + } + ext := "" + if strings.HasPrefix(meta, "data:") { + metaParts := strings.SplitN(meta[5:], ";", 2) + if len(metaParts) > 0 { + if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { + ext = exts[0] + } + } + } + filename := "image" + ext + return decoded, filename, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraImageInput(ctx, raw) + } + decoded, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, "", errors.New("invalid base64 image") + } + return decoded, "image.png", nil +} + +func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return nil, "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, "", err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20)) + if err != nil { + return nil, "", err + } + ext := fileExtFromURL(rawURL) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + filename := "image" + ext + return data, filename, nil +} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go new file mode 100644 index 00000000..e4de8256 --- /dev/null +++ b/backend/internal/service/sora_gateway_service_test.go @@ -0,0 +1,99 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type stubSoraClientForPoll struct { + imageStatus *SoraImageTaskStatus + videoStatus *SoraVideoTaskStatus + imageCalls int + videoCalls int +} + +func (s *stubSoraClientForPoll) Enabled() bool { return true } +func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + return "", nil +} +func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + s.imageCalls++ + return s.imageStatus, nil +} +func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + s.videoCalls++ + return s.videoStatus, nil +} + +func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { + client := &stubSoraClientForPoll{ + imageStatus: &SoraImageTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/a.png"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) + require.NoError(t, err) + require.Equal(t, []string{"https://example.com/a.png"}, urls) + require.Equal(t, 1, client.imageCalls) +} + +func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "reject", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false) + require.Error(t, err) + require.Empty(t, urls) + require.Contains(t, err.Error(), "reject") + require.Equal(t, 1, client.videoCalls) +} + +func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + SoraMediaSigningKey: "test-key", + SoraMediaSignedURLTTLSeconds: 600, + }, + } + service := NewSoraGatewayService(nil, nil, nil, cfg) + + url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") + require.Contains(t, url, "/sora/media-signed") + require.Contains(t, url, "expires=") + require.Contains(t, url, "sig=") +} diff --git a/backend/internal/service/sora_media_cleanup_service.go b/backend/internal/service/sora_media_cleanup_service.go new file mode 100644 index 00000000..7de0f1c4 --- /dev/null +++ b/backend/internal/service/sora_media_cleanup_service.go @@ -0,0 +1,117 @@ +package service + +import ( + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/robfig/cron/v3" +) + +var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + +// SoraMediaCleanupService 定期清理本地媒体文件 +type SoraMediaCleanupService struct { + storage *SoraMediaStorage + cfg *config.Config + + cron *cron.Cron + + startOnce sync.Once + stopOnce sync.Once +} + +func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { + return &SoraMediaCleanupService{ + storage: storage, + cfg: cfg, + } +} + +func (s *SoraMediaCleanupService) Start() { + if s == nil || s.cfg == nil { + return + } + if !s.cfg.Sora.Storage.Cleanup.Enabled { + log.Printf("[SoraCleanup] not started (disabled)") + return + } + if s.storage == nil || !s.storage.Enabled() { + log.Printf("[SoraCleanup] not started (storage disabled)") + return + } + + s.startOnce.Do(func() { + schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule) + if schedule == "" { + log.Printf("[SoraCleanup] not started (empty schedule)") + return + } + loc := time.Local + if strings.TrimSpace(s.cfg.Timezone) != "" { + if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil { + loc = parsed + } + } + c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc)) + if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil { + log.Printf("[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err) + return + } + s.cron = c + s.cron.Start() + log.Printf("[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) + }) +} + +func (s *SoraMediaCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.cron != nil { + ctx := s.cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(3 * time.Second): + log.Printf("[SoraCleanup] cron stop timed out") + } + } + }) +} + +func (s *SoraMediaCleanupService) runCleanup() { + retention := s.cfg.Sora.Storage.Cleanup.RetentionDays + if retention <= 0 { + log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention) + return + } + cutoff := time.Now().AddDate(0, 0, -retention) + deleted := 0 + + roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()} + for _, root := range roots { + if root == "" { + continue + } + _ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + return nil + } + if info.ModTime().Before(cutoff) { + if rmErr := os.Remove(p); rmErr == nil { + deleted++ + } + } + return nil + }) + } + log.Printf("[SoraCleanup] cleanup finished, deleted=%d", deleted) +} diff --git a/backend/internal/service/sora_media_cleanup_service_test.go b/backend/internal/service/sora_media_cleanup_service_test.go new file mode 100644 index 00000000..63204104 --- /dev/null +++ b/backend/internal/service/sora_media_cleanup_service_test.go @@ -0,0 +1,46 @@ +//go:build unit + +package service + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSoraMediaCleanupService_RunCleanup(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + Cleanup: config.SoraStorageCleanupConfig{ + Enabled: true, + RetentionDays: 1, + }, + }, + }, + } + + storage := NewSoraMediaStorage(cfg) + require.NoError(t, storage.EnsureLocalDirs()) + + oldImage := filepath.Join(storage.ImageRoot(), "old.png") + newVideo := filepath.Join(storage.VideoRoot(), "new.mp4") + require.NoError(t, os.WriteFile(oldImage, []byte("old"), 0o644)) + require.NoError(t, os.WriteFile(newVideo, []byte("new"), 0o644)) + + oldTime := time.Now().Add(-48 * time.Hour) + require.NoError(t, os.Chtimes(oldImage, oldTime, oldTime)) + + cleanup := NewSoraMediaCleanupService(storage, cfg) + cleanup.runCleanup() + + require.NoFileExists(t, oldImage) + require.FileExists(t, newVideo) +} diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go new file mode 100644 index 00000000..53214bb7 --- /dev/null +++ b/backend/internal/service/sora_media_storage.go @@ -0,0 +1,256 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "mime" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/google/uuid" +) + +const ( + soraStorageDefaultRoot = "/app/data/sora" +) + +// SoraMediaStorage 负责下载并落地 Sora 媒体 +type SoraMediaStorage struct { + cfg *config.Config + root string + imageRoot string + videoRoot string + maxConcurrent int + fallbackToUpstream bool + debug bool + sem chan struct{} + ready bool +} + +func NewSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { + storage := &SoraMediaStorage{cfg: cfg} + storage.refreshConfig() + if storage.Enabled() { + if err := storage.EnsureLocalDirs(); err != nil { + log.Printf("[SoraStorage] 初始化失败: %v", err) + } + } + return storage +} + +func (s *SoraMediaStorage) Enabled() bool { + if s == nil || s.cfg == nil { + return false + } + return strings.ToLower(strings.TrimSpace(s.cfg.Sora.Storage.Type)) == "local" +} + +func (s *SoraMediaStorage) Root() string { + if s == nil { + return "" + } + return s.root +} + +func (s *SoraMediaStorage) ImageRoot() string { + if s == nil { + return "" + } + return s.imageRoot +} + +func (s *SoraMediaStorage) VideoRoot() string { + if s == nil { + return "" + } + return s.videoRoot +} + +func (s *SoraMediaStorage) refreshConfig() { + if s == nil || s.cfg == nil { + return + } + root := strings.TrimSpace(s.cfg.Sora.Storage.LocalPath) + if root == "" { + root = soraStorageDefaultRoot + } + s.root = root + s.imageRoot = filepath.Join(root, "image") + s.videoRoot = filepath.Join(root, "video") + + maxConcurrent := s.cfg.Sora.Storage.MaxConcurrentDownloads + if maxConcurrent <= 0 { + maxConcurrent = 4 + } + s.maxConcurrent = maxConcurrent + s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream + s.debug = s.cfg.Sora.Storage.Debug + s.sem = make(chan struct{}, maxConcurrent) +} + +// EnsureLocalDirs 创建并校验本地目录 +func (s *SoraMediaStorage) EnsureLocalDirs() error { + if s == nil || !s.Enabled() { + return nil + } + if err := os.MkdirAll(s.imageRoot, 0o755); err != nil { + return fmt.Errorf("create image dir: %w", err) + } + if err := os.MkdirAll(s.videoRoot, 0o755); err != nil { + return fmt.Errorf("create video dir: %w", err) + } + s.ready = true + return nil +} + +// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL +func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, urls []string) ([]string, error) { + if len(urls) == 0 { + return nil, nil + } + if s == nil || !s.Enabled() { + return urls, nil + } + if !s.ready { + if err := s.EnsureLocalDirs(); err != nil { + return nil, err + } + } + results := make([]string, 0, len(urls)) + for _, raw := range urls { + relative, err := s.downloadAndStore(ctx, mediaType, raw) + if err != nil { + if s.fallbackToUpstream { + results = append(results, raw) + continue + } + return nil, err + } + results = append(results, relative) + } + return results, nil +} + +func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) { + if strings.TrimSpace(rawURL) == "" { + return "", errors.New("empty url") + } + root := s.imageRoot + if mediaType == "video" { + root = s.videoRoot + } + if root == "" { + return "", errors.New("storage root not configured") + } + + retries := 3 + for attempt := 1; attempt <= retries; attempt++ { + release, err := s.acquire(ctx) + if err != nil { + return "", err + } + relative, err := s.downloadOnce(ctx, root, mediaType, rawURL) + release() + if err == nil { + return relative, nil + } + if s.debug { + log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err) + } + if attempt < retries { + time.Sleep(time.Duration(attempt*attempt) * time.Second) + continue + } + return "", err + } + return "", errors.New("download retries exhausted") +} + +func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, rawURL string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body)) + } + + ext := fileExtFromURL(rawURL) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + if ext == "" { + ext = ".bin" + } + + datePath := time.Now().Format("2006/01/02") + destDir := filepath.Join(root, filepath.FromSlash(datePath)) + if err := os.MkdirAll(destDir, 0o755); err != nil { + return "", err + } + filename := uuid.NewString() + ext + destPath := filepath.Join(destDir, filename) + out, err := os.Create(destPath) + if err != nil { + return "", err + } + defer func() { _ = out.Close() }() + + if _, err := io.Copy(out, resp.Body); err != nil { + _ = os.Remove(destPath) + return "", err + } + + relative := path.Join("/", mediaType, datePath, filename) + if s.debug { + log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative) + } + return relative, nil +} + +func (s *SoraMediaStorage) acquire(ctx context.Context) (func(), error) { + if s.sem == nil { + return func() {}, nil + } + select { + case s.sem <- struct{}{}: + return func() { <-s.sem }, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func fileExtFromURL(raw string) string { + parsed, err := url.Parse(raw) + if err != nil { + return "" + } + ext := path.Ext(parsed.Path) + return strings.ToLower(ext) +} + +func fileExtFromContentType(ct string) string { + if ct == "" { + return "" + } + if exts, err := mime.ExtensionsByType(ct); err == nil && len(exts) > 0 { + return strings.ToLower(exts[0]) + } + return "" +} diff --git a/backend/internal/service/sora_media_storage_test.go b/backend/internal/service/sora_media_storage_test.go new file mode 100644 index 00000000..f86234d2 --- /dev/null +++ b/backend/internal/service/sora_media_storage_test.go @@ -0,0 +1,69 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSoraMediaStorage_StoreFromURLs(t *testing.T) { + tmpDir := t.TempDir() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + MaxConcurrentDownloads: 1, + }, + }, + } + + storage := NewSoraMediaStorage(cfg) + urls, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"}) + require.NoError(t, err) + require.Len(t, urls, 1) + require.True(t, strings.HasPrefix(urls[0], "/image/")) + require.True(t, strings.HasSuffix(urls[0], ".png")) + + localPath := filepath.Join(tmpDir, filepath.FromSlash(strings.TrimPrefix(urls[0], "/"))) + require.FileExists(t, localPath) +} + +func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) { + tmpDir := t.TempDir() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + FallbackToUpstream: true, + }, + }, + } + + storage := NewSoraMediaStorage(cfg) + url := server.URL + "/broken.png" + urls, err := storage.StoreFromURLs(context.Background(), "image", []string{url}) + require.NoError(t, err) + require.Equal(t, []string{url}, urls) +} diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go new file mode 100644 index 00000000..ab095e46 --- /dev/null +++ b/backend/internal/service/sora_models.go @@ -0,0 +1,252 @@ +package service + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" +) + +// SoraModelConfig Sora 模型配置 +type SoraModelConfig struct { + Type string + Width int + Height int + Orientation string + Frames int + Model string + Size string + RequirePro bool +} + +var soraModelConfigs = map[string]SoraModelConfig{ + "gpt-image": { + Type: "image", + Width: 360, + Height: 360, + }, + "gpt-image-landscape": { + Type: "image", + Width: 540, + Height: 360, + }, + "gpt-image-portrait": { + Type: "image", + Width: 360, + Height: 540, + }, + "sora2-landscape-10s": { + Type: "video", + Orientation: "landscape", + Frames: 300, + Model: "sy_8", + Size: "small", + }, + "sora2-portrait-10s": { + Type: "video", + Orientation: "portrait", + Frames: 300, + Model: "sy_8", + Size: "small", + }, + "sora2-landscape-15s": { + Type: "video", + Orientation: "landscape", + Frames: 450, + Model: "sy_8", + Size: "small", + }, + "sora2-portrait-15s": { + Type: "video", + Orientation: "portrait", + Frames: 450, + Model: "sy_8", + Size: "small", + }, + "sora2-landscape-25s": { + Type: "video", + Orientation: "landscape", + Frames: 750, + Model: "sy_8", + Size: "small", + RequirePro: true, + }, + "sora2-portrait-25s": { + Type: "video", + Orientation: "portrait", + Frames: 750, + Model: "sy_8", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-10s": { + Type: "video", + Orientation: "landscape", + Frames: 300, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-10s": { + Type: "video", + Orientation: "portrait", + Frames: 300, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-15s": { + Type: "video", + Orientation: "landscape", + Frames: 450, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-15s": { + Type: "video", + Orientation: "portrait", + Frames: 450, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-25s": { + Type: "video", + Orientation: "landscape", + Frames: 750, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-25s": { + Type: "video", + Orientation: "portrait", + Frames: 750, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-hd-landscape-10s": { + Type: "video", + Orientation: "landscape", + Frames: 300, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-portrait-10s": { + Type: "video", + Orientation: "portrait", + Frames: 300, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-landscape-15s": { + Type: "video", + Orientation: "landscape", + Frames: 450, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-portrait-15s": { + Type: "video", + Orientation: "portrait", + Frames: 450, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "prompt-enhance-short-10s": { + Type: "prompt_enhance", + }, + "prompt-enhance-short-15s": { + Type: "prompt_enhance", + }, + "prompt-enhance-short-20s": { + Type: "prompt_enhance", + }, + "prompt-enhance-medium-10s": { + Type: "prompt_enhance", + }, + "prompt-enhance-medium-15s": { + Type: "prompt_enhance", + }, + "prompt-enhance-medium-20s": { + Type: "prompt_enhance", + }, + "prompt-enhance-long-10s": { + Type: "prompt_enhance", + }, + "prompt-enhance-long-15s": { + Type: "prompt_enhance", + }, + "prompt-enhance-long-20s": { + Type: "prompt_enhance", + }, +} + +var soraModelIDs = []string{ + "gpt-image", + "gpt-image-landscape", + "gpt-image-portrait", + "sora2-landscape-10s", + "sora2-portrait-10s", + "sora2-landscape-15s", + "sora2-portrait-15s", + "sora2-landscape-25s", + "sora2-portrait-25s", + "sora2pro-landscape-10s", + "sora2pro-portrait-10s", + "sora2pro-landscape-15s", + "sora2pro-portrait-15s", + "sora2pro-landscape-25s", + "sora2pro-portrait-25s", + "sora2pro-hd-landscape-10s", + "sora2pro-hd-portrait-10s", + "sora2pro-hd-landscape-15s", + "sora2pro-hd-portrait-15s", + "prompt-enhance-short-10s", + "prompt-enhance-short-15s", + "prompt-enhance-short-20s", + "prompt-enhance-medium-10s", + "prompt-enhance-medium-15s", + "prompt-enhance-medium-20s", + "prompt-enhance-long-10s", + "prompt-enhance-long-15s", + "prompt-enhance-long-20s", +} + +// GetSoraModelConfig 返回 Sora 模型配置 +func GetSoraModelConfig(model string) (SoraModelConfig, bool) { + key := strings.ToLower(strings.TrimSpace(model)) + cfg, ok := soraModelConfigs[key] + return cfg, ok +} + +// DefaultSoraModels returns the default Sora model list. +func DefaultSoraModels(cfg *config.Config) []openai.Model { + models := make([]openai.Model, 0, len(soraModelIDs)) + for _, id := range soraModelIDs { + models = append(models, openai.Model{ + ID: id, + Object: "model", + OwnedBy: "openai", + Type: "model", + DisplayName: id, + }) + } + if cfg != nil && cfg.Gateway.SoraModelFilters.HidePromptEnhance { + filtered := models[:0] + for _, model := range models { + if strings.HasPrefix(strings.ToLower(model.ID), "prompt-enhance") { + continue + } + filtered = append(filtered, model) + } + models = filtered + } + return models +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 435056ab..7dccf393 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -63,16 +63,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { } } -// SetSoraSyncService 设置 Sora2API 同步服务 -// 需要在 Start() 之前调用 -func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) { - for _, refresher := range s.refreshers { - if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { - openaiRefresher.SetSoraSyncService(svc) - } - } -} - // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 7e084bd5..46033f75 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -86,7 +86,6 @@ type OpenAITokenRefresher struct { openaiOAuthService *OpenAIOAuthService accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 - soraSyncService *Sora2APISyncService // Sora2API 同步服务 } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -104,11 +103,6 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { r.soraAccountRepo = repo } -// SetSoraSyncService 设置 Sora2API 同步服务 -func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) { - r.soraSyncService = svc -} - // CanRefresh 检查是否能处理此账号 // 只处理 openai 平台的 oauth 类型账号 func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { @@ -151,17 +145,6 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) } - // 如果是 Sora 平台账号,同步到 sora2api(不阻塞主流程) - if account.Platform == PlatformSora && r.soraSyncService != nil { - syncAccount := *account - syncAccount.Credentials = newCredentials - go func() { - if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil { - log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err) - } - }() - } - return newCredentials, nil } @@ -218,13 +201,6 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena } } - // 2.3 同步到 sora2api(如果配置) - if r.soraSyncService != nil { - if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil { - log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err) - } - } - log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v", soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil) } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index fb0946d2..9c13be93 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { func ProvideTokenRefreshService( accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步 - soraSyncService *Sora2APISyncService, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, @@ -51,7 +50,6 @@ func ProvideTokenRefreshService( svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) - svc.SetSoraSyncService(soraSyncService) svc.Start() return svc } @@ -187,6 +185,18 @@ func ProvideOpsCleanupService( return svc } +// ProvideSoraMediaStorage 初始化 Sora 媒体存储 +func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { + return NewSoraMediaStorage(cfg) +} + +// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务 +func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { + svc := NewSoraMediaCleanupService(storage, cfg) + svc.Start() + return svc +} + // ProvideOpsScheduledReportService creates and starts OpsScheduledReportService. func ProvideOpsScheduledReportService( opsService *OpsService, @@ -226,6 +236,10 @@ var ProviderSet = wire.NewSet( NewBillingCacheService, NewAdminService, NewGatewayService, + ProvideSoraMediaStorage, + ProvideSoraMediaCleanupService, + NewSoraDirectClient, + wire.Bind(new(SoraClient), new(*SoraDirectClient)), NewSoraGatewayService, NewOpenAIGatewayService, NewOAuthService, diff --git a/build_image.sh b/build_image.sh new file mode 100755 index 00000000..2cea4925 --- /dev/null +++ b/build_image.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# 本地构建镜像的快速脚本,避免在命令行反复输入构建参数。 + +docker build -t sub2api:latest \ + --build-arg GOPROXY=https://goproxy.cn,direct \ + --build-arg GOSUMDB=sum.golang.google.cn \ + -f Dockerfile \ + . diff --git a/deploy/Dockerfile b/deploy/Dockerfile new file mode 100644 index 00000000..b3320300 --- /dev/null +++ b/deploy/Dockerfile @@ -0,0 +1,111 @@ +# ============================================================================= +# Sub2API Multi-Stage Dockerfile +# ============================================================================= +# Stage 1: Build frontend +# Stage 2: Build Go backend with embedded frontend +# Stage 3: Final minimal image +# ============================================================================= + +ARG NODE_IMAGE=node:24-alpine +ARG GOLANG_IMAGE=golang:1.25.5-alpine +ARG ALPINE_IMAGE=alpine:3.20 +ARG GOPROXY=https://goproxy.cn,direct +ARG GOSUMDB=sum.golang.google.cn + +# ----------------------------------------------------------------------------- +# Stage 1: Frontend Builder +# ----------------------------------------------------------------------------- +FROM ${NODE_IMAGE} AS frontend-builder + +WORKDIR /app/frontend + +# Install pnpm +RUN corepack enable && corepack prepare pnpm@latest --activate + +# Install dependencies first (better caching) +COPY frontend/package.json frontend/pnpm-lock.yaml ./ +RUN pnpm install --frozen-lockfile + +# Copy frontend source and build +COPY frontend/ ./ +RUN pnpm run build + +# ----------------------------------------------------------------------------- +# Stage 2: Backend Builder +# ----------------------------------------------------------------------------- +FROM ${GOLANG_IMAGE} AS backend-builder + +# Build arguments for version info (set by CI) +ARG VERSION=docker +ARG COMMIT=docker +ARG DATE +ARG GOPROXY +ARG GOSUMDB + +ENV GOPROXY=${GOPROXY} +ENV GOSUMDB=${GOSUMDB} + +# Install build dependencies +RUN apk add --no-cache git ca-certificates tzdata + +WORKDIR /app/backend + +# Copy go mod files first (better caching) +COPY backend/go.mod backend/go.sum ./ +RUN go mod download + +# Copy backend source first +COPY backend/ ./ + +# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten) +COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist + +# Build the binary (BuildType=release for CI builds, embed frontend) +RUN CGO_ENABLED=0 GOOS=linux go build \ + -tags embed \ + -ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \ + -o /app/sub2api \ + ./cmd/server + +# ----------------------------------------------------------------------------- +# Stage 3: Final Runtime Image +# ----------------------------------------------------------------------------- +FROM ${ALPINE_IMAGE} + +# Labels +LABEL maintainer="Wei-Shaw " +LABEL description="Sub2API - AI API Gateway Platform" +LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" + +# Install runtime dependencies +RUN apk add --no-cache \ + ca-certificates \ + tzdata \ + curl \ + && rm -rf /var/cache/apk/* + +# Create non-root user +RUN addgroup -g 1000 sub2api && \ + adduser -u 1000 -G sub2api -s /bin/sh -D sub2api + +# Set working directory +WORKDIR /app + +# Copy binary from builder +COPY --from=backend-builder /app/sub2api /app/sub2api + +# Create data directory +RUN mkdir -p /app/data && chown -R sub2api:sub2api /app + +# Switch to non-root user +USER sub2api + +# Expose port (can be overridden by SERVER_PORT env var) +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ + CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + +# Run the application +ENTRYPOINT ["/app/sub2api"] diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 99386fc9..2c7a1778 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -249,32 +249,64 @@ gateway: # name: "Custom Profile 2" # ============================================================================= -# Sora2API Configuration -# Sora2API 配置 +# Sora Direct Client Configuration +# Sora 直连配置 # ============================================================================= -sora2api: - # Sora2API base URL - # Sora2API 服务地址 - base_url: "http://127.0.0.1:8000" - # Sora2API API Key (for /v1/chat/completions and /v1/models) - # Sora2API API Key(用于生成/模型列表) - api_key: "" - # Admin username/password (for token sync) - # 管理口用户名/密码(用于 token 同步) - admin_username: "admin" - admin_password: "admin" - # Admin token cache ttl (seconds) - # 管理口 token 缓存时长(秒) - admin_token_ttl_seconds: 900 - # Admin request timeout (seconds) - # 管理口请求超时(秒) - admin_timeout_seconds: 10 - # Token import mode: at/offline - # Token 导入模式:at/offline - token_import_mode: "at" - # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] - # curves: [29, 23, 24] - # point_formats: [0] +sora: + client: + # Sora backend base URL + # Sora 上游 Base URL + base_url: "https://sora.chatgpt.com/backend" + # Request timeout (seconds) + # 请求超时(秒) + timeout_seconds: 120 + # Max retries for upstream requests + # 上游请求最大重试次数 + max_retries: 3 + # Poll interval (seconds) + # 轮询间隔(秒) + poll_interval_seconds: 2 + # Max poll attempts + # 最大轮询次数 + max_poll_attempts: 600 + # Enable debug logs for Sora upstream requests + # 启用 Sora 直连调试日志 + debug: false + # Optional custom headers (key-value) + # 额外请求头(键值对) + headers: {} + # Default User-Agent for Sora requests + # Sora 默认 User-Agent + user_agent: "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" + # Disable TLS fingerprint for Sora upstream + # 关闭 Sora 上游 TLS 指纹伪装 + disable_tls_fingerprint: false + storage: + # Storage type (local only for now) + # 存储类型(首发仅支持 local) + type: "local" + # Local base path; empty uses /app/data/sora + # 本地存储基础路径;为空使用 /app/data/sora + local_path: "" + # Fallback to upstream URL when download fails + # 下载失败时回退到上游 URL + fallback_to_upstream: true + # Max concurrent downloads + # 并发下载上限 + max_concurrent_downloads: 4 + # Enable debug logs for media storage + # 启用媒体存储调试日志 + debug: false + cleanup: + # Enable cleanup task + # 启用清理任务 + enabled: true + # Retention days + # 保留天数 + retention_days: 7 + # Cron schedule + # Cron 调度表达式 + schedule: "0 3 * * *" # ============================================================================= # API Key Auth Cache Configuration diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 505c1419..e86f6348 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -18,7 +18,6 @@ import geminiAPI from './gemini' import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' -import modelsAPI from './models' /** * Unified admin API object for convenient access @@ -38,8 +37,7 @@ export const adminAPI = { gemini: geminiAPI, antigravity: antigravityAPI, userAttributes: userAttributesAPI, - ops: opsAPI, - models: modelsAPI + ops: opsAPI } export { @@ -57,8 +55,7 @@ export { geminiAPI, antigravityAPI, userAttributesAPI, - opsAPI, - modelsAPI + opsAPI } export default adminAPI diff --git a/frontend/src/api/admin/models.ts b/frontend/src/api/admin/models.ts deleted file mode 100644 index 897304ac..00000000 --- a/frontend/src/api/admin/models.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { apiClient } from '@/api/client' - -export async function getPlatformModels(platform: string): Promise { - const { data } = await apiClient.get('/admin/models', { - params: { platform } - }) - return data -} - -export const modelsAPI = { - getPlatformModels -} - -export default modelsAPI diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 0e81a717..30ec9e63 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1501,9 +1501,9 @@
-
diff --git a/frontend/src/components/account/ModelWhitelistSelector.vue b/frontend/src/components/account/ModelWhitelistSelector.vue index 227e6e61..16ffa225 100644 --- a/frontend/src/components/account/ModelWhitelistSelector.vue +++ b/frontend/src/components/account/ModelWhitelistSelector.vue @@ -45,19 +45,6 @@ :placeholder="t('admin.accounts.searchModels')" @click.stop /> -
- - {{ t('admin.accounts.soraModelsLoading') }} - - -
+ +
+
+ + +

{{ t('admin.accounts.upstream.baseUrlHint') }}

+
+
+ + +

{{ t('admin.accounts.leaveEmptyToKeep') }}

+
+
+
@@ -1244,6 +1268,9 @@ watch( } else { selectedErrorCodes.value = [] } + } else if (newAccount.type === 'upstream' && newAccount.credentials) { + const credentials = newAccount.credentials as Record + editBaseUrl.value = (credentials.base_url as string) || '' } else { const platformDefaultUrl = newAccount.platform === 'openai' @@ -1584,6 +1611,22 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if (props.account.type === 'upstream') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.base_url = editBaseUrl.value.trim() + + if (editApiKey.value.trim()) { + newCredentials.api_key = editApiKey.value.trim() + } + + if (!applyTempUnschedConfig(newCredentials)) { + submitting.value = false + return + } + updatePayload.credentials = newCredentials } else { // For oauth/setup-token types, only update intercept_warmup_requests if changed From 1563bd3dda85e7f18058357fc8fcfdc4308c94ef Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:33:09 +0800 Subject: [PATCH 053/363] feat(upstream): passthrough all client headers instead of manual header setting Replace manual header setting (Content-Type, anthropic-version, anthropic-beta) with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini. Only authentication headers (Authorization, x-api-key) are overridden with upstream account credentials. Hop-by-hop headers are excluded. Add unit tests covering header passthrough, auth override, and hop-by-hop filtering. --- .../service/antigravity_gateway_service.go | 312 ++++-------------- .../upstream_header_passthrough_test.go | 285 ++++++++++++++++ 2 files changed, 352 insertions(+), 245 deletions(-) create mode 100644 backend/internal/service/upstream_header_passthrough_test.go diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fd53ba71..fc29eeb3 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -47,6 +47,21 @@ const ( googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) +// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头 +var upstreamHopByHopHeaders = map[string]bool{ + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "proxy-connection": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, + "host": true, + "content-length": true, +} + // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // 匹配时使用 strings.Contains,无需完全匹配 var antigravityPassthroughErrorMessages = []string{ @@ -3456,10 +3471,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - loadModel := mappedModel - thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 代理 URL proxyURL := "" @@ -3469,98 +3480,38 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 统计模型调用次数 if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) + _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) } apiURL := baseURL + "/antigravity/v1/messages" log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - // 透传 anthropic headers - if v := c.GetHeader("anthropic-version"); v != "" { - req.Header.Set("anthropic-version", v) - } else { - req.Header.Set("anthropic-version", "2023-06-01") - } - if v := c.GetHeader("anthropic-beta"); v != "" { - req.Header.Set("anthropic-beta", v) - } - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break // 成功或非限流错误,跳出重试 + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) } defer func() { _ = resp.Body.Close() }() @@ -3568,44 +3519,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - // signature 重试 - if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) { - log.Printf("%s upstream signature error, retrying with thinking stripped", prefix) - retryClaudeReq := claudeReq - retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...) - if stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq); stripErr == nil && stripped { - retryBody, _ := json.Marshal(&retryClaudeReq) - retryReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(retryBody)) - if err == nil { - retryReq.Header.Set("Content-Type", "application/json") - retryReq.Header.Set("Authorization", "Bearer "+apiKey) - retryReq.Header.Set("x-api-key", apiKey) - retryReq.Header.Set("anthropic-version", "2023-06-01") - if v := c.GetHeader("anthropic-beta"); v != "" { - retryReq.Header.Set("anthropic-beta", v) - } - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) - if retryErr == nil && retryResp != nil && retryResp.StatusCode < 400 { - resp = retryResp - goto upstreamClaudeSuccess - } - if retryResp != nil { - _ = retryResp.Body.Close() - } - } - } - } - - // prompt too long - if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { - return nil, &PromptTooLongError{ - StatusCode: resp.StatusCode, - RequestID: resp.Header.Get("x-request-id"), - Body: respBody, - } - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} @@ -3614,7 +3528,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) } -upstreamClaudeSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) @@ -3674,7 +3588,6 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) imageSize := s.extractImageSize(body) @@ -3712,143 +3625,52 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c } // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - upstreamAction := action - if action == "generateContent" && !stream { - // 非流式也用 streamGenerateContent,与 OAuth 路径行为一致 - upstreamAction = action - } - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { + apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) + if stream || action == "streamGenerateContent" { apiURL += "?alt=sse" } - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, upstreamAction) + log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - // 预检查:模型级限流 - if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(remaining): - } - } else { - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusServiceUnavailable, - ForceCacheBilling: isStickySession, - } + // 构建请求:body 原样透传 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") + } + // 透传客户端所有请求头(排除 hop-by-hop 和认证头) + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) } } + // 覆盖认证头 + req.Header.Set("Authorization", "Bearer "+apiKey) - // 重试循环 - var resp *http.Response - var lastErr error - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - lastErr = err - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - // 429/503 重试 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - return nil, ctx.Err() - } - continue - } - - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ForceCacheBilling: isStickySession, - } - } - - break + if c != nil && len(body) > 0 { + c.Set(OpsUpstreamRequestBodyKey, string(body)) } - if resp == nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("upstream request failed: %v", lastErr)) + + // 单次发送,不重试 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() + defer func() { _ = resp.Body.Close() }() // 错误响应处理 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) contentType := resp.Header.Get("Content-Type") - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - // 模型兜底 - if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) && - isModelNotFoundError(resp.StatusCode, respBody) { - fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) - if fallbackModel != "" && fallbackModel != mappedModel { - log.Printf("[Antigravity-Upstream] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) - fallbackURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, fallbackModel, upstreamAction) - if stream || upstreamAction == "streamGenerateContent" { - fallbackURL += "?alt=sse" - } - fallbackReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fallbackURL, bytes.NewReader(body)) - if err == nil { - fallbackReq.Header.Set("Content-Type", "application/json") - fallbackReq.Header.Set("Authorization", "Bearer "+apiKey) - fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency) - if err == nil && fallbackResp.StatusCode < 400 { - _ = resp.Body.Close() - resp = fallbackResp - } else if fallbackResp != nil { - _ = fallbackResp.Body.Close() - } - } - } - } - - // fallback 成功 - if resp.StatusCode < 400 { - goto upstreamGeminiSuccess - } requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := s.getUpstreamErrorDetail(respBody) @@ -3886,7 +3708,7 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } -upstreamGeminiSuccess: + // 成功响应 requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go new file mode 100644 index 00000000..51d8588b --- /dev/null +++ b/backend/internal/service/upstream_header_passthrough_test.go @@ -0,0 +1,285 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// httpUpstreamCapture captures the outgoing *http.Request for assertion. +type httpUpstreamCapture struct { + capturedReq *http.Request + resp *http.Response + err error +} + +func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + s.capturedReq = req + return s.resp, s.err +} + +func newUpstreamAccount() *Account { + return &Account{ + ID: 100, + Name: "upstream-test", + Platform: PlatformAntigravity, + Type: AccountTypeUpstream, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "base_url": "https://upstream.example.com", + "api_key": "sk-upstream-secret", + }, + } +} + +// makeSSEOKResponse builds a minimal SSE response that +// handleClaudeStreamingResponse / handleGeminiStreamingResponse +// can consume without error. +// We return 502 to bypass streaming and hit the error branch instead, +// which is sufficient for testing header passthrough. +func makeUpstreamErrorResponse() *http.Response { + body := []byte(`{"error":{"message":"test error"}}`) + return &http.Response{ + StatusCode: http.StatusBadGateway, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader(body)), + } +} + +// --- ForwardUpstream tests --- + +func TestForwardUpstream_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("anthropic-version", "2024-10-22") + req.Header.Set("anthropic-beta", "output-128k-2025-02-19") + req.Header.Set("X-Custom-Header", "custom-value") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) + require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) + require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) +} + +func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // 客户端发来的认证头应被覆盖 + req.Header.Set("Authorization", "Bearer client-token") + req.Header.Set("x-api-key", "client-api-key") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key,而非客户端的 + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) + require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) +} + +func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{{"role": "user", "content": "hi"}}, + "max_tokens": 1, + "stream": false, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Keep-Alive", "timeout=5") + req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Te", "trailers") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Keep-Alive")) + require.Empty(t, captured.Header.Get("Transfer-Encoding")) + require.Empty(t, captured.Header.Get("Upgrade")) + require.Empty(t, captured.Header.Get("Te")) + + // 但普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +} + +// --- ForwardUpstreamGemini tests --- + +func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Custom-Gemini", "gemini-value") + req.Header.Set("X-Request-Id", "req-abc-123") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured, "upstream request should have been made") + + // 客户端 header 应被透传 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) + require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) + require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) +} + +func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer client-gemini-token") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // 认证头应使用上游账号的 api_key + require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) +} + +func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body, _ := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, + }, + }) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") + req.Header.Set("Host", "evil.example.com") + c.Request = req + + stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: stub, + } + + _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) + + captured := stub.capturedReq + require.NotNil(t, captured) + + // hop-by-hop header 不应出现 + require.Empty(t, captured.Header.Get("Connection")) + require.Empty(t, captured.Header.Get("Proxy-Authorization")) + // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 + require.Empty(t, captured.Header.Values("Host")) + + // 普通 header 应保留 + require.Equal(t, "application/json", captured.Header.Get("Content-Type")) +} From 4f57d7f76188f2c767060c37d516ceb3fb05cdfe Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:36:35 +0800 Subject: [PATCH 054/363] fix: add nil guard for gin.Context in header passthrough to satisfy staticcheck SA5011 --- .../service/antigravity_gateway_service.go | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fc29eeb3..c2983c47 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -3492,12 +3492,14 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") } // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } } } // 覆盖认证头 @@ -3638,12 +3640,14 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") } // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } } } // 覆盖认证头 From 6ab77f5eb5afceb99eb32bba011261866bf6cf14 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 08:49:43 +0800 Subject: [PATCH 055/363] fix(upstream): passthrough response body directly instead of parsing SSE ForwardUpstream/ForwardUpstreamGemini should pipe the upstream response directly to the client (headers + body), not parse it as SSE stream. --- .../service/antigravity_gateway_service.go | 99 +++++++------------ 1 file changed, 38 insertions(+), 61 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index c2983c47..2d96b1ab 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -3530,39 +3530,30 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) } - // 成功响应 + // 成功响应:透传 response header + body requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) + + // 透传上游响应头(排除 hop-by-hop) + for key, values := range resp.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + c.Header(key, v) + } } - var usage *ClaudeUsage - var firstTokenMs *int - if claudeReq.Stream { - streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) - if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } else { - streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) - if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs + c.Status(resp.StatusCode) + _, copyErr := io.Copy(c.Writer, resp.Body) + if copyErr != nil { + log.Printf("%s status=copy_error error=%v", prefix, copyErr) } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Model: originalModel, + Stream: claudeReq.Stream, + Duration: time.Since(startTime), }, nil } @@ -3712,35 +3703,23 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } - // 成功响应 + // 成功响应:透传 response header + body requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) + + // 透传上游响应头(排除 hop-by-hop) + for key, values := range resp.Header { + if upstreamHopByHopHeaders[strings.ToLower(key)] { + continue + } + for _, v := range values { + c.Header(key, v) + } } - var usage *ClaudeUsage - var firstTokenMs *int - - if stream { - streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) - if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } else { - streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) - if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) - return nil, err - } - usage = streamRes.usage - firstTokenMs = streamRes.firstTokenMs - } - - if usage == nil { - usage = &ClaudeUsage{} + c.Status(resp.StatusCode) + _, copyErr := io.Copy(c.Writer, resp.Body) + if copyErr != nil { + log.Printf("%s status=copy_error error=%v", prefix, copyErr) } imageCount := 0 @@ -3749,13 +3728,11 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + ImageCount: imageCount, + ImageSize: imageSize, }, nil } From bb5a5dd65eab240c03b6f8f457f7d8453d84fdc0 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 8 Feb 2026 12:05:39 +0800 Subject: [PATCH 056/363] =?UTF-8?q?test:=20=E5=AE=8C=E5=96=84=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E5=8C=96=E6=B5=8B=E8=AF=95=E4=BD=93=E7=B3=BB=EF=BC=88?= =?UTF-8?q?7=E4=B8=AA=E6=A8=A1=E5=9D=97=EF=BC=8C73=E4=B8=AA=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 系统性地修复、补充和强化项目的自动化测试能力: 1. 测试基础设施修复 - 修复 stubConcurrencyCache 缺失方法和构造函数参数不匹配 - 创建 testutil 共享包(stubs.go, fixtures.go, httptest.go) - 为所有 Stub 添加编译期接口断言 2. 中间件测试补充 - 新增 JWT 认证中间件测试(有效/过期/篡改/缺失 Token) - 补充 rate_limiter 和 recovery 中间件测试场景 3. 网关核心路径测试 - 新增账户选择、等待队列、流式响应、并发控制、计费、Claude Code 检测测试 - 覆盖负载均衡、粘性会话、SSE 转发、槽位管理等关键逻辑 4. 前端测试体系(11个新测试文件,163个测试用例) - Pinia stores: auth, app, subscriptions - API client: 请求拦截器、响应拦截器、401 刷新 - Router guards: 认证重定向、管理员权限、简易模式限制 - Composables: useForm, useTableLoader, useClipboard - Components: LoginForm, ApiKeyCreate, Dashboard 5. CI/CD 流水线重构 - 重构 backend-ci.yml 为统一的 ci.yml - 前后端 4 个并行 Job + Postgres/Redis services - Race 检测、覆盖率收集与门禁、Docker 构建验证 6. E2E 自动化测试 - e2e-test.sh 自动化脚本(Docker 启动→健康检查→测试→清理) - 用户注册→登录→API Key→网关调用完整链路测试 - Mock 模式和 API Key 脱敏支持 7. 修复预存问题 - tlsfingerprint dialer_test.go 缺失 build tag 导致集成测试编译冲突 Co-Authored-By: Claude Opus 4.6 --- .github/workflows/backend-ci.yml | 47 --- .github/workflows/ci.yml | 179 ++++++++++ backend/Makefile | 5 +- .../admin/batch_update_credentials_test.go | 18 +- .../handler/sora_gateway_handler_test.go | 69 ++-- .../internal/integration/e2e_gateway_test.go | 96 ++++-- .../internal/integration/e2e_helpers_test.go | 48 +++ .../integration/e2e_user_flow_test.go | 317 +++++++++++++++++ .../internal/middleware/rate_limiter_test.go | 43 +++ .../pkg/tlsfingerprint/dialer_test.go | 20 +- .../pkg/tlsfingerprint/test_types_test.go | 20 ++ .../repository/billing_cache_jitter_test.go | 2 +- backend/internal/server/api_contract_test.go | 2 +- .../server/middleware/jwt_auth_test.go | 234 +++++++++++++ .../server/middleware/recovery_test.go | 29 ++ .../service/antigravity_rate_limit_test.go | 6 + .../internal/service/billing_service_test.go | 310 +++++++++++++++++ .../service/claude_code_detection_test.go | 282 +++++++++++++++ .../service/concurrency_service_test.go | 280 +++++++++++++++ .../service/gateway_account_selection_test.go | 198 +++++++++++ .../service/gateway_streaming_test.go | 203 +++++++++++ .../service/gateway_waiting_queue_test.go | 120 +++++++ .../service/openai_gateway_service_test.go | 4 + .../ops_alert_evaluator_service_test.go | 2 + .../service/sora_gateway_service_test.go | 2 + .../subscription_calculate_progress_test.go | 2 +- backend/internal/testutil/fixtures.go | 78 +++++ backend/internal/testutil/httptest.go | 35 ++ backend/internal/testutil/stubs.go | 137 ++++++++ frontend/src/api/__tests__/client.spec.ts | 208 +++++++++++ .../components/__tests__/ApiKeyCreate.spec.ts | 184 ++++++++++ .../components/__tests__/Dashboard.spec.ts | 173 ++++++++++ .../components/__tests__/LoginForm.spec.ts | 178 ++++++++++ .../__tests__/useClipboard.spec.ts | 137 ++++++++ .../src/composables/__tests__/useForm.spec.ts | 143 ++++++++ .../__tests__/useTableLoader.spec.ts | 252 ++++++++++++++ frontend/src/router/__tests__/guards.spec.ts | 324 ++++++++++++++++++ frontend/src/stores/__tests__/app.spec.ts | 293 ++++++++++++++++ frontend/src/stores/__tests__/auth.spec.ts | 289 ++++++++++++++++ .../stores/__tests__/subscriptions.spec.ts | 239 +++++++++++++ frontend/vitest.config.ts | 75 ++-- 41 files changed, 5101 insertions(+), 182 deletions(-) delete mode 100644 .github/workflows/backend-ci.yml create mode 100644 .github/workflows/ci.yml create mode 100644 backend/internal/integration/e2e_helpers_test.go create mode 100644 backend/internal/integration/e2e_user_flow_test.go create mode 100644 backend/internal/pkg/tlsfingerprint/test_types_test.go create mode 100644 backend/internal/server/middleware/jwt_auth_test.go create mode 100644 backend/internal/service/billing_service_test.go create mode 100644 backend/internal/service/claude_code_detection_test.go create mode 100644 backend/internal/service/concurrency_service_test.go create mode 100644 backend/internal/service/gateway_account_selection_test.go create mode 100644 backend/internal/service/gateway_streaming_test.go create mode 100644 backend/internal/service/gateway_waiting_queue_test.go create mode 100644 backend/internal/testutil/fixtures.go create mode 100644 backend/internal/testutil/httptest.go create mode 100644 backend/internal/testutil/stubs.go create mode 100644 frontend/src/api/__tests__/client.spec.ts create mode 100644 frontend/src/components/__tests__/ApiKeyCreate.spec.ts create mode 100644 frontend/src/components/__tests__/Dashboard.spec.ts create mode 100644 frontend/src/components/__tests__/LoginForm.spec.ts create mode 100644 frontend/src/composables/__tests__/useClipboard.spec.ts create mode 100644 frontend/src/composables/__tests__/useForm.spec.ts create mode 100644 frontend/src/composables/__tests__/useTableLoader.spec.ts create mode 100644 frontend/src/router/__tests__/guards.spec.ts create mode 100644 frontend/src/stores/__tests__/app.spec.ts create mode 100644 frontend/src/stores/__tests__/auth.spec.ts create mode 100644 frontend/src/stores/__tests__/subscriptions.spec.ts diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml deleted file mode 100644 index 2596a18c..00000000 --- a/.github/workflows/backend-ci.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: CI - -on: - push: - pull_request: - -permissions: - contents: read - -jobs: - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: backend/go.mod - check-latest: false - cache: true - - name: Verify Go version - run: | - go version | grep -q 'go1.25.7' - - name: Unit tests - working-directory: backend - run: make test-unit - - name: Integration tests - working-directory: backend - run: make test-integration - - golangci-lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: backend/go.mod - check-latest: false - cache: true - - name: Verify Go version - run: | - go version | grep -q 'go1.25.7' - - name: golangci-lint - uses: golangci/golangci-lint-action@v9 - with: - version: v2.7 - args: --timeout=5m - working-directory: backend diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..03e7159f --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,179 @@ +name: CI + +on: + push: + pull_request: + +permissions: + contents: read + +jobs: + # ========================================================================== + # 后端测试(与前端并行运行) + # ========================================================================== + backend-test: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:16-alpine + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: sub2api_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U test" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + redis: + image: redis:7-alpine + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: backend/go.mod + check-latest: false + cache: true + + - name: 验证 Go 版本 + run: go version | grep -q 'go1.25.7' + + - name: 单元测试 + working-directory: backend + run: make test-unit + + - name: 集成测试 + working-directory: backend + env: + DATABASE_URL: postgres://test:test@localhost:5432/sub2api_test?sslmode=disable + REDIS_URL: redis://localhost:6379/0 + run: make test-integration + + - name: Race 检测 + working-directory: backend + run: go test -tags=unit -race -count=1 ./... + + - name: 覆盖率收集 + working-directory: backend + run: | + go test -tags=unit -coverprofile=coverage.out -count=1 ./... + echo "## 后端测试覆盖率" >> $GITHUB_STEP_SUMMARY + echo '```' >> $GITHUB_STEP_SUMMARY + go tool cover -func=coverage.out | tail -1 >> $GITHUB_STEP_SUMMARY + echo '```' >> $GITHUB_STEP_SUMMARY + + - name: 覆盖率门禁(≥8%) + working-directory: backend + run: | + COVERAGE=$(go tool cover -func=coverage.out | tail -1 | awk '{print $3}' | sed 's/%//') + echo "当前覆盖率: ${COVERAGE}%" + if [ "$(echo "$COVERAGE < 8" | bc -l)" -eq 1 ]; then + echo "::error::后端覆盖率 ${COVERAGE}% 低于门禁值 8%" + exit 1 + fi + + # ========================================================================== + # 后端代码检查 + # ========================================================================== + golangci-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: backend/go.mod + check-latest: false + cache: true + - name: 验证 Go 版本 + run: go version | grep -q 'go1.25.7' + - name: golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: v2.7 + args: --timeout=5m + working-directory: backend + + # ========================================================================== + # 前端测试(与后端并行运行) + # ========================================================================== + frontend-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: 安装 pnpm + uses: pnpm/action-setup@v4 + with: + version: 9 + - name: 安装 Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + cache: 'pnpm' + cache-dependency-path: frontend/pnpm-lock.yaml + - name: 安装依赖 + working-directory: frontend + run: pnpm install --frozen-lockfile + + - name: 类型检查 + working-directory: frontend + run: pnpm run typecheck + + - name: Lint 检查 + working-directory: frontend + run: pnpm run lint:check + + - name: 单元测试 + working-directory: frontend + run: pnpm run test:run + + - name: 覆盖率收集 + working-directory: frontend + run: | + pnpm run test:coverage -- --exclude '**/integration/**' || true + echo "## 前端测试覆盖率" >> $GITHUB_STEP_SUMMARY + if [ -f coverage/coverage-final.json ]; then + echo "覆盖率报告已生成" >> $GITHUB_STEP_SUMMARY + fi + + - name: 覆盖率门禁(≥20%) + working-directory: frontend + run: | + if [ ! -f coverage/coverage-final.json ]; then + echo "::warning::覆盖率报告未生成,跳过门禁检查" + exit 0 + fi + # 使用 node 解析覆盖率 JSON + COVERAGE=$(node -e " + const data = require('./coverage/coverage-final.json'); + let totalStatements = 0, coveredStatements = 0; + for (const file of Object.values(data)) { + const stmts = file.s; + totalStatements += Object.keys(stmts).length; + coveredStatements += Object.values(stmts).filter(v => v > 0).length; + } + const pct = totalStatements > 0 ? (coveredStatements / totalStatements * 100) : 0; + console.log(pct.toFixed(1)); + ") + echo "当前前端覆盖率: ${COVERAGE}%" + if [ "$(echo "$COVERAGE < 20" | bc -l 2>/dev/null || node -e "console.log($COVERAGE < 20 ? 1 : 0)")" = "1" ]; then + echo "::warning::前端覆盖率 ${COVERAGE}% 低于门禁值 20%(当前为警告,不阻塞)" + fi + + # ========================================================================== + # Docker 构建验证 + # ========================================================================== + docker-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Docker 构建验证 + run: docker build -t aicodex2api:ci-test . diff --git a/backend/Makefile b/backend/Makefile index 6a5d2caa..89db1104 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -14,4 +14,7 @@ test-integration: go test -tags=integration ./... test-e2e: - go test -tags=e2e ./... + ./scripts/e2e-test.sh + +test-e2e-local: + go test -tags=e2e -v -timeout=300s ./internal/integration/... diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go index 4c47fadb..c8185735 100644 --- a/backend/internal/handler/admin/batch_update_credentials_test.go +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -60,7 +60,7 @@ func TestBatchUpdateCredentials_AllSuccess(t *testing.T) { require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount") } -func TestBatchUpdateCredentials_FailFast(t *testing.T) { +func TestBatchUpdateCredentials_PartialFailure(t *testing.T) { // 让第 2 个账号(ID=2)更新时失败 svc := &failingAdminService{ stubAdminService: newStubAdminService(), @@ -79,10 +79,18 @@ func TestBatchUpdateCredentials_FailFast(t *testing.T) { req.Header.Set("Content-Type", "application/json") router.ServeHTTP(w, req) - require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500") - // 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用 - require.Equal(t, int64(2), svc.updateCallCount.Load(), - "fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)") + // 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细 + require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细") + + var resp map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + data := resp["data"].(map[string]any) + require.Equal(t, float64(2), data["success"], "应有 2 个成功") + require.Equal(t, float64(1), data["failed"], "应有 1 个失败") + + // 所有 3 个账号都会被尝试更新(非 fail-fast) + require.Equal(t, int64(3), svc.updateCallCount.Load(), + "应调用 3 次 UpdateAccount(逐个尝试,失败后继续)") } func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) { diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 91881dec..ba266d5c 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -16,10 +16,17 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/testutil" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ service.SoraClient = (*stubSoraClient)(nil) +var _ service.AccountRepository = (*stubAccountRepo)(nil) +var _ service.GroupRepository = (*stubGroupRepo)(nil) +var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil) + type stubSoraClient struct { imageURLs []string } @@ -41,52 +48,6 @@ func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Acco return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil } -type stubConcurrencyCache struct{} - -func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { - return nil -} -func (c stubConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { - return 0, nil -} -func (c stubConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { - return nil -} -func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { - return 0, nil -} -func (c stubConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { - return nil -} -func (c stubConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { - return 0, nil -} -func (c stubConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { - return true, nil -} -func (c stubConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { - return nil -} -func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { - result := make(map[int64]*service.AccountLoadInfo, len(accounts)) - for _, acc := range accounts { - result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} - } - return result, nil -} -func (c stubConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { - return nil -} - type stubAccountRepo struct { accounts map[int64]*service.Account } @@ -260,6 +221,12 @@ func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil } +func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + return nil +} type stubUsageLogRepo struct{} @@ -312,15 +279,18 @@ func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, e func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { return nil, nil } -func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { +func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, nil } -func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { return nil, nil } func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) { return nil, nil } +func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { return nil, nil } @@ -384,7 +354,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { usageLogRepo := &stubUsageLogRepo{} deferredService := service.NewDeferredService(accountRepo, nil, 0) billingService := service.NewBillingService(cfg, nil) - concurrencyService := service.NewConcurrencyService(stubConcurrencyCache{}) + concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{}) billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg) t.Cleanup(func() { billingCacheService.Stop() @@ -397,6 +367,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, nil, nil, + nil, cfg, nil, concurrencyService, diff --git a/backend/internal/integration/e2e_gateway_test.go b/backend/internal/integration/e2e_gateway_test.go index ec0b29f7..8ee3f22e 100644 --- a/backend/internal/integration/e2e_gateway_test.go +++ b/backend/internal/integration/e2e_gateway_test.go @@ -21,11 +21,18 @@ var ( // - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户) // - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户) endpointPrefix = getEnv("ENDPOINT_PREFIX", "") - claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3" - geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f" testInterval = 1 * time.Second // 测试间隔,防止限流 ) +const ( + // 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。 + // 例如: + // export CLAUDE_API_KEY="sk-..." + // export GEMINI_API_KEY="sk-..." + claudeAPIKeyEnv = "CLAUDE_API_KEY" + geminiAPIKeyEnv = "GEMINI_API_KEY" +) + func getEnv(key, defaultVal string) string { if v := os.Getenv(key); v != "" { return v @@ -65,16 +72,45 @@ func TestMain(m *testing.M) { if endpointPrefix != "" { mode = "Antigravity 模式" } - fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode) + claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != "" + geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != "" + fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n", + baseURL, + endpointPrefix, + mode, + claudeAPIKeyEnv, + claudeKeySet, + geminiAPIKeyEnv, + geminiKeySet, + ) os.Exit(m.Run()) } +func requireClaudeAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv) + } + return key +} + +func requireGeminiAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv) + } + return key +} + // TestClaudeModelsList 测试 GET /v1/models func TestClaudeModelsList(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) url := baseURL + endpointPrefix + "/v1/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) { // TestGeminiModelsList 测试 GET /v1beta/models func TestGeminiModelsList(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) url := baseURL + endpointPrefix + "/v1beta/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) { // TestClaudeMessages 测试 Claude /v1/messages 接口 func TestClaudeMessages(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) for i, model := range claudeModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } -func testClaudeMessage(t *testing.T, model string, stream bool) { +func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) { url := baseURL + endpointPrefix + "/v1/messages" payload := map[string]any{ @@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { // TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口 func TestGeminiGenerateContent(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) for i, model := range geminiModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } -func testGeminiGenerate(t *testing.T, model string, stream bool) { +func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) { action := "generateContent" if stream { action = "streamGenerateContent" @@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) @@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { // TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求 // 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段 func TestClaudeMessagesWithComplexTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) // 测试模型列表(只测试几个代表性模型) models := []string{ "claude-opus-4-5-20251101", // Claude 模型 @@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_复杂工具", func(t *testing.T) { - testClaudeMessageWithTools(t, model) + testClaudeMessageWithTools(t, claudeKey, model) }) } } -func testClaudeMessageWithTools(t *testing.T, model string) { +func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具) @@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { // 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时, // 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误 func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash } @@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_thinking模式工具调用", func(t *testing.T) { - testClaudeThinkingWithToolHistory(t, model) + testClaudeThinkingWithToolHistory(t, claudeKey, model) }) } } -func testClaudeThinkingWithToolHistory(t *testing.T, model string) { +func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话 @@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + claudeKey := requireClaudeAPIKey(t) // 测试通过 Claude 端点调用 Gemini 模型 geminiViaClaude := []string{ @@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Claude端点", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Claude端点_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } @@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // 验证:Gemini 模型接受没有 signature 的 thinking block func TestClaudeMessagesWithNoSignature(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature } @@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_无signature", func(t *testing.T) { - testClaudeWithNoSignature(t, model) + testClaudeWithNoSignature(t, claudeKey, model) }) } } -func testClaudeWithNoSignature(t *testing.T, model string) { +func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话包含 thinking block 但没有 signature @@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + geminiKey := requireGeminiAPIKey(t) // 测试通过 Gemini 端点调用 Claude 模型 claudeViaGemini := []string{ @@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Gemini端点", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } diff --git a/backend/internal/integration/e2e_helpers_test.go b/backend/internal/integration/e2e_helpers_test.go new file mode 100644 index 00000000..7d266bcb --- /dev/null +++ b/backend/internal/integration/e2e_helpers_test.go @@ -0,0 +1,48 @@ +//go:build e2e + +package integration + +import ( + "os" + "strings" + "testing" +) + +// ============================================================================= +// E2E Mock 模式支持 +// ============================================================================= +// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。 +// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。 + +// isMockMode 检查是否启用 Mock 模式 +func isMockMode() bool { + return strings.EqualFold(os.Getenv("E2E_MOCK"), "true") +} + +// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试 +func skipIfNoRealAPI(t *testing.T) { + t.Helper() + if isMockMode() { + return // Mock 模式下不跳过 + } + claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if claudeKey == "" && geminiKey == "" { + t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试") + } +} + +// ============================================================================= +// API Key 脱敏(Task 6.10) +// ============================================================================= + +// safeLogKey 安全地记录 API Key(仅显示前 8 位) +func safeLogKey(t *testing.T, prefix string, key string) { + t.Helper() + key = strings.TrimSpace(key) + if len(key) <= 8 { + t.Logf("%s: ***(长度: %d)", prefix, len(key)) + return + } + t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key)) +} diff --git a/backend/internal/integration/e2e_user_flow_test.go b/backend/internal/integration/e2e_user_flow_test.go new file mode 100644 index 00000000..5489d0a3 --- /dev/null +++ b/backend/internal/integration/e2e_user_flow_test.go @@ -0,0 +1,317 @@ +//go:build e2e + +package integration + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// E2E 用户流程测试 +// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量 + +var ( + testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local" + testUserPassword = "E2eTest@12345" + testUserName = "e2e-test-user" +) + +// TestUserRegistrationAndLogin 测试用户注册和登录流程 +func TestUserRegistrationAndLogin(t *testing.T) { + // 步骤 1: 注册新用户 + t.Run("注册新用户", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + "username": testUserName, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/register", body, "") + if err != nil { + t.Skipf("注册接口不可用,跳过用户流程测试: %v", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭) + switch resp.StatusCode { + case 200: + t.Logf("✅ 用户注册成功: %s", testUserEmail) + case 400: + t.Logf("⚠️ 用户可能已存在: %s", string(respBody)) + case 403: + t.Skipf("注册功能已关闭: %s", string(respBody)) + default: + t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 2: 登录获取 JWT + var accessToken string + t.Run("用户登录获取JWT", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + t.Fatalf("登录请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析登录响应失败: %v", err) + } + + // 尝试从标准响应格式获取 token + if token, ok := result["access_token"].(string); ok && token != "" { + accessToken = token + } else if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + accessToken = token + } + } + + if accessToken == "" { + t.Skipf("未获取到 access_token,响应: %s", string(respBody)) + return + } + + // 验证 token 不为空且格式基本正确 + if len(accessToken) < 10 { + t.Fatalf("access_token 格式异常: %s", accessToken) + } + + t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken)) + }) + + if accessToken == "" { + t.Skip("未获取到 JWT,跳过后续测试") + return + } + + // 步骤 3: 使用 JWT 获取当前用户信息 + t.Run("获取当前用户信息", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + t.Logf("✅ 成功获取用户信息") + }) +} + +// TestAPIKeyLifecycle 测试 API Key 的创建和使用 +func TestAPIKeyLifecycle(t *testing.T) { + // 先登录获取 JWT + accessToken := loginTestUser(t) + if accessToken == "" { + t.Skip("无法登录,跳过 API Key 生命周期测试") + return + } + + var apiKey string + + // 步骤 1: 创建 API Key + t.Run("创建API_Key", func(t *testing.T) { + payload := map[string]string{ + "name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()), + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/keys", body, accessToken) + if err != nil { + t.Fatalf("创建 API Key 请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + // 从响应中提取 key + if key, ok := result["key"].(string); ok { + apiKey = key + } else if data, ok := result["data"].(map[string]any); ok { + if key, ok := data["key"].(string); ok { + apiKey = key + } + } + + if apiKey == "" { + t.Skipf("未获取到 API Key,响应: %s", string(respBody)) + return + } + + // 验证 API Key 脱敏日志(只显示前 8 位) + masked := apiKey + if len(masked) > 8 { + masked = masked[:8] + "..." + } + t.Logf("✅ API Key 创建成功: %s", masked) + }) + + if apiKey == "" { + t.Skip("未创建 API Key,跳过后续测试") + return + } + + // 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用) + t.Run("使用API_Key调用网关", func(t *testing.T) { + // 尝试调用 models 列表(最轻量的 API 调用) + resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey) + if err != nil { + t.Fatalf("网关请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 可能返回 200(成功)或 402(余额不足)或 403(无可用账户) + switch { + case resp.StatusCode == 200: + t.Logf("✅ API Key 网关调用成功") + case resp.StatusCode == 402: + t.Logf("⚠️ 余额不足,但 API Key 认证通过") + case resp.StatusCode == 403: + t.Logf("⚠️ 无可用账户,但 API Key 认证通过") + default: + t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 3: 查询用量记录 + t.Run("查询用量记录", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken) + if err != nil { + t.Fatalf("用量查询请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body)) + return + } + + t.Logf("✅ 用量查询成功") + }) +} + +// ============================================================================= +// 辅助函数 +// ============================================================================= + +func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) { + t.Helper() + + url := baseURL + path + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + client := &http.Client{Timeout: 30 * time.Second} + return client.Do(req) +} + +func loginTestUser(t *testing.T) string { + t.Helper() + + // 先尝试用管理员账户登录 + adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local") + adminPassword := getEnv("ADMIN_PASSWORD", "") + + if adminPassword == "" { + // 尝试用测试用户 + adminEmail = testUserEmail + adminPassword = testUserPassword + } + + payload := map[string]string{ + "email": adminEmail, + "password": adminPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return "" + } + + respBody, _ := io.ReadAll(resp.Body) + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if token, ok := result["access_token"].(string); ok { + return token + } + if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + return token + } + } + + return "" +} + +// redactAPIKey API Key 脱敏,只显示前 8 位 +func redactAPIKey(key string) string { + key = strings.TrimSpace(key) + if len(key) <= 8 { + return "***" + } + return key[:8] + "..." +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 0c379c0f..e362274f 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) { require.Equal(t, http.StatusTooManyRequests, recorder.Code) } +func TestRateLimiterDifferentIPsIndependent(t *testing.T) { + gin.SetMode(gin.TestMode) + + callCounts := make(map[string]int64) + originalRun := rateLimitRun + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { + callCounts[key]++ + return callCounts[key], false, nil + } + t.Cleanup(func() { + rateLimitRun = originalRun + }) + + limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"})) + + router := gin.New() + router.Use(limiter.Limit("api", 1, time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + // 第一个 IP 的请求应通过 + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + req1.RemoteAddr = "10.0.0.1:1234" + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过") + + // 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响) + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "10.0.0.2:5678" + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过") + + // 第一个 IP 的第二次请求应被限流 + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "10.0.0.1:1234" + rec3 := httptest.NewRecorder() + router.ServeHTTP(rec3, req3) + require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流") +} + func TestRateLimiterSuccessAndLimit(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index 345067e5..6d3db174 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -1,3 +1,5 @@ +//go:build unit + // Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. // // Unit tests for TLS fingerprint dialer. @@ -20,24 +22,6 @@ import ( "time" ) -// FingerprintResponse represents the response from tls.peet.ws/api/all. -type FingerprintResponse struct { - IP string `json:"ip"` - TLS TLSInfo `json:"tls"` - HTTP2 any `json:"http2"` -} - -// TLSInfo contains TLS fingerprint details. -type TLSInfo struct { - JA3 string `json:"ja3"` - JA3Hash string `json:"ja3_hash"` - JA4 string `json:"ja4"` - PeetPrint string `json:"peetprint"` - PeetPrintHash string `json:"peetprint_hash"` - ClientRandom string `json:"client_random"` - SessionID string `json:"session_id"` -} - // TestDialerBasicConnection tests that the dialer can establish TLS connections. func TestDialerBasicConnection(t *testing.T) { skipNetworkTest(t) diff --git a/backend/internal/pkg/tlsfingerprint/test_types_test.go b/backend/internal/pkg/tlsfingerprint/test_types_test.go new file mode 100644 index 00000000..2bbf2d22 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/test_types_test.go @@ -0,0 +1,20 @@ +package tlsfingerprint + +// FingerprintResponse represents the response from tls.peet.ws/api/all. +// 共享测试类型,供 unit 和 integration 测试文件使用。 +type FingerprintResponse struct { + IP string `json:"ip"` + TLS TLSInfo `json:"tls"` + HTTP2 any `json:"http2"` +} + +// TLSInfo contains TLS fingerprint details. +type TLSInfo struct { + JA3 string `json:"ja3"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + PeetPrint string `json:"peetprint"` + PeetPrintHash string `json:"peetprint_hash"` + ClientRandom string `json:"client_random"` + SessionID string `json:"session_id"` +} diff --git a/backend/internal/repository/billing_cache_jitter_test.go b/backend/internal/repository/billing_cache_jitter_test.go index 32c42cf4..ba4f2873 100644 --- a/backend/internal/repository/billing_cache_jitter_test.go +++ b/backend/internal/repository/billing_cache_jitter_test.go @@ -14,7 +14,7 @@ func TestJitteredTTL_WithinExpectedRange(t *testing.T) { // jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter) // 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内 lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s - upperBound := billingCacheTTL // 5min + upperBound := billingCacheTTL // 5min for i := 0; i < 200; i++ { ttl := jitteredTTL() diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index d92dcc47..6851e71a 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -603,7 +603,7 @@ func newContractDeps(t *testing.T) *contractDeps { usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) - subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, cfg) + subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go new file mode 100644 index 00000000..e1b8e1ad --- /dev/null +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -0,0 +1,234 @@ +//go:build unit + +package middleware + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。 +type stubJWTUserRepo struct { + service.UserRepository + users map[int64]*service.User +} + +func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) { + u, ok := r.users[id] + if !ok { + return nil, errors.New("user not found") + } + return u, nil +} + +// newJWTTestEnv 创建 JWT 认证中间件测试环境。 +// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。 +func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!" + cfg.JWT.AccessTokenExpireMinutes = 60 + + userRepo := &stubJWTUserRepo{users: users} + authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil) + mw := NewJWTAuthMiddleware(authSvc, userSvc) + + r := gin.New() + r.Use(gin.HandlerFunc(mw)) + r.GET("/protected", func(c *gin.Context) { + subject, _ := GetAuthSubjectFromContext(c) + role, _ := GetUserRoleFromContext(c) + c.JSON(http.StatusOK, gin.H{ + "user_id": subject.UserID, + "role": role, + }) + }) + return r, authSvc +} + +func TestJWTAuth_ValidToken(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, float64(1), body["user_id"]) + require.Equal(t, "user", body["role"]) +} + +func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "UNAUTHORIZED", body.Code) +} + +func TestJWTAuth_InvalidHeaderFormat(t *testing.T) { + tests := []struct { + name string + header string + }{ + {"无Bearer前缀", "Token abc123"}, + {"缺少空格分隔", "Bearerabc123"}, + {"仅有单词", "abc123"}, + } + router, _ := newJWTTestEnv(nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", tt.header) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_AUTH_HEADER", body.Code) + }) + } +} + +func TestJWTAuth_EmptyToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer ") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "EMPTY_TOKEN", body.Code) +} + +func TestJWTAuth_TamperedToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_TOKEN", body.Code) +} + +func TestJWTAuth_UserNotFound(t *testing.T) { + // 使用 user ID=1 的 token,但 repo 中没有该用户 + fakeUser := &service.User{ + ID: 999, + Email: "ghost@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + // 创建环境时不注入此用户,这样 GetByID 会失败 + router, authSvc := newJWTTestEnv(map[int64]*service.User{}) + + token, err := authSvc.GenerateToken(fakeUser) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_NOT_FOUND", body.Code) +} + +func TestJWTAuth_UserInactive(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "disabled@example.com", + Role: "user", + Status: service.StatusDisabled, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_INACTIVE", body.Code) +} + +func TestJWTAuth_TokenVersionMismatch(t *testing.T) { + // Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改) + userForToken := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + userInDB := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 2, // 密码修改后版本递增 + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB}) + + token, err := authSvc.GenerateToken(userForToken) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "TOKEN_REVOKED", body.Code) +} diff --git a/backend/internal/server/middleware/recovery_test.go b/backend/internal/server/middleware/recovery_test.go index 439f44cb..33e71d51 100644 --- a/backend/internal/server/middleware/recovery_test.go +++ b/backend/internal/server/middleware/recovery_test.go @@ -3,6 +3,7 @@ package middleware import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -14,6 +15,34 @@ import ( "github.com/stretchr/testify/require" ) +func TestRecovery_PanicLogContainsInfo(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 临时替换 DefaultErrorWriter 以捕获日志输出 + var buf bytes.Buffer + originalWriter := gin.DefaultErrorWriter + gin.DefaultErrorWriter = &buf + t.Cleanup(func() { + gin.DefaultErrorWriter = originalWriter + }) + + r := gin.New() + r.Use(Recovery()) + r.GET("/panic", func(c *gin.Context) { + panic("custom panic message for test") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/panic", nil) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + logOutput := buf.String() + require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息") + require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名") +} + func TestRecovery(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 20936356..2b4a5504 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -15,6 +15,12 @@ import ( "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ HTTPUpstream = (*stubAntigravityUpstream)(nil) +var _ HTTPUpstream = (*recordingOKUpstream)(nil) +var _ AccountRepository = (*stubAntigravityAccountRepo)(nil) +var _ SchedulerCache = (*stubSchedulerCache)(nil) + type stubAntigravityUpstream struct { firstBase string secondBase string diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go new file mode 100644 index 00000000..cdaf6953 --- /dev/null +++ b/backend/internal/service/billing_service_test.go @@ -0,0 +1,310 @@ +//go:build unit + +package service + +import ( + "math" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func newTestBillingService() *BillingService { + return NewBillingService(&config.Config{}, nil) +} + +func TestCalculateCost_BasicComputation(t *testing.T) { + svc := newTestBillingService() + + // 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075 + expectedInput := 1000 * 3e-6 + expectedOutput := 500 * 15e-6 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestCalculateCost_WithCacheTokens(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreationTokens: 2000, + CacheReadTokens: 3000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expectedCacheCreation := 2000 * 3.75e-6 + expectedCacheRead := 3000 * 0.3e-6 + require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10) + require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10) + + expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) +} + +func TestCalculateCost_RateMultiplier(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0) + require.NoError(t, err) + + // TotalCost 不受倍率影响,ActualCost 翻倍 + require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10) + require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) +} + +func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + model string + expectedInput float64 + }{ + {"claude-opus-4.5-20250101", 5e-6}, + {"claude-3-opus-20240229", 15e-6}, + {"claude-sonnet-4-20250514", 3e-6}, + {"claude-3-5-sonnet-20241022", 3e-6}, + {"claude-3-5-haiku-20241022", 1e-6}, + {"claude-3-haiku-20240307", 0.25e-6}, + } + + for _, tt := range tests { + pricing, err := svc.GetModelPricing(tt.model) + require.NoError(t, err, "模型 %s", tt.model) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model) + } +} + +func TestGetModelPricing_CaseInsensitive(t *testing.T) { + svc := newTestBillingService() + + p1, err := svc.GetModelPricing("Claude-Sonnet-4") + require.NoError(t, err) + + p2, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + + require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) +} + +func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { + svc := newTestBillingService() + + // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 + pricing, err := svc.GetModelPricing("claude-unknown-model") + require.NoError(t, err) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 50000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + // 总输入 150k < 200k 阈值,应走正常计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 210k + 输入 10k = 220k > 200k 阈值 + // 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入 + tokens := UsageTokens{ + InputTokens: 10000, + OutputTokens: 1000, + CacheReadTokens: 210000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + // 范围内:200k cache + 0 input + 1k output + inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 0, + OutputTokens: 1000, + CacheReadTokens: 200000, + }, 1.0) + + // 范围外:10k cache + 10k input,倍率 2.0 + outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 10000, + CacheReadTokens: 10000, + }, 2.0) + + require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 100k + 输入 150k = 250k > 200k 阈值 + // 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入 + tokens := UsageTokens{ + InputTokens: 150000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + require.True(t, cost.ActualCost > 0, "费用应大于 0") + + // 正常费用不含长上下文 + normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用") +} + +func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + + // threshold <= 0 应禁用长上下文计费 + cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0) + require.NoError(t, err) + + cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000} + + // extraMultiplier <= 1 应禁用长上下文计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateImageCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.134 + cfg := &ImagePriceConfig{Price1K: &price} + cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0) + + require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10) + require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) +} + +func TestCalculateSoraVideoCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.5 + cfg := &SoraPriceConfig{VideoPricePerRequest: &price} + cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0) + + require.InDelta(t, 0.5, cost.TotalCost, 1e-10) +} + +func TestCalculateSoraVideoCost_HDModel(t *testing.T) { + svc := newTestBillingService() + + hdPrice := 1.0 + normalPrice := 0.5 + cfg := &SoraPriceConfig{ + VideoPricePerRequest: &normalPrice, + VideoPricePerRequestHD: &hdPrice, + } + cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0) + require.InDelta(t, 1.0, cost.TotalCost, 1e-10) +} + +func TestIsModelSupported(t *testing.T) { + svc := newTestBillingService() + + require.True(t, svc.IsModelSupported("claude-sonnet-4")) + require.True(t, svc.IsModelSupported("Claude-Opus-4.5")) + require.True(t, svc.IsModelSupported("claude-3-haiku")) + require.False(t, svc.IsModelSupported("gpt-4o")) + require.False(t, svc.IsModelSupported("gemini-pro")) +} + +func TestCalculateCost_ZeroTokens(t *testing.T) { + svc := newTestBillingService() + + cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0) + require.NoError(t, err) + require.Equal(t, 0.0, cost.TotalCost) + require.Equal(t, 0.0, cost.ActualCost) +} + +func TestCalculateCost_LargeTokenCount(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1_000_000, + OutputTokens: 1_000_000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15 + require.InDelta(t, 3.0, cost.InputCost, 1e-6) + require.InDelta(t, 15.0, cost.OutputCost, 1e-6) + require.False(t, math.IsNaN(cost.TotalCost)) + require.False(t, math.IsInf(cost.TotalCost, 0)) +} diff --git a/backend/internal/service/claude_code_detection_test.go b/backend/internal/service/claude_code_detection_test.go new file mode 100644 index 00000000..ff7ad7f4 --- /dev/null +++ b/backend/internal/service/claude_code_detection_test.go @@ -0,0 +1,282 @@ +//go:build unit + +package service + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newTestValidator() *ClaudeCodeValidator { + return NewClaudeCodeValidator() +} + +// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体 +func validClaudeCodeBody() map[string]any { + return map[string]any{ + "model": "claude-sonnet-4-20250514", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc", + }, + } +} + +func TestValidate_ClaudeCLIUserAgent(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + ua string + want bool + }{ + {"标准版本号", "claude-cli/1.0.0", true}, + {"多位版本号", "claude-cli/12.34.56", true}, + {"大写开头", "Claude-CLI/1.0.0", true}, + {"非 claude-cli", "curl/7.64.1", false}, + {"空 User-Agent", "", false}, + {"部分匹配", "not-claude-cli/1.0.0", false}, + {"缺少版本号", "claude-cli/", false}, + {"版本格式不对", "claude-cli/1.0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua) + }) + } +} + +func TestValidate_NonMessagesPath_UAOnly(t *testing.T) { + v := newTestValidator() + + // 非 messages 路径只检查 UA + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + + result := v.Validate(req, nil) + require.True(t, result, "非 messages 路径只需 UA 匹配") +} + +func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "curl/7.64.1") + + result := v.Validate(req, nil) + require.False(t, result, "UA 不匹配时应返回 false") +} + +func TestValidate_MessagesPath_FullValid(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, validClaudeCodeBody()) + require.True(t, result, "完整有效请求应通过") +} + +func TestValidate_MessagesPath_MissingHeaders(t *testing.T) { + v := newTestValidator() + body := validClaudeCodeBody() + + tests := []struct { + name string + missingHeader string + }{ + {"缺少 X-App", "X-App"}, + {"缺少 anthropic-beta", "anthropic-beta"}, + {"缺少 anthropic-version", "anthropic-version"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Del(tt.missingHeader) + + result := v.Validate(req, body) + require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader) + }) + } +} + +func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + metadata map[string]any + }{ + {"缺少 metadata", nil}, + {"缺少 user_id", map[string]any{"other": "value"}}, + {"空 user_id", map[string]any{"user_id": ""}}, + {"格式错误", map[string]any{"user_id": "invalid-format"}}, + {"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + } + if tt.metadata != nil { + body["metadata"] = tt.metadata + } + + result := v.Validate(req, body) + require.False(t, result, "metadata.user_id: %v", tt.metadata) + }) + } +} + +func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "Generate JSON data for testing database migrations.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc", + }, + } + + result := v.Validate(req, body) + require.False(t, result, "无关系统提示词应返回 false") +} + +func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + // 不设置 X-App 等头,通过 context 标记为 haiku 探测请求 + ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + req = req.WithContext(ctx) + + // 即使 body 不包含 system prompt,也应通过 + result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1}) + require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证") +} + +func TestSystemPromptSimilarity(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + prompt string + want bool + }{ + {"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true}, + {"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true}, + {"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true}, + {"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true}, + {"无关文本", "Write me a poem about cats", false}, + {"空文本", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{"type": "text", "text": tt.prompt}, + }, + } + result := v.IncludesClaudeCodeSystemPrompt(body) + require.Equal(t, tt.want, result, "提示词: %q", tt.prompt) + }) + } +} + +func TestDiceCoefficient(t *testing.T) { + tests := []struct { + name string + a string + b string + want float64 + tol float64 + }{ + {"相同字符串", "hello", "hello", 1.0, 0.001}, + {"完全不同", "abc", "xyz", 0.0, 0.001}, + {"空字符串", "", "hello", 0.0, 0.001}, + {"单字符", "a", "b", 0.0, 0.001}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := diceCoefficient(tt.a, tt.b) + require.InDelta(t, tt.want, result, tt.tol) + }) + } +} + +func TestIsClaudeCodeClient_Context(t *testing.T) { + ctx := context.Background() + + // 默认应为 false + require.False(t, IsClaudeCodeClient(ctx)) + + // 设置为 true + ctx = SetClaudeCodeClient(ctx, true) + require.True(t, IsClaudeCodeClient(ctx)) + + // 设置为 false + ctx = SetClaudeCodeClient(ctx, false) + require.False(t, IsClaudeCodeClient(ctx)) +} + +func TestValidate_NilBody_MessagesPath(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, nil) + require.False(t, result, "nil body 的 messages 请求应返回 false") +} diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go new file mode 100644 index 00000000..33ce4cb9 --- /dev/null +++ b/backend/internal/service/concurrency_service_test.go @@ -0,0 +1,280 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 +type stubConcurrencyCacheForTest struct { + acquireResult bool + acquireErr error + releaseErr error + concurrency int + concurrencyErr error + waitAllowed bool + waitErr error + waitCount int + waitCountErr error + loadBatch map[int64]*AccountLoadInfo + loadBatchErr error + usersLoadBatch map[int64]*UserLoadInfo + usersLoadErr error + cleanupErr error + + // 记录调用 + releasedAccountIDs []int64 + releasedRequestIDs []string +} + +var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) + +func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error { + c.releasedAccountIDs = append(c.releasedAccountIDs, accountID) + c.releasedRequestIDs = append(c.releasedRequestIDs, requestID) + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { + return c.waitCount, c.waitCountErr +} +func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return c.loadBatch, c.loadBatchErr +} +func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { + return c.usersLoadBatch, c.usersLoadErr +} +func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { + return c.cleanupErr +} + +func TestAcquireAccountSlot_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_Failure(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: false} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.False(t, result.Acquired) + require.Nil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + for _, maxConcurrency := range []int{0, -1} { + result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency) + require.NoError(t, err) + require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency) + require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数") + } +} + +func TestAcquireAccountSlot_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.Error(t, err) + require.Nil(t, result) +} + +func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 42, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + + // 调用 ReleaseFunc 应释放槽位 + result.ReleaseFunc() + + require.Len(t, cache.releasedAccountIDs, 1) + require.Equal(t, int64(42), cache.releasedAccountIDs[0]) + require.Len(t, cache.releasedRequestIDs, 1) + require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空") +} + +func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + // 用户槽位获取应独立于账户槽位 + result, err := svc.AcquireUserSlot(context.Background(), 100, 3) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + result, err := svc.AcquireUserSlot(context.Background(), 1, 0) + require.NoError(t, err) + require.True(t, result.Acquired) +} + +func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { + expected := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, + 2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100}, + } + cache := &stubConcurrencyCacheForTest{loadBatch: expected} + svc := NewConcurrencyService(cache) + + accounts := []AccountWithConcurrency{ + {ID: 1, MaxConcurrency: 5}, + {ID: 2, MaxConcurrency: 5}, + } + result, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestGetAccountsLoadBatch_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + result, err := svc.GetAccountsLoadBatch(context.Background(), nil) + require.NoError(t, err) + require.Empty(t, result) +} + +func TestIncrementWaitCount_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) +} + +func TestIncrementWaitCount_QueueFull(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed) +} + +func TestIncrementWaitCount_FailOpen(t *testing.T) { + // Redis 错误时应 fail-open(允许请求通过) + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed, "nil cache 应 fail-open") +} + +func TestCalculateMaxWait(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {10, 30}, // 10 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} + +func TestGetAccountWaitingCount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitCount: 5} + svc := NewConcurrencyService(cache) + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 5, count) +} + +func TestGetAccountWaitingCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 0, count) +} + +func TestGetAccountConcurrencyBatch(t *testing.T) { + cache := &stubConcurrencyCacheForTest{concurrency: 3} + svc := NewConcurrencyService(cache) + + result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3}) + require.NoError(t, err) + require.Len(t, result, 3) + for _, id := range []int64{1, 2, 3} { + require.Equal(t, 3, result[id]) + } +} + +func TestIncrementAccountWaitCount_FailOpen(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.True(t, allowed) +} diff --git a/backend/internal/service/gateway_account_selection_test.go b/backend/internal/service/gateway_account_selection_test.go new file mode 100644 index 00000000..70c5d6c5 --- /dev/null +++ b/backend/internal/service/gateway_account_selection_test.go @@ -0,0 +1,198 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- helpers --- + +func testTimePtr(t time.Time) *time.Time { return &t } + +func makeAccWithLoad(id int64, priority int, loadRate int, lastUsed *time.Time, accType string) accountWithLoad { + return accountWithLoad{ + account: &Account{ + ID: id, + Priority: priority, + LastUsedAt: lastUsed, + Type: accType, + Schedulable: true, + Status: StatusActive, + }, + loadInfo: &AccountLoadInfo{ + AccountID: id, + CurrentConcurrency: 0, + LoadRate: loadRate, + }, + } +} + +// --- sortAccountsByPriorityAndLastUsed --- + +func TestSortAccountsByPriorityAndLastUsed_ByPriority(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 5, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 3, LastUsedAt: testTimePtr(now)}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(2), accounts[0].ID, "优先级最低的排第一") + require.Equal(t, int64(3), accounts[1].ID) + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(3), accounts[0].ID, "nil LastUsedAt 排最前") + require.Equal(t, int64(2), accounts[1].ID, "更早使用的排前面") + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_PreferOAuth(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeOAuth}, + } + sortAccountsByPriorityAndLastUsed(accounts, true) + require.Equal(t, int64(2), accounts[0].ID, "preferOAuth 时 OAuth 账号排前面") +} + +func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + // 稳定排序:相同键值的元素保持原始顺序 + require.Equal(t, int64(1), accounts[0].ID) + require.Equal(t, int64(2), accounts[1].ID) + require.Equal(t, int64(3), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 2, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 4, Priority: 2, LastUsedAt: testTimePtr(now.Add(-2 * time.Hour))}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + // 优先级1排前:nil < earlier + require.Equal(t, int64(3), accounts[0].ID, "优先级1 + 更早") + require.Equal(t, int64(2), accounts[1].ID, "优先级1 + 现在") + // 优先级2排后:nil < time + require.Equal(t, int64(1), accounts[2].ID, "优先级2 + nil") + require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间") +} + +// --- selectByCallCount --- + +func TestSelectByCallCount_Empty(t *testing.T) { + result := selectByCallCount(nil, nil, false) + require.Nil(t, result) +} + +func TestSelectByCallCount_Single(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + } + result := selectByCallCount(accounts, map[int64]*ModelLoadInfo{1: {CallCount: 10}}, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID) +} + +func TestSelectByCallCount_NilModelLoadFallsBackToLRU(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + } + result := selectByCallCount(accounts, nil, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID, "nil modelLoadMap 应回退到 LRU 选择") +} + +func TestSelectByCallCount_SelectsMinCallCount(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey), + } + modelLoad := map[int64]*ModelLoadInfo{ + 1: {CallCount: 100}, + 2: {CallCount: 5}, + 3: {CallCount: 50}, + } + // 运行多次确认总是选调用次数最少的 + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID, "应选择调用次数最少的账号") + } +} + +func TestSelectByCallCount_NewAccountUsesAverage(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 50, nil, AccountTypeAPIKey), + } + // 账号1和2有调用记录,账号3是新账号(CallCount=0) + // 平均调用次数 = (100 + 200) / 2 = 150 + // 新账号用平均值 150,比账号1(100)多,所以应选账号1 + modelLoad := map[int64]*ModelLoadInfo{ + 1: {CallCount: 100}, + 2: {CallCount: 200}, + // 3 没有记录 + } + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID, "新账号虚拟调用次数(150)高于账号1(100),应选账号1") + } +} + +func TestSelectByCallCount_AllNewAccountsFallToAvgZero(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeAPIKey), + } + // 所有账号都是新的,avgCallCount = 0,所有人 effectiveCallCount 都是 0 + modelLoad := map[int64]*ModelLoadInfo{} + validIDs := map[int64]bool{1: true, 2: true} + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, false) + require.NotNil(t, result) + require.True(t, validIDs[result.account.ID], "所有新账号应随机选择") + } +} + +func TestSelectByCallCount_PreferOAuth(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 50, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 50, nil, AccountTypeOAuth), + } + // 两个账号调用次数相同 + modelLoad := map[int64]*ModelLoadInfo{ + 1: {CallCount: 10}, + 2: {CallCount: 10}, + } + for i := 0; i < 10; i++ { + result := selectByCallCount(accounts, modelLoad, true) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID, "调用次数相同时应优先选择 OAuth 账号") + } +} diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go new file mode 100644 index 00000000..50b998a3 --- /dev/null +++ b/backend/internal/service/gateway_streaming_test.go @@ -0,0 +1,203 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --- parseSSEUsage 测试 --- + +func newMinimalGatewayService() *GatewayService { + return &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } +} + +func TestParseSSEUsage_MessageStart(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation_input_tokens":50,"cache_read_input_tokens":200}}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 100, usage.InputTokens) + require.Equal(t, 50, usage.CacheCreationInputTokens) + require.Equal(t, 200, usage.CacheReadInputTokens) + require.Equal(t, 0, usage.OutputTokens, "message_start 不应设置 output_tokens") +} + +func TestParseSSEUsage_MessageDelta(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_delta","usage":{"output_tokens":42}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 42, usage.OutputTokens) + require.Equal(t, 0, usage.InputTokens, "message_delta 的 output_tokens 不应影响已有的 input_tokens") +} + +func TestParseSSEUsage_DeltaDoesNotOverwriteStartValues(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先处理 message_start + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100}}}`, usage) + require.Equal(t, 100, usage.InputTokens) + + // 再处理 message_delta(output_tokens > 0, input_tokens = 0) + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":50}}`, usage) + require.Equal(t, 100, usage.InputTokens, "delta 中 input_tokens=0 不应覆盖 start 中的值") + require.Equal(t, 50, usage.OutputTokens) +} + +func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // GLM 等 API 会在 delta 中包含所有 usage 信息 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"input_tokens":200,"output_tokens":100,"cache_creation_input_tokens":30,"cache_read_input_tokens":60}}`, usage) + require.Equal(t, 200, usage.InputTokens) + require.Equal(t, 100, usage.OutputTokens) + require.Equal(t, 30, usage.CacheCreationInputTokens) + require.Equal(t, 60, usage.CacheReadInputTokens) +} + +func TestParseSSEUsage_InvalidJSON(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 无效 JSON 不应 panic + svc.parseSSEUsage("not json", usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_UnknownType(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 不是 message_start 或 message_delta 的类型 + svc.parseSSEUsage(`{"type":"content_block_delta","delta":{"text":"hello"}}`, usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_EmptyString(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + svc.parseSSEUsage("", usage) + require.Equal(t, 0, usage.InputTokens) +} + +func TestParseSSEUsage_DoneEvent(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // [DONE] 事件不应影响 usage + svc.parseSSEUsage("[DONE]", usage) + require.Equal(t, 0, usage.InputTokens) +} + +// --- 流式响应端到端测试 --- + +func TestHandleStreamingResponse_CacheTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":20,\"cache_read_input_tokens\":30}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":15}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 10, result.usage.InputTokens) + require.Equal(t, 15, result.usage.OutputTokens) + require.Equal(t, 20, result.usage.CacheCreationInputTokens) + require.Equal(t, 30, result.usage.CacheReadInputTokens) +} + +func TestHandleStreamingResponse_EmptyStream(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + // 直接关闭,不发送任何事件 + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // 包含特殊字符的 content_block_delta(引号、换行、Unicode) + _, _ = pw.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello \\\"world\\\"\\n你好\"}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + + // 验证响应中包含转发的数据 + body := rec.Body.String() + require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") +} diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go new file mode 100644 index 00000000..0ed95c87 --- /dev/null +++ b/backend/internal/service/gateway_waiting_queue_test.go @@ -0,0 +1,120 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDecrementWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + // 不应 panic + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + // DecrementWaitCount 使用 background context,错误只记录日志不传播 + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementAccountWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_IncrementThenDecrement 测试完整的等待队列增减流程 +func TestWaitingQueueFlow_IncrementThenDecrement(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入等待队列 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) + + // 离开等待队列(不应 panic) + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_AccountLevel 测试账号级等待队列流程 +func TestWaitingQueueFlow_AccountLevel(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入账号等待队列 + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 42, 10) + require.NoError(t, err) + require.True(t, allowed) + + // 离开账号等待队列 + svc.DecrementAccountWaitCount(context.Background(), 42) +} + +// TestWaitingQueueFull_Returns429Signal 测试等待队列满时返回 false +func TestWaitingQueueFull_Returns429Signal(t *testing.T) { + // waitAllowed=false 模拟队列已满 + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + // 用户级等待队列满 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed, "等待队列满时应返回 false(调用方根据此返回 429)") + + // 账号级等待队列满 + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.False(t, allowed, "账号等待队列满时应返回 false") +} + +// TestWaitingQueue_FailOpen_OnCacheError 测试 Redis 故障时 fail-open +func TestWaitingQueue_FailOpen_OnCacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis connection refused")} + svc := NewConcurrencyService(cache) + + // 用户级:Redis 错误时允许通过 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") + + // 账号级:同样 fail-open + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") +} + +// TestCalculateMaxWait_Scenarios 测试最大等待队列大小计算 +func TestCalculateMaxWait_Scenarios(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {10, 30}, // 10 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {-10, 21}, // min(1) + 20 + {100, 120}, // 100 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 91dbaa4b..a6eeb3eb 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -17,6 +17,10 @@ import ( "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ AccountRepository = (*stubOpenAIAccountRepo)(nil) +var _ GatewayCache = (*stubGatewayCache)(nil) + type stubOpenAIAccountRepo struct { AccountRepository accounts []Account diff --git a/backend/internal/service/ops_alert_evaluator_service_test.go b/backend/internal/service/ops_alert_evaluator_service_test.go index 068ab6bb..83d358a3 100644 --- a/backend/internal/service/ops_alert_evaluator_service_test.go +++ b/backend/internal/service/ops_alert_evaluator_service_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var _ OpsRepository = (*stubOpsRepo)(nil) + type stubOpsRepo struct { OpsRepository overview *OpsDashboardOverview diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index caa10427..0a77d228 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var _ SoraClient = (*stubSoraClientForPoll)(nil) + type stubSoraClientForPoll struct { imageStatus *SoraImageTaskStatus videoStatus *SoraVideoTaskStatus diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go index d8adf7f7..22018bcd 100644 --- a/backend/internal/service/subscription_calculate_progress_test.go +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -14,7 +14,7 @@ func newTestSubscriptionService() *SubscriptionService { return &SubscriptionService{} } -func ptrFloat64(v float64) *float64 { return &v } +func ptrFloat64(v float64) *float64 { return &v } func ptrTime(t time.Time) *time.Time { return &t } func TestCalculateProgress_BasicFields(t *testing.T) { diff --git a/backend/internal/testutil/fixtures.go b/backend/internal/testutil/fixtures.go new file mode 100644 index 00000000..747767bc --- /dev/null +++ b/backend/internal/testutil/fixtures.go @@ -0,0 +1,78 @@ +//go:build unit + +package testutil + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// NewTestUser 创建一个可用的测试用户,可通过 opts 覆盖默认值。 +func NewTestUser(opts ...func(*service.User)) *service.User { + u := &service.User{ + ID: 1, + Email: "test@example.com", + Username: "testuser", + Role: "user", + Balance: 100.0, + Concurrency: 5, + Status: service.StatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + for _, opt := range opts { + opt(u) + } + return u +} + +// NewTestAccount 创建一个可用的测试账户,可通过 opts 覆盖默认值。 +func NewTestAccount(opts ...func(*service.Account)) *service.Account { + a := &service.Account{ + ID: 1, + Name: "test-account", + Platform: service.PlatformAnthropic, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 1, + } + for _, opt := range opts { + opt(a) + } + return a +} + +// NewTestAPIKey 创建一个可用的测试 API Key,可通过 opts 覆盖默认值。 +func NewTestAPIKey(opts ...func(*service.APIKey)) *service.APIKey { + groupID := int64(1) + k := &service.APIKey{ + ID: 1, + UserID: 1, + Key: "sk-test-key-12345678", + Name: "test-key", + GroupID: &groupID, + Status: service.StatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + for _, opt := range opts { + opt(k) + } + return k +} + +// NewTestGroup 创建一个可用的测试分组,可通过 opts 覆盖默认值。 +func NewTestGroup(opts ...func(*service.Group)) *service.Group { + g := &service.Group{ + ID: 1, + Platform: service.PlatformAnthropic, + Status: service.StatusActive, + Hydrated: true, + } + for _, opt := range opts { + opt(g) + } + return g +} diff --git a/backend/internal/testutil/httptest.go b/backend/internal/testutil/httptest.go new file mode 100644 index 00000000..2a066a12 --- /dev/null +++ b/backend/internal/testutil/httptest.go @@ -0,0 +1,35 @@ +//go:build unit + +package testutil + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// NewGinTestContext 创建一个 Gin 测试上下文和 ResponseRecorder。 +// body 为空字符串时创建无 body 的请求。 +func NewGinTestContext(method, path, body string) (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + var bodyReader io.Reader + if body != "" { + bodyReader = strings.NewReader(body) + } + + c.Request = httptest.NewRequest(method, path, bodyReader) + if method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch { + c.Request.Header.Set("Content-Type", "application/json") + } + + return c, rec +} diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go new file mode 100644 index 00000000..81c40c42 --- /dev/null +++ b/backend/internal/testutil/stubs.go @@ -0,0 +1,137 @@ +//go:build unit + +// Package testutil 提供单元测试共享的 Stub、Fixture 和辅助函数。 +// 所有文件使用 //go:build unit 标签,确保不会被生产构建包含。 +package testutil + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// ============================================================ +// StubConcurrencyCache — service.ConcurrencyCache 的空实现 +// ============================================================ + +// 编译期接口断言 +var _ service.ConcurrencyCache = StubConcurrencyCache{} + +// StubConcurrencyCache 是 ConcurrencyCache 的默认空实现,所有方法返回零值。 +type StubConcurrencyCache struct{} + +func (c StubConcurrencyCache) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) ReleaseAccountSlot(_ context.Context, _ int64, _ string) error { + return nil +} +func (c StubConcurrencyCache) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubConcurrencyCache) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) DecrementAccountWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c StubConcurrencyCache) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubConcurrencyCache) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { + return nil +} +func (c StubConcurrencyCache) GetUserConcurrency(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubConcurrencyCache) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return true, nil +} +func (c StubConcurrencyCache) DecrementWaitCount(_ context.Context, _ int64) error { return nil } +func (c StubConcurrencyCache) GetAccountsLoadBatch(_ context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + result := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return result, nil +} +func (c StubConcurrencyCache) GetUsersLoadBatch(_ context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + result := make(map[int64]*service.UserLoadInfo, len(users)) + for _, u := range users { + result[u.ID] = &service.UserLoadInfo{UserID: u.ID, LoadRate: 0} + } + return result, nil +} +func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { + return nil +} + +// ============================================================ +// StubGatewayCache — service.GatewayCache 的空实现 +// ============================================================ + +var _ service.GatewayCache = StubGatewayCache{} + +type StubGatewayCache struct{} + +func (c StubGatewayCache) GetSessionAccountID(_ context.Context, _ int64, _ string) (int64, error) { + return 0, nil +} +func (c StubGatewayCache) SetSessionAccountID(_ context.Context, _ int64, _ string, _ int64, _ time.Duration) error { + return nil +} +func (c StubGatewayCache) RefreshSessionTTL(_ context.Context, _ int64, _ string, _ time.Duration) error { + return nil +} +func (c StubGatewayCache) DeleteSessionAccountID(_ context.Context, _ int64, _ string) error { + return nil +} +func (c StubGatewayCache) IncrModelCallCount(_ context.Context, _ int64, _ string) (int64, error) { + return 0, nil +} +func (c StubGatewayCache) GetModelLoadBatch(_ context.Context, _ []int64, _ string) (map[int64]*service.ModelLoadInfo, error) { + return nil, nil +} +func (c StubGatewayCache) FindGeminiSession(_ context.Context, _ int64, _, _ string) (string, int64, bool) { + return "", 0, false +} +func (c StubGatewayCache) SaveGeminiSession(_ context.Context, _ int64, _, _, _ string, _ int64) error { + return nil +} + +// ============================================================ +// StubSessionLimitCache — service.SessionLimitCache 的空实现 +// ============================================================ + +var _ service.SessionLimitCache = StubSessionLimitCache{} + +type StubSessionLimitCache struct{} + +func (c StubSessionLimitCache) RegisterSession(_ context.Context, _ int64, _ string, _ int, _ time.Duration) (bool, error) { + return true, nil +} +func (c StubSessionLimitCache) RefreshSession(_ context.Context, _ int64, _ string, _ time.Duration) error { + return nil +} +func (c StubSessionLimitCache) GetActiveSessionCount(_ context.Context, _ int64) (int, error) { + return 0, nil +} +func (c StubSessionLimitCache) GetActiveSessionCountBatch(_ context.Context, _ []int64, _ map[int64]time.Duration) (map[int64]int, error) { + return nil, nil +} +func (c StubSessionLimitCache) IsSessionActive(_ context.Context, _ int64, _ string) (bool, error) { + return false, nil +} +func (c StubSessionLimitCache) GetWindowCost(_ context.Context, _ int64) (float64, bool, error) { + return 0, false, nil +} +func (c StubSessionLimitCache) SetWindowCost(_ context.Context, _ int64, _ float64) error { + return nil +} +func (c StubSessionLimitCache) GetWindowCostBatch(_ context.Context, _ []int64) (map[int64]float64, error) { + return nil, nil +} diff --git a/frontend/src/api/__tests__/client.spec.ts b/frontend/src/api/__tests__/client.spec.ts new file mode 100644 index 00000000..0e92c6d1 --- /dev/null +++ b/frontend/src/api/__tests__/client.spec.ts @@ -0,0 +1,208 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import axios from 'axios' +import type { AxiosInstance, InternalAxiosRequestConfig, AxiosResponse, AxiosHeaders } from 'axios' + +// 需要在导入 client 之前设置 mock +vi.mock('@/i18n', () => ({ + getLocale: () => 'zh-CN', +})) + +describe('API Client', () => { + let apiClient: AxiosInstance + + beforeEach(async () => { + localStorage.clear() + // 每次测试重新导入以获取干净的模块状态 + vi.resetModules() + const mod = await import('@/api/client') + apiClient = mod.apiClient + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + // --- 请求拦截器 --- + + describe('请求拦截器', () => { + it('自动附加 Authorization 头', async () => { + localStorage.setItem('auth_token', 'my-jwt-token') + + // 拦截实际请求 + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.get('/test') + + const config = adapter.mock.calls[0][0] + expect(config.headers.get('Authorization')).toBe('Bearer my-jwt-token') + }) + + it('无 token 时不附加 Authorization 头', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.get('/test') + + const config = adapter.mock.calls[0][0] + expect(config.headers.get('Authorization')).toBeFalsy() + }) + + it('GET 请求自动附加 timezone 参数', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.get('/test') + + const config = adapter.mock.calls[0][0] + expect(config.params).toHaveProperty('timezone') + }) + + it('POST 请求不附加 timezone 参数', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: {} }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await apiClient.post('/test', { foo: 'bar' }) + + const config = adapter.mock.calls[0][0] + expect(config.params?.timezone).toBeUndefined() + }) + }) + + // --- 响应拦截器 --- + + describe('响应拦截器', () => { + it('code=0 时解包 data 字段', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 0, data: { name: 'test' }, message: 'ok' }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + const response = await apiClient.get('/test') + expect(response.data).toEqual({ name: 'test' }) + }) + + it('code!=0 时拒绝并返回结构化错误', async () => { + const adapter = vi.fn().mockResolvedValue({ + status: 200, + data: { code: 1001, message: '参数错误', data: null }, + headers: {}, + config: {}, + statusText: 'OK', + }) + apiClient.defaults.adapter = adapter + + await expect(apiClient.get('/test')).rejects.toEqual( + expect.objectContaining({ + code: 1001, + message: '参数错误', + }) + ) + }) + }) + + // --- 401 Token 刷新 --- + + describe('401 Token 刷新', () => { + it('无 refresh_token 时 401 清除 localStorage', async () => { + localStorage.setItem('auth_token', 'expired-token') + // 不设置 refresh_token + + // Mock window.location + const originalLocation = window.location + Object.defineProperty(window, 'location', { + value: { ...originalLocation, pathname: '/dashboard', href: '/dashboard' }, + writable: true, + }) + + const adapter = vi.fn().mockRejectedValue({ + response: { + status: 401, + data: { code: 'TOKEN_EXPIRED', message: 'Token expired' }, + }, + config: { + url: '/test', + headers: { Authorization: 'Bearer expired-token' }, + }, + code: 'ERR_BAD_REQUEST', + }) + apiClient.defaults.adapter = adapter + + await expect(apiClient.get('/test')).rejects.toBeDefined() + + expect(localStorage.getItem('auth_token')).toBeNull() + + // 恢复 location + Object.defineProperty(window, 'location', { + value: originalLocation, + writable: true, + }) + }) + }) + + // --- 网络错误 --- + + describe('网络错误', () => { + it('网络错误返回 status 0 的错误', async () => { + const adapter = vi.fn().mockRejectedValue({ + code: 'ERR_NETWORK', + message: 'Network Error', + config: { url: '/test' }, + // 没有 response + }) + apiClient.defaults.adapter = adapter + + await expect(apiClient.get('/test')).rejects.toEqual( + expect.objectContaining({ + status: 0, + message: 'Network error. Please check your connection.', + }) + ) + }) + }) + + // --- 请求取消 --- + + describe('请求取消', () => { + it('取消的请求保持原始取消错误', async () => { + const source = axios.CancelToken.source() + + const adapter = vi.fn().mockRejectedValue( + new axios.Cancel('Operation canceled') + ) + apiClient.defaults.adapter = adapter + + await expect( + apiClient.get('/test', { cancelToken: source.token }) + ).rejects.toBeDefined() + }) + }) +}) diff --git a/frontend/src/components/__tests__/ApiKeyCreate.spec.ts b/frontend/src/components/__tests__/ApiKeyCreate.spec.ts new file mode 100644 index 00000000..537f43e7 --- /dev/null +++ b/frontend/src/components/__tests__/ApiKeyCreate.spec.ts @@ -0,0 +1,184 @@ +/** + * API Key 创建逻辑测试 + * 通过封装组件测试 API Key 创建的核心流程 + */ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, ref, reactive } from 'vue' + +// Mock keysAPI +const mockCreate = vi.fn() +const mockList = vi.fn() + +vi.mock('@/api', () => ({ + keysAPI: { + create: (...args: any[]) => mockCreate(...args), + list: (...args: any[]) => mockList(...args), + }, + authAPI: { + getCurrentUser: vi.fn().mockResolvedValue({ data: {} }), + logout: vi.fn(), + refreshToken: vi.fn(), + }, + isTotp2FARequired: () => false, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn().mockResolvedValue({}), +})) + +// Mock app store - 使用固定引用确保组件和测试共享同一对象 +const mockShowSuccess = vi.fn() +const mockShowError = vi.fn() + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showSuccess: mockShowSuccess, + showError: mockShowError, + }), +})) + +import { useAppStore } from '@/stores/app' + +/** + * 简化的 API Key 创建测试组件 + */ +const ApiKeyCreateTestComponent = defineComponent({ + setup() { + const appStore = useAppStore() + const loading = ref(false) + const createdKey = ref('') + const formData = reactive({ + name: '', + group_id: null as number | null, + }) + + const handleCreate = async () => { + if (!formData.name) return + + loading.value = true + try { + const result = await mockCreate({ + name: formData.name, + group_id: formData.group_id, + }) + createdKey.value = result.key + appStore.showSuccess('API Key 创建成功') + } catch (error: any) { + appStore.showError(error.message || '创建失败') + } finally { + loading.value = false + } + } + + return { formData, loading, createdKey, handleCreate } + }, + template: ` +
+
+ + + +
+
{{ createdKey }}
+
+ `, +}) + +describe('ApiKey 创建流程', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + }) + + it('创建 API Key 调用 API 并显示结果', async () => { + mockCreate.mockResolvedValue({ + id: 1, + key: 'sk-test-key-12345', + name: 'My Test Key', + }) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('My Test Key') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockCreate).toHaveBeenCalledWith({ + name: 'My Test Key', + group_id: null, + }) + + expect(wrapper.find('.created-key').text()).toBe('sk-test-key-12345') + }) + + it('选择分组后正确传参', async () => { + mockCreate.mockResolvedValue({ + id: 2, + key: 'sk-group-key', + name: 'Group Key', + }) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('Group Key') + // 选择 group_id = 1 + await wrapper.find('#group').setValue('1') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockCreate).toHaveBeenCalledWith({ + name: 'Group Key', + group_id: 1, + }) + }) + + it('创建失败时显示错误', async () => { + mockCreate.mockRejectedValue(new Error('配额不足')) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('Fail Key') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockShowError).toHaveBeenCalledWith('配额不足') + expect(wrapper.find('.created-key').exists()).toBe(false) + }) + + it('名称为空时不提交', async () => { + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockCreate).not.toHaveBeenCalled() + }) + + it('创建过程中按钮被禁用', async () => { + let resolveCreate: (v: any) => void + mockCreate.mockImplementation( + () => new Promise((resolve) => { resolveCreate = resolve }) + ) + + const wrapper = mount(ApiKeyCreateTestComponent) + + await wrapper.find('#name').setValue('Test Key') + await wrapper.find('form').trigger('submit') + + expect(wrapper.find('button').attributes('disabled')).toBeDefined() + + resolveCreate!({ id: 1, key: 'sk-test', name: 'Test Key' }) + await flushPromises() + + expect(wrapper.find('button').attributes('disabled')).toBeUndefined() + }) +}) diff --git a/frontend/src/components/__tests__/Dashboard.spec.ts b/frontend/src/components/__tests__/Dashboard.spec.ts new file mode 100644 index 00000000..b83808cc --- /dev/null +++ b/frontend/src/components/__tests__/Dashboard.spec.ts @@ -0,0 +1,173 @@ +/** + * Dashboard 数据加载逻辑测试 + * 通过封装组件测试仪表板核心数据加载流程 + */ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, ref, onMounted, nextTick } from 'vue' + +// Mock API +const mockGetDashboardStats = vi.fn() +const mockRefreshUser = vi.fn() + +vi.mock('@/api', () => ({ + authAPI: { + getCurrentUser: vi.fn().mockResolvedValue({ + data: { id: 1, username: 'test', email: 'test@example.com', role: 'user', balance: 100, concurrency: 5, status: 'active', allowed_groups: null, created_at: '', updated_at: '' }, + }), + logout: vi.fn(), + refreshToken: vi.fn(), + }, + isTotp2FARequired: () => false, +})) + +vi.mock('@/api/usage', () => ({ + usageAPI: { + getDashboardStats: (...args: any[]) => mockGetDashboardStats(...args), + }, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn().mockResolvedValue({}), +})) + +interface DashboardStats { + balance: number + api_key_count: number + active_api_key_count: number + today_requests: number + today_cost: number + today_tokens: number + total_tokens: number +} + +/** + * 简化的 Dashboard 测试组件 + */ +const DashboardTestComponent = defineComponent({ + setup() { + const stats = ref(null) + const loading = ref(false) + const error = ref('') + + const loadStats = async () => { + loading.value = true + error.value = '' + try { + stats.value = await mockGetDashboardStats() + } catch (e: any) { + error.value = e.message || '加载失败' + } finally { + loading.value = false + } + } + + onMounted(loadStats) + + return { stats, loading, error, loadStats } + }, + template: ` +
+
加载中...
+
{{ error }}
+
+ {{ stats.balance }} + {{ stats.api_key_count }} + {{ stats.today_requests }} + {{ stats.today_cost }} +
+ +
+ `, +}) + +describe('Dashboard 数据加载', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + }) + + const fakeStats: DashboardStats = { + balance: 100.5, + api_key_count: 3, + active_api_key_count: 2, + today_requests: 150, + today_cost: 2.5, + today_tokens: 50000, + total_tokens: 1000000, + } + + it('挂载后自动加载数据', async () => { + mockGetDashboardStats.mockResolvedValue(fakeStats) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(mockGetDashboardStats).toHaveBeenCalledTimes(1) + expect(wrapper.find('.balance').text()).toBe('100.5') + expect(wrapper.find('.api-keys').text()).toBe('3') + expect(wrapper.find('.today-requests').text()).toBe('150') + expect(wrapper.find('.today-cost').text()).toBe('2.5') + }) + + it('加载中显示 loading 状态', async () => { + let resolveStats: (v: any) => void + mockGetDashboardStats.mockImplementation( + () => new Promise((resolve) => { resolveStats = resolve }) + ) + + const wrapper = mount(DashboardTestComponent) + await nextTick() + + expect(wrapper.find('.loading').exists()).toBe(true) + + resolveStats!(fakeStats) + await flushPromises() + + expect(wrapper.find('.loading').exists()).toBe(false) + expect(wrapper.find('.stats').exists()).toBe(true) + }) + + it('加载失败时显示错误信息', async () => { + mockGetDashboardStats.mockRejectedValue(new Error('Network error')) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(wrapper.find('.error').text()).toBe('Network error') + expect(wrapper.find('.stats').exists()).toBe(false) + }) + + it('点击刷新按钮重新加载数据', async () => { + mockGetDashboardStats.mockResolvedValue(fakeStats) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(mockGetDashboardStats).toHaveBeenCalledTimes(1) + + // 更新数据 + const updatedStats = { ...fakeStats, today_requests: 200 } + mockGetDashboardStats.mockResolvedValue(updatedStats) + + await wrapper.find('.refresh').trigger('click') + await flushPromises() + + expect(mockGetDashboardStats).toHaveBeenCalledTimes(2) + expect(wrapper.find('.today-requests').text()).toBe('200') + }) + + it('数据为空时不显示统计信息', async () => { + mockGetDashboardStats.mockResolvedValue(null) + + const wrapper = mount(DashboardTestComponent) + await flushPromises() + + expect(wrapper.find('.stats').exists()).toBe(false) + }) +}) diff --git a/frontend/src/components/__tests__/LoginForm.spec.ts b/frontend/src/components/__tests__/LoginForm.spec.ts new file mode 100644 index 00000000..14b86fc2 --- /dev/null +++ b/frontend/src/components/__tests__/LoginForm.spec.ts @@ -0,0 +1,178 @@ +/** + * LoginView 组件核心逻辑测试 + * 测试登录表单提交、验证、2FA 等场景 + */ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { mount, flushPromises } from '@vue/test-utils' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, reactive, ref } from 'vue' +import { useAuthStore } from '@/stores/auth' + +// Mock 所有外部依赖 +const mockLogin = vi.fn() +const mockLogin2FA = vi.fn() +const mockPush = vi.fn() + +vi.mock('@/api', () => ({ + authAPI: { + login: (...args: any[]) => mockLogin(...args), + login2FA: (...args: any[]) => mockLogin2FA(...args), + logout: vi.fn(), + getCurrentUser: vi.fn().mockResolvedValue({ data: {} }), + register: vi.fn(), + refreshToken: vi.fn(), + }, + isTotp2FARequired: (response: any) => response?.requires_2fa === true, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn().mockResolvedValue({}), +})) + +/** + * 创建一个简化的测试组件来封装登录逻辑 + * 避免引入 LoginView.vue 的全部依赖(AuthLayout、i18n、Icon 等) + */ +const LoginFormTestComponent = defineComponent({ + setup() { + const authStore = useAuthStore() + const formData = reactive({ email: '', password: '' }) + const isLoading = ref(false) + const errorMessage = ref('') + + const handleLogin = async () => { + if (!formData.email || !formData.password) { + errorMessage.value = '请输入邮箱和密码' + return + } + + isLoading.value = true + errorMessage.value = '' + + try { + const response = await authStore.login({ + email: formData.email, + password: formData.password, + }) + + // 2FA 流程由调用方处理 + if ((response as any)?.requires_2fa) { + errorMessage.value = '需要 2FA 验证' + return + } + + mockPush('/dashboard') + } catch (error: any) { + errorMessage.value = error.message || '登录失败' + } finally { + isLoading.value = false + } + } + + return { formData, isLoading, errorMessage, handleLogin } + }, + template: ` +
+ + +

{{ errorMessage }}

+ +
+ `, +}) + +describe('LoginForm 核心逻辑', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.clearAllMocks() + }) + + it('成功登录后跳转到 dashboard', async () => { + mockLogin.mockResolvedValue({ + access_token: 'token', + token_type: 'Bearer', + user: { id: 1, username: 'test', email: 'test@example.com', role: 'user', balance: 0, concurrency: 5, status: 'active', allowed_groups: null, created_at: '', updated_at: '' }, + }) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('password123') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockLogin).toHaveBeenCalledWith({ + email: 'test@example.com', + password: 'password123', + }) + expect(mockPush).toHaveBeenCalledWith('/dashboard') + }) + + it('登录失败时显示错误信息', async () => { + mockLogin.mockRejectedValue(new Error('Invalid credentials')) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('wrong') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(wrapper.find('.error').text()).toBe('Invalid credentials') + }) + + it('空表单提交显示验证错误', async () => { + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(wrapper.find('.error').text()).toBe('请输入邮箱和密码') + expect(mockLogin).not.toHaveBeenCalled() + }) + + it('需要 2FA 时不跳转', async () => { + mockLogin.mockResolvedValue({ + requires_2fa: true, + temp_token: 'temp-123', + }) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('password123') + await wrapper.find('form').trigger('submit') + await flushPromises() + + expect(mockPush).not.toHaveBeenCalled() + expect(wrapper.find('.error').text()).toBe('需要 2FA 验证') + }) + + it('提交过程中按钮被禁用', async () => { + let resolveLogin: (v: any) => void + mockLogin.mockImplementation( + () => new Promise((resolve) => { resolveLogin = resolve }) + ) + + const wrapper = mount(LoginFormTestComponent) + + await wrapper.find('#email').setValue('test@example.com') + await wrapper.find('#password').setValue('password123') + await wrapper.find('form').trigger('submit') + + expect(wrapper.find('button').attributes('disabled')).toBeDefined() + + resolveLogin!({ + access_token: 'token', + token_type: 'Bearer', + user: { id: 1, username: 'test', email: 'test@example.com', role: 'user', balance: 0, concurrency: 5, status: 'active', allowed_groups: null, created_at: '', updated_at: '' }, + }) + await flushPromises() + + expect(wrapper.find('button').attributes('disabled')).toBeUndefined() + }) +}) diff --git a/frontend/src/composables/__tests__/useClipboard.spec.ts b/frontend/src/composables/__tests__/useClipboard.spec.ts new file mode 100644 index 00000000..b2c4de41 --- /dev/null +++ b/frontend/src/composables/__tests__/useClipboard.spec.ts @@ -0,0 +1,137 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' + +// Mock i18n +vi.mock('@/i18n', () => ({ + i18n: { + global: { + t: (key: string) => key, + }, + }, +})) + +// Mock app store +const mockShowSuccess = vi.fn() +const mockShowError = vi.fn() + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showSuccess: mockShowSuccess, + showError: mockShowError, + }), +})) + +import { useClipboard } from '@/composables/useClipboard' + +describe('useClipboard', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.useFakeTimers() + vi.clearAllMocks() + + // 默认模拟安全上下文 + Clipboard API + Object.defineProperty(window, 'isSecureContext', { value: true, writable: true }) + Object.defineProperty(navigator, 'clipboard', { + value: { + writeText: vi.fn().mockResolvedValue(undefined), + }, + writable: true, + configurable: true, + }) + }) + + afterEach(() => { + vi.useRealTimers() + // 恢复 execCommand + if ('execCommand' in document) { + delete (document as any).execCommand + } + }) + + it('复制成功后 copied 变为 true', async () => { + const { copied, copyToClipboard } = useClipboard() + + expect(copied.value).toBe(false) + + await copyToClipboard('hello') + + expect(copied.value).toBe(true) + }) + + it('copied 在 2 秒后自动恢复为 false', async () => { + const { copied, copyToClipboard } = useClipboard() + + await copyToClipboard('hello') + expect(copied.value).toBe(true) + + vi.advanceTimersByTime(2000) + + expect(copied.value).toBe(false) + }) + + it('复制成功时调用 showSuccess', async () => { + const { copyToClipboard } = useClipboard() + + await copyToClipboard('hello', '已复制') + + expect(mockShowSuccess).toHaveBeenCalledWith('已复制') + }) + + it('无自定义消息时使用 i18n 默认消息', async () => { + const { copyToClipboard } = useClipboard() + + await copyToClipboard('hello') + + expect(mockShowSuccess).toHaveBeenCalledWith('common.copiedToClipboard') + }) + + it('空文本返回 false 且不复制', async () => { + const { copyToClipboard, copied } = useClipboard() + + const result = await copyToClipboard('') + + expect(result).toBe(false) + expect(copied.value).toBe(false) + expect(navigator.clipboard.writeText).not.toHaveBeenCalled() + }) + + it('Clipboard API 失败时降级到 fallback', async () => { + ;(navigator.clipboard.writeText as any).mockRejectedValue(new Error('API failed')) + + // jsdom 没有 execCommand,手动定义 + ;(document as any).execCommand = vi.fn().mockReturnValue(true) + + const { copyToClipboard, copied } = useClipboard() + const result = await copyToClipboard('fallback text') + + expect(result).toBe(true) + expect(copied.value).toBe(true) + expect(document.execCommand).toHaveBeenCalledWith('copy') + }) + + it('非安全上下文使用 fallback', async () => { + Object.defineProperty(window, 'isSecureContext', { value: false, writable: true }) + + ;(document as any).execCommand = vi.fn().mockReturnValue(true) + + const { copyToClipboard, copied } = useClipboard() + const result = await copyToClipboard('insecure context text') + + expect(result).toBe(true) + expect(copied.value).toBe(true) + expect(navigator.clipboard.writeText).not.toHaveBeenCalled() + expect(document.execCommand).toHaveBeenCalledWith('copy') + }) + + it('所有复制方式均失败时调用 showError', async () => { + ;(navigator.clipboard.writeText as any).mockRejectedValue(new Error('fail')) + ;(document as any).execCommand = vi.fn().mockReturnValue(false) + + const { copyToClipboard, copied } = useClipboard() + const result = await copyToClipboard('text') + + expect(result).toBe(false) + expect(copied.value).toBe(false) + expect(mockShowError).toHaveBeenCalled() + }) +}) diff --git a/frontend/src/composables/__tests__/useForm.spec.ts b/frontend/src/composables/__tests__/useForm.spec.ts new file mode 100644 index 00000000..bd9396a2 --- /dev/null +++ b/frontend/src/composables/__tests__/useForm.spec.ts @@ -0,0 +1,143 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useForm } from '@/composables/useForm' +import { useAppStore } from '@/stores/app' + +// Mock API 依赖(app store 内部引用了这些) +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn(), +})) + +describe('useForm', () => { + let appStore: ReturnType + + beforeEach(() => { + setActivePinia(createPinia()) + appStore = useAppStore() + vi.clearAllMocks() + }) + + it('submit 期间 loading 为 true,完成后为 false', async () => { + let resolveSubmit: () => void + const submitFn = vi.fn( + () => new Promise((resolve) => { resolveSubmit = resolve }) + ) + + const { loading, submit } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + expect(loading.value).toBe(false) + + const submitPromise = submit() + // 提交中 + expect(loading.value).toBe(true) + + resolveSubmit!() + await submitPromise + + expect(loading.value).toBe(false) + }) + + it('submit 成功时显示成功消息', async () => { + const submitFn = vi.fn().mockResolvedValue(undefined) + const showSuccessSpy = vi.spyOn(appStore, 'showSuccess') + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + successMsg: '保存成功', + }) + + await submit() + + expect(showSuccessSpy).toHaveBeenCalledWith('保存成功') + }) + + it('submit 成功但无 successMsg 时不调用 showSuccess', async () => { + const submitFn = vi.fn().mockResolvedValue(undefined) + const showSuccessSpy = vi.spyOn(appStore, 'showSuccess') + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + await submit() + + expect(showSuccessSpy).not.toHaveBeenCalled() + }) + + it('submit 失败时显示错误消息并抛出错误', async () => { + const error = Object.assign(new Error('提交失败'), { + response: { data: { message: '服务器错误' } }, + }) + const submitFn = vi.fn().mockRejectedValue(error) + const showErrorSpy = vi.spyOn(appStore, 'showError') + + const { submit, loading } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + await expect(submit()).rejects.toThrow('提交失败') + + expect(showErrorSpy).toHaveBeenCalled() + expect(loading.value).toBe(false) + }) + + it('submit 失败时使用自定义 errorMsg', async () => { + const submitFn = vi.fn().mockRejectedValue(new Error('network')) + const showErrorSpy = vi.spyOn(appStore, 'showError') + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + errorMsg: '自定义错误提示', + }) + + await expect(submit()).rejects.toThrow() + + expect(showErrorSpy).toHaveBeenCalledWith('自定义错误提示') + }) + + it('loading 中不会重复提交', async () => { + let resolveSubmit: () => void + const submitFn = vi.fn( + () => new Promise((resolve) => { resolveSubmit = resolve }) + ) + + const { submit } = useForm({ + form: { name: 'test' }, + submitFn, + }) + + // 第一次提交 + const p1 = submit() + // 第二次提交(应被忽略,因为 loading=true) + submit() + + expect(submitFn).toHaveBeenCalledTimes(1) + + resolveSubmit!() + await p1 + }) + + it('传递 form 数据到 submitFn', async () => { + const formData = { name: 'test', email: 'test@example.com' } + const submitFn = vi.fn().mockResolvedValue(undefined) + + const { submit } = useForm({ + form: formData, + submitFn, + }) + + await submit() + + expect(submitFn).toHaveBeenCalledWith(formData) + }) +}) diff --git a/frontend/src/composables/__tests__/useTableLoader.spec.ts b/frontend/src/composables/__tests__/useTableLoader.spec.ts new file mode 100644 index 00000000..0eb6f42c --- /dev/null +++ b/frontend/src/composables/__tests__/useTableLoader.spec.ts @@ -0,0 +1,252 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { useTableLoader } from '@/composables/useTableLoader' +import { nextTick } from 'vue' + +// Mock @vueuse/core 的 useDebounceFn +vi.mock('@vueuse/core', () => ({ + useDebounceFn: (fn: Function, ms: number) => { + let timer: ReturnType | null = null + const debounced = (...args: any[]) => { + if (timer) clearTimeout(timer) + timer = setTimeout(() => fn(...args), ms) + } + debounced.cancel = () => { if (timer) clearTimeout(timer) } + return debounced + }, +})) + +// Mock Vue 的 onUnmounted(composable 外使用时会报错) +vi.mock('vue', async () => { + const actual = await vi.importActual('vue') + return { + ...actual, + onUnmounted: vi.fn(), + } +}) + +const createMockFetchFn = (items: any[] = [], total = 0, pages = 1) => { + return vi.fn().mockResolvedValue({ items, total, pages }) +} + +describe('useTableLoader', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.clearAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- 基础加载 --- + + describe('基础加载', () => { + it('load 执行 fetchFn 并更新 items', async () => { + const mockItems = [{ id: 1, name: 'item1' }, { id: 2, name: 'item2' }] + const fetchFn = createMockFetchFn(mockItems, 2, 1) + + const { items, loading, load, pagination } = useTableLoader({ + fetchFn, + }) + + expect(items.value).toHaveLength(0) + + await load() + + expect(items.value).toEqual(mockItems) + expect(pagination.total).toBe(2) + expect(pagination.pages).toBe(1) + expect(loading.value).toBe(false) + }) + + it('load 期间 loading 为 true', async () => { + let resolveLoad: (v: any) => void + const fetchFn = vi.fn( + () => new Promise((resolve) => { resolveLoad = resolve }) + ) + + const { loading, load } = useTableLoader({ fetchFn }) + + const p = load() + expect(loading.value).toBe(true) + + resolveLoad!({ items: [], total: 0, pages: 0 }) + await p + + expect(loading.value).toBe(false) + }) + + it('使用默认 pageSize=20', async () => { + const fetchFn = createMockFetchFn() + const { load, pagination } = useTableLoader({ fetchFn }) + + await load() + + expect(fetchFn).toHaveBeenCalledWith( + 1, + 20, + expect.anything(), + expect.objectContaining({ signal: expect.any(AbortSignal) }) + ) + expect(pagination.page_size).toBe(20) + }) + + it('可自定义 pageSize', async () => { + const fetchFn = createMockFetchFn() + const { load } = useTableLoader({ fetchFn, pageSize: 50 }) + + await load() + + expect(fetchFn).toHaveBeenCalledWith( + 1, + 50, + expect.anything(), + expect.anything() + ) + }) + }) + + // --- 分页 --- + + describe('分页', () => { + it('handlePageChange 更新页码并加载', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { handlePageChange, pagination, load } = useTableLoader({ fetchFn }) + + await load() // 初始加载 + fetchFn.mockClear() + + handlePageChange(3) + + expect(pagination.page).toBe(3) + // 等待 load 完成 + await vi.runAllTimersAsync() + expect(fetchFn).toHaveBeenCalledWith(3, 20, expect.anything(), expect.anything()) + }) + + it('handlePageSizeChange 重置到第1页并加载', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { handlePageSizeChange, pagination, load } = useTableLoader({ fetchFn }) + + await load() + pagination.page = 3 + fetchFn.mockClear() + + handlePageSizeChange(50) + + expect(pagination.page).toBe(1) + expect(pagination.page_size).toBe(50) + }) + + it('handlePageChange 限制页码范围', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { handlePageChange, pagination, load } = useTableLoader({ fetchFn }) + + await load() + + // 超出范围的页码被限制 + handlePageChange(999) + expect(pagination.page).toBe(5) // 限制在 pages=5 + + handlePageChange(0) + expect(pagination.page).toBe(1) // 最小为 1 + }) + }) + + // --- 搜索防抖 --- + + describe('搜索防抖', () => { + it('debouncedReload 在 300ms 内多次调用只执行一次', async () => { + const fetchFn = createMockFetchFn() + const { debouncedReload } = useTableLoader({ fetchFn }) + + // 快速连续调用 + debouncedReload() + debouncedReload() + debouncedReload() + + // 还没到 300ms,不应调用 fetchFn + expect(fetchFn).not.toHaveBeenCalled() + + // 推进 300ms + vi.advanceTimersByTime(300) + + // 等待异步完成 + await vi.runAllTimersAsync() + + expect(fetchFn).toHaveBeenCalledTimes(1) + }) + + it('reload 重置到第 1 页', async () => { + const fetchFn = createMockFetchFn([], 100, 5) + const { reload, pagination, load } = useTableLoader({ fetchFn }) + + await load() + pagination.page = 3 + + await reload() + + expect(pagination.page).toBe(1) + }) + }) + + // --- 请求取消 --- + + describe('请求取消', () => { + it('新请求取消前一个未完成的请求', async () => { + let callCount = 0 + const fetchFn = vi.fn((_page, _size, _params, options) => { + callCount++ + const currentCall = callCount + return new Promise((resolve, reject) => { + // 模拟监听 abort + if (options?.signal) { + options.signal.addEventListener('abort', () => { + reject({ name: 'CanceledError', code: 'ERR_CANCELED' }) + }) + } + // 异步解决 + setTimeout(() => { + resolve({ items: [{ id: currentCall }], total: 1, pages: 1 }) + }, 1000) + }) + }) + + const { load, items } = useTableLoader({ fetchFn }) + + // 第一次加载 + const p1 = load() + // 第二次加载(应取消第一次) + const p2 = load() + + // 推进时间让第二次完成 + vi.advanceTimersByTime(1000) + await vi.runAllTimersAsync() + + // 等待两个 Promise settle + await Promise.allSettled([p1, p2]) + + // 第二次请求的结果生效 + expect(fetchFn).toHaveBeenCalledTimes(2) + }) + }) + + // --- 错误处理 --- + + describe('错误处理', () => { + it('非取消错误会被抛出', async () => { + const fetchFn = vi.fn().mockRejectedValue(new Error('Server error')) + const { load } = useTableLoader({ fetchFn }) + + await expect(load()).rejects.toThrow('Server error') + }) + + it('取消错误被静默处理', async () => { + const fetchFn = vi.fn().mockRejectedValue({ name: 'CanceledError', code: 'ERR_CANCELED' }) + const { load } = useTableLoader({ fetchFn }) + + // 不应抛出 + await load() + }) + }) +}) diff --git a/frontend/src/router/__tests__/guards.spec.ts b/frontend/src/router/__tests__/guards.spec.ts new file mode 100644 index 00000000..931f4534 --- /dev/null +++ b/frontend/src/router/__tests__/guards.spec.ts @@ -0,0 +1,324 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { createRouter, createMemoryHistory } from 'vue-router' +import { setActivePinia, createPinia } from 'pinia' +import { defineComponent, h } from 'vue' + +// Mock 导航加载状态 +vi.mock('@/composables/useNavigationLoading', () => { + const mockStart = vi.fn() + const mockEnd = vi.fn() + return { + useNavigationLoadingState: () => ({ + startNavigation: mockStart, + endNavigation: mockEnd, + isLoading: { value: false }, + }), + useNavigationLoading: () => ({ + startNavigation: mockStart, + endNavigation: mockEnd, + isLoading: { value: false }, + }), + } +}) + +// Mock 路由预加载 +vi.mock('@/composables/useRoutePrefetch', () => ({ + useRoutePrefetch: () => ({ + triggerPrefetch: vi.fn(), + cancelPendingPrefetch: vi.fn(), + resetPrefetchState: vi.fn(), + }), +})) + +// Mock API 相关模块 +vi.mock('@/api', () => ({ + authAPI: { + getCurrentUser: vi.fn().mockResolvedValue({ data: {} }), + logout: vi.fn(), + }, + isTotp2FARequired: () => false, +})) + +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn(), +})) + +const DummyComponent = defineComponent({ + render() { + return h('div', 'dummy') + }, +}) + +/** + * 创建带守卫逻辑的测试路由 + * 模拟 router/index.ts 中的 beforeEach 守卫逻辑 + */ +function createTestRouter() { + const router = createRouter({ + history: createMemoryHistory(), + routes: [ + { path: '/login', component: DummyComponent, meta: { requiresAuth: false, title: 'Login' } }, + { + path: '/register', + component: DummyComponent, + meta: { requiresAuth: false, title: 'Register' }, + }, + { path: '/home', component: DummyComponent, meta: { requiresAuth: false, title: 'Home' } }, + { path: '/dashboard', component: DummyComponent, meta: { title: 'Dashboard' } }, + { path: '/keys', component: DummyComponent, meta: { title: 'API Keys' } }, + { path: '/subscriptions', component: DummyComponent, meta: { title: 'Subscriptions' } }, + { path: '/redeem', component: DummyComponent, meta: { title: 'Redeem' } }, + { + path: '/admin/dashboard', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Dashboard' }, + }, + { + path: '/admin/users', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Users' }, + }, + { + path: '/admin/groups', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Groups' }, + }, + { + path: '/admin/subscriptions', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Subscriptions' }, + }, + { + path: '/admin/redeem', + component: DummyComponent, + meta: { requiresAdmin: true, title: 'Admin Redeem' }, + }, + ], + }) + + return router +} + +// 用于测试的 auth 状态 +interface MockAuthState { + isAuthenticated: boolean + isAdmin: boolean + isSimpleMode: boolean +} + +/** + * 将 router/index.ts 中 beforeEach 守卫的核心逻辑提取为可测试的函数 + */ +function simulateGuard( + toPath: string, + toMeta: Record, + authState: MockAuthState +): string | null { + const requiresAuth = toMeta.requiresAuth !== false + const requiresAdmin = toMeta.requiresAdmin === true + + // 不需要认证的路由 + if (!requiresAuth) { + if ( + authState.isAuthenticated && + (toPath === '/login' || toPath === '/register') + ) { + return authState.isAdmin ? '/admin/dashboard' : '/dashboard' + } + return null // 允许通过 + } + + // 需要认证但未登录 + if (!authState.isAuthenticated) { + return '/login' + } + + // 需要管理员但不是管理员 + if (requiresAdmin && !authState.isAdmin) { + return '/dashboard' + } + + // 简易模式限制 + if (authState.isSimpleMode) { + const restrictedPaths = [ + '/admin/groups', + '/admin/subscriptions', + '/admin/redeem', + '/subscriptions', + '/redeem', + ] + if (restrictedPaths.some((path) => toPath.startsWith(path))) { + return authState.isAdmin ? '/admin/dashboard' : '/dashboard' + } + } + + return null // 允许通过 +} + +describe('路由守卫逻辑', () => { + beforeEach(() => { + setActivePinia(createPinia()) + }) + + // --- 未认证用户 --- + + describe('未认证用户', () => { + const authState: MockAuthState = { + isAuthenticated: false, + isAdmin: false, + isSimpleMode: false, + } + + it('访问需要认证的页面重定向到 /login', () => { + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBe('/login') + }) + + it('访问管理页面重定向到 /login', () => { + const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState) + expect(redirect).toBe('/login') + }) + + it('访问公开页面允许通过', () => { + const redirect = simulateGuard('/login', { requiresAuth: false }, authState) + expect(redirect).toBeNull() + }) + + it('访问 /home 公开页面允许通过', () => { + const redirect = simulateGuard('/home', { requiresAuth: false }, authState) + expect(redirect).toBeNull() + }) + }) + + // --- 已认证普通用户 --- + + describe('已认证普通用户', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: false, + } + + it('访问 /login 重定向到 /dashboard', () => { + const redirect = simulateGuard('/login', { requiresAuth: false }, authState) + expect(redirect).toBe('/dashboard') + }) + + it('访问 /register 重定向到 /dashboard', () => { + const redirect = simulateGuard('/register', { requiresAuth: false }, authState) + expect(redirect).toBe('/dashboard') + }) + + it('访问 /dashboard 允许通过', () => { + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBeNull() + }) + + it('访问管理页面被拒绝,重定向到 /dashboard', () => { + const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState) + expect(redirect).toBe('/dashboard') + }) + + it('访问 /admin/users 被拒绝', () => { + const redirect = simulateGuard('/admin/users', { requiresAdmin: true }, authState) + expect(redirect).toBe('/dashboard') + }) + }) + + // --- 已认证管理员 --- + + describe('已认证管理员', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: true, + isSimpleMode: false, + } + + it('访问 /login 重定向到 /admin/dashboard', () => { + const redirect = simulateGuard('/login', { requiresAuth: false }, authState) + expect(redirect).toBe('/admin/dashboard') + }) + + it('访问管理页面允许通过', () => { + const redirect = simulateGuard('/admin/dashboard', { requiresAdmin: true }, authState) + expect(redirect).toBeNull() + }) + + it('访问用户页面允许通过', () => { + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBeNull() + }) + }) + + // --- 简易模式 --- + + describe('简易模式受限路由', () => { + it('普通用户简易模式访问 /subscriptions 重定向到 /dashboard', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/subscriptions', {}, authState) + expect(redirect).toBe('/dashboard') + }) + + it('普通用户简易模式访问 /redeem 重定向到 /dashboard', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/redeem', {}, authState) + expect(redirect).toBe('/dashboard') + }) + + it('管理员简易模式访问 /admin/groups 重定向到 /admin/dashboard', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: true, + isSimpleMode: true, + } + const redirect = simulateGuard('/admin/groups', { requiresAdmin: true }, authState) + expect(redirect).toBe('/admin/dashboard') + }) + + it('管理员简易模式访问 /admin/subscriptions 重定向', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: true, + isSimpleMode: true, + } + const redirect = simulateGuard( + '/admin/subscriptions', + { requiresAdmin: true }, + authState + ) + expect(redirect).toBe('/admin/dashboard') + }) + + it('简易模式下非受限页面正常访问', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/dashboard', {}, authState) + expect(redirect).toBeNull() + }) + + it('简易模式下 /keys 正常访问', () => { + const authState: MockAuthState = { + isAuthenticated: true, + isAdmin: false, + isSimpleMode: true, + } + const redirect = simulateGuard('/keys', {}, authState) + expect(redirect).toBeNull() + }) + }) +}) diff --git a/frontend/src/stores/__tests__/app.spec.ts b/frontend/src/stores/__tests__/app.spec.ts new file mode 100644 index 00000000..432a7079 --- /dev/null +++ b/frontend/src/stores/__tests__/app.spec.ts @@ -0,0 +1,293 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useAppStore } from '@/stores/app' + +// Mock API 模块 +vi.mock('@/api/admin/system', () => ({ + checkUpdates: vi.fn(), +})) + +vi.mock('@/api/auth', () => ({ + getPublicSettings: vi.fn(), +})) + +describe('useAppStore', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.useFakeTimers() + // 清除 window.__APP_CONFIG__ + delete (window as any).__APP_CONFIG__ + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- Toast 消息管理 --- + + describe('Toast 消息管理', () => { + it('showSuccess 创建 success 类型 toast', () => { + const store = useAppStore() + const id = store.showSuccess('操作成功') + + expect(id).toMatch(/^toast-/) + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('success') + expect(store.toasts[0].message).toBe('操作成功') + }) + + it('showError 创建 error 类型 toast', () => { + const store = useAppStore() + store.showError('出错了') + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('error') + expect(store.toasts[0].message).toBe('出错了') + }) + + it('showWarning 创建 warning 类型 toast', () => { + const store = useAppStore() + store.showWarning('警告信息') + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('warning') + }) + + it('showInfo 创建 info 类型 toast', () => { + const store = useAppStore() + store.showInfo('提示信息') + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('info') + }) + + it('toast 在指定 duration 后自动消失', () => { + const store = useAppStore() + store.showSuccess('临时消息', 3000) + + expect(store.toasts).toHaveLength(1) + + vi.advanceTimersByTime(3000) + + expect(store.toasts).toHaveLength(0) + }) + + it('hideToast 移除指定 toast', () => { + const store = useAppStore() + const id = store.showSuccess('消息1') + store.showError('消息2') + + expect(store.toasts).toHaveLength(2) + + store.hideToast(id) + + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('error') + }) + + it('clearAllToasts 清除所有 toast', () => { + const store = useAppStore() + store.showSuccess('消息1') + store.showError('消息2') + store.showWarning('消息3') + + expect(store.toasts).toHaveLength(3) + + store.clearAllToasts() + + expect(store.toasts).toHaveLength(0) + }) + + it('hasActiveToasts 正确反映 toast 状态', () => { + const store = useAppStore() + expect(store.hasActiveToasts).toBe(false) + + store.showSuccess('消息') + expect(store.hasActiveToasts).toBe(true) + + store.clearAllToasts() + expect(store.hasActiveToasts).toBe(false) + }) + + it('多个 toast 的 ID 唯一', () => { + const store = useAppStore() + const id1 = store.showSuccess('消息1') + const id2 = store.showSuccess('消息2') + const id3 = store.showSuccess('消息3') + + expect(id1).not.toBe(id2) + expect(id2).not.toBe(id3) + }) + }) + + // --- 侧边栏 --- + + describe('侧边栏管理', () => { + it('toggleSidebar 切换折叠状态', () => { + const store = useAppStore() + expect(store.sidebarCollapsed).toBe(false) + + store.toggleSidebar() + expect(store.sidebarCollapsed).toBe(true) + + store.toggleSidebar() + expect(store.sidebarCollapsed).toBe(false) + }) + + it('setSidebarCollapsed 直接设置状态', () => { + const store = useAppStore() + + store.setSidebarCollapsed(true) + expect(store.sidebarCollapsed).toBe(true) + + store.setSidebarCollapsed(false) + expect(store.sidebarCollapsed).toBe(false) + }) + + it('toggleMobileSidebar 切换移动端状态', () => { + const store = useAppStore() + expect(store.mobileOpen).toBe(false) + + store.toggleMobileSidebar() + expect(store.mobileOpen).toBe(true) + + store.toggleMobileSidebar() + expect(store.mobileOpen).toBe(false) + }) + }) + + // --- Loading 状态 --- + + describe('Loading 状态管理', () => { + it('setLoading 管理引用计数', () => { + const store = useAppStore() + expect(store.loading).toBe(false) + + store.setLoading(true) + expect(store.loading).toBe(true) + + store.setLoading(true) // 两次 true + expect(store.loading).toBe(true) + + store.setLoading(false) // 第一次 false,计数还是 1 + expect(store.loading).toBe(true) + + store.setLoading(false) // 第二次 false,计数为 0 + expect(store.loading).toBe(false) + }) + + it('setLoading(false) 不会使计数为负', () => { + const store = useAppStore() + + store.setLoading(false) + store.setLoading(false) + expect(store.loading).toBe(false) + + store.setLoading(true) + expect(store.loading).toBe(true) + + store.setLoading(false) + expect(store.loading).toBe(false) + }) + + it('withLoading 自动管理 loading 状态', async () => { + const store = useAppStore() + + const result = await store.withLoading(async () => { + expect(store.loading).toBe(true) + return 'done' + }) + + expect(result).toBe('done') + expect(store.loading).toBe(false) + }) + + it('withLoading 错误时也恢复 loading 状态', async () => { + const store = useAppStore() + + await expect( + store.withLoading(async () => { + throw new Error('操作失败') + }) + ).rejects.toThrow('操作失败') + + expect(store.loading).toBe(false) + }) + + it('withLoadingAndError 错误时显示 toast 并返回 null', async () => { + const store = useAppStore() + + const result = await store.withLoadingAndError(async () => { + throw new Error('网络错误') + }) + + expect(result).toBeNull() + expect(store.loading).toBe(false) + expect(store.toasts).toHaveLength(1) + expect(store.toasts[0].type).toBe('error') + }) + }) + + // --- reset --- + + describe('reset', () => { + it('重置所有 UI 状态', () => { + const store = useAppStore() + + store.setSidebarCollapsed(true) + store.setLoading(true) + store.showSuccess('消息') + + store.reset() + + expect(store.sidebarCollapsed).toBe(false) + expect(store.loading).toBe(false) + expect(store.toasts).toHaveLength(0) + }) + }) + + // --- 公开设置 --- + + describe('公开设置加载', () => { + it('从 window.__APP_CONFIG__ 初始化', () => { + ;(window as any).__APP_CONFIG__ = { + site_name: 'TestSite', + site_logo: '/logo.png', + version: '1.0.0', + contact_info: 'test@test.com', + api_base_url: 'https://api.test.com', + doc_url: 'https://docs.test.com', + } + + const store = useAppStore() + const result = store.initFromInjectedConfig() + + expect(result).toBe(true) + expect(store.siteName).toBe('TestSite') + expect(store.siteLogo).toBe('/logo.png') + expect(store.siteVersion).toBe('1.0.0') + expect(store.publicSettingsLoaded).toBe(true) + }) + + it('无注入配置时返回 false', () => { + const store = useAppStore() + const result = store.initFromInjectedConfig() + + expect(result).toBe(false) + expect(store.publicSettingsLoaded).toBe(false) + }) + + it('clearPublicSettingsCache 清除缓存', () => { + ;(window as any).__APP_CONFIG__ = { site_name: 'Test' } + const store = useAppStore() + store.initFromInjectedConfig() + + expect(store.publicSettingsLoaded).toBe(true) + + store.clearPublicSettingsCache() + + expect(store.publicSettingsLoaded).toBe(false) + expect(store.cachedPublicSettings).toBeNull() + }) + }) +}) diff --git a/frontend/src/stores/__tests__/auth.spec.ts b/frontend/src/stores/__tests__/auth.spec.ts new file mode 100644 index 00000000..ee6ad24e --- /dev/null +++ b/frontend/src/stores/__tests__/auth.spec.ts @@ -0,0 +1,289 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useAuthStore } from '@/stores/auth' + +// Mock authAPI +const mockLogin = vi.fn() +const mockLogin2FA = vi.fn() +const mockLogout = vi.fn() +const mockGetCurrentUser = vi.fn() +const mockRegister = vi.fn() +const mockRefreshToken = vi.fn() + +vi.mock('@/api', () => ({ + authAPI: { + login: (...args: any[]) => mockLogin(...args), + login2FA: (...args: any[]) => mockLogin2FA(...args), + logout: (...args: any[]) => mockLogout(...args), + getCurrentUser: (...args: any[]) => mockGetCurrentUser(...args), + register: (...args: any[]) => mockRegister(...args), + refreshToken: (...args: any[]) => mockRefreshToken(...args), + }, + isTotp2FARequired: (response: any) => response?.requires_2fa === true, +})) + +const fakeUser = { + id: 1, + username: 'testuser', + email: 'test@example.com', + role: 'user' as const, + balance: 100, + concurrency: 5, + status: 'active' as const, + allowed_groups: null, + created_at: '2024-01-01', + updated_at: '2024-01-01', +} + +const fakeAdminUser = { + ...fakeUser, + id: 2, + username: 'admin', + email: 'admin@example.com', + role: 'admin' as const, +} + +const fakeAuthResponse = { + access_token: 'test-token-123', + refresh_token: 'refresh-token-456', + expires_in: 3600, + token_type: 'Bearer', + user: { ...fakeUser }, +} + +describe('useAuthStore', () => { + beforeEach(() => { + setActivePinia(createPinia()) + localStorage.clear() + vi.useFakeTimers() + vi.clearAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- login --- + + describe('login', () => { + it('成功登录后设置 token 和 user', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + + await store.login({ email: 'test@example.com', password: '123456' }) + + expect(store.token).toBe('test-token-123') + expect(store.user).toEqual(fakeUser) + expect(store.isAuthenticated).toBe(true) + expect(localStorage.getItem('auth_token')).toBe('test-token-123') + expect(localStorage.getItem('auth_user')).toBe(JSON.stringify(fakeUser)) + }) + + it('登录失败时清除状态并抛出错误', async () => { + mockLogin.mockRejectedValue(new Error('Invalid credentials')) + const store = useAuthStore() + + await expect(store.login({ email: 'test@example.com', password: 'wrong' })).rejects.toThrow( + 'Invalid credentials' + ) + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + + it('需要 2FA 时返回响应但不设置认证状态', async () => { + const twoFAResponse = { requires_2fa: true, temp_token: 'temp-123' } + mockLogin.mockResolvedValue(twoFAResponse) + const store = useAuthStore() + + const result = await store.login({ email: 'test@example.com', password: '123456' }) + + expect(result).toEqual(twoFAResponse) + expect(store.token).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + }) + + // --- login2FA --- + + describe('login2FA', () => { + it('2FA 验证成功后设置认证状态', async () => { + mockLogin2FA.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + + const user = await store.login2FA('temp-123', '654321') + + expect(store.token).toBe('test-token-123') + expect(store.user).toEqual(fakeUser) + expect(user).toEqual(fakeUser) + expect(mockLogin2FA).toHaveBeenCalledWith({ + temp_token: 'temp-123', + totp_code: '654321', + }) + }) + + it('2FA 验证失败时清除状态并抛出错误', async () => { + mockLogin2FA.mockRejectedValue(new Error('Invalid TOTP')) + const store = useAuthStore() + + await expect(store.login2FA('temp-123', '000000')).rejects.toThrow('Invalid TOTP') + expect(store.token).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + }) + + // --- logout --- + + describe('logout', () => { + it('注销后清除所有状态和 localStorage', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + mockLogout.mockResolvedValue(undefined) + const store = useAuthStore() + + // 先登录 + await store.login({ email: 'test@example.com', password: '123456' }) + expect(store.isAuthenticated).toBe(true) + + // 注销 + await store.logout() + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(store.isAuthenticated).toBe(false) + expect(localStorage.getItem('auth_token')).toBeNull() + expect(localStorage.getItem('auth_user')).toBeNull() + expect(localStorage.getItem('refresh_token')).toBeNull() + expect(localStorage.getItem('token_expires_at')).toBeNull() + }) + }) + + // --- checkAuth --- + + describe('checkAuth', () => { + it('从 localStorage 恢复持久化状态', () => { + localStorage.setItem('auth_token', 'saved-token') + localStorage.setItem('auth_user', JSON.stringify(fakeUser)) + + // Mock refreshUser (getCurrentUser) 防止后台刷新报错 + mockGetCurrentUser.mockResolvedValue({ data: fakeUser }) + + const store = useAuthStore() + store.checkAuth() + + expect(store.token).toBe('saved-token') + expect(store.user).toEqual(fakeUser) + expect(store.isAuthenticated).toBe(true) + }) + + it('localStorage 无数据时保持未认证状态', () => { + const store = useAuthStore() + store.checkAuth() + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(store.isAuthenticated).toBe(false) + }) + + it('localStorage 中用户数据损坏时清除状态', () => { + localStorage.setItem('auth_token', 'saved-token') + localStorage.setItem('auth_user', 'invalid-json{{{') + + const store = useAuthStore() + store.checkAuth() + + expect(store.token).toBeNull() + expect(store.user).toBeNull() + expect(localStorage.getItem('auth_token')).toBeNull() + }) + + it('恢复 refresh token 和过期时间', () => { + const futureTs = String(Date.now() + 3600_000) + localStorage.setItem('auth_token', 'saved-token') + localStorage.setItem('auth_user', JSON.stringify(fakeUser)) + localStorage.setItem('refresh_token', 'saved-refresh') + localStorage.setItem('token_expires_at', futureTs) + + mockGetCurrentUser.mockResolvedValue({ data: fakeUser }) + + const store = useAuthStore() + store.checkAuth() + + expect(store.isAuthenticated).toBe(true) + }) + }) + + // --- isAdmin --- + + describe('isAdmin', () => { + it('管理员用户返回 true', async () => { + const adminResponse = { ...fakeAuthResponse, user: { ...fakeAdminUser } } + mockLogin.mockResolvedValue(adminResponse) + const store = useAuthStore() + + await store.login({ email: 'admin@example.com', password: '123456' }) + + expect(store.isAdmin).toBe(true) + }) + + it('普通用户返回 false', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + + await store.login({ email: 'test@example.com', password: '123456' }) + + expect(store.isAdmin).toBe(false) + }) + + it('未登录时返回 false', () => { + const store = useAuthStore() + expect(store.isAdmin).toBe(false) + }) + }) + + // --- refreshUser --- + + describe('refreshUser', () => { + it('刷新用户数据并更新 localStorage', async () => { + mockLogin.mockResolvedValue(fakeAuthResponse) + const store = useAuthStore() + await store.login({ email: 'test@example.com', password: '123456' }) + + const updatedUser = { ...fakeUser, username: 'updated-name' } + mockGetCurrentUser.mockResolvedValue({ data: updatedUser }) + + const result = await store.refreshUser() + + expect(result).toEqual(updatedUser) + expect(store.user).toEqual(updatedUser) + expect(JSON.parse(localStorage.getItem('auth_user')!)).toEqual(updatedUser) + }) + + it('未认证时抛出错误', async () => { + const store = useAuthStore() + await expect(store.refreshUser()).rejects.toThrow('Not authenticated') + }) + }) + + // --- isSimpleMode --- + + describe('isSimpleMode', () => { + it('run_mode 为 simple 时返回 true', async () => { + const simpleResponse = { + ...fakeAuthResponse, + user: { ...fakeUser, run_mode: 'simple' as const }, + } + mockLogin.mockResolvedValue(simpleResponse) + const store = useAuthStore() + + await store.login({ email: 'test@example.com', password: '123456' }) + + expect(store.isSimpleMode).toBe(true) + }) + + it('默认为 standard 模式', () => { + const store = useAuthStore() + expect(store.isSimpleMode).toBe(false) + }) + }) +}) diff --git a/frontend/src/stores/__tests__/subscriptions.spec.ts b/frontend/src/stores/__tests__/subscriptions.spec.ts new file mode 100644 index 00000000..4c0b4b89 --- /dev/null +++ b/frontend/src/stores/__tests__/subscriptions.spec.ts @@ -0,0 +1,239 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { setActivePinia, createPinia } from 'pinia' +import { useSubscriptionStore } from '@/stores/subscriptions' + +// Mock subscriptions API +const mockGetActiveSubscriptions = vi.fn() + +vi.mock('@/api/subscriptions', () => ({ + default: { + getActiveSubscriptions: (...args: any[]) => mockGetActiveSubscriptions(...args), + }, +})) + +const fakeSubscriptions = [ + { + id: 1, + user_id: 1, + group_id: 1, + status: 'active' as const, + daily_usage_usd: 5, + weekly_usage_usd: 20, + monthly_usage_usd: 50, + daily_window_start: null, + weekly_window_start: null, + monthly_window_start: null, + created_at: '2024-01-01', + updated_at: '2024-01-01', + expires_at: '2025-01-01', + }, + { + id: 2, + user_id: 1, + group_id: 2, + status: 'active' as const, + daily_usage_usd: 10, + weekly_usage_usd: 40, + monthly_usage_usd: 100, + daily_window_start: null, + weekly_window_start: null, + monthly_window_start: null, + created_at: '2024-02-01', + updated_at: '2024-02-01', + expires_at: '2025-02-01', + }, +] + +describe('useSubscriptionStore', () => { + beforeEach(() => { + setActivePinia(createPinia()) + vi.useFakeTimers() + vi.clearAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + // --- fetchActiveSubscriptions --- + + describe('fetchActiveSubscriptions', () => { + it('成功获取活跃订阅', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + const result = await store.fetchActiveSubscriptions() + + expect(result).toEqual(fakeSubscriptions) + expect(store.activeSubscriptions).toEqual(fakeSubscriptions) + expect(store.loading).toBe(false) + }) + + it('缓存有效时返回缓存数据', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + // 第一次请求 + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + // 第二次请求(60秒内)- 应返回缓存 + const result = await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) // 没有新请求 + expect(result).toEqual(fakeSubscriptions) + }) + + it('缓存过期后重新请求', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + // 推进 61 秒让缓存过期 + vi.advanceTimersByTime(61_000) + + const updatedSubs = [fakeSubscriptions[0]] + mockGetActiveSubscriptions.mockResolvedValue(updatedSubs) + + const result = await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(2) + expect(result).toEqual(updatedSubs) + }) + + it('force=true 强制重新请求', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + + const updatedSubs = [fakeSubscriptions[0]] + mockGetActiveSubscriptions.mockResolvedValue(updatedSubs) + + const result = await store.fetchActiveSubscriptions(true) + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(2) + expect(result).toEqual(updatedSubs) + }) + + it('并发请求共享同一个 Promise(去重)', async () => { + let resolvePromise: (v: any) => void + mockGetActiveSubscriptions.mockImplementation( + () => new Promise((resolve) => { resolvePromise = resolve }) + ) + const store = useSubscriptionStore() + + // 并发发起两个请求 + const p1 = store.fetchActiveSubscriptions() + const p2 = store.fetchActiveSubscriptions() + + // 只调用了一次 API + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + // 解决 Promise + resolvePromise!(fakeSubscriptions) + + const [r1, r2] = await Promise.all([p1, p2]) + expect(r1).toEqual(fakeSubscriptions) + expect(r2).toEqual(fakeSubscriptions) + }) + + it('API 错误时抛出异常', async () => { + mockGetActiveSubscriptions.mockRejectedValue(new Error('Network error')) + const store = useSubscriptionStore() + + await expect(store.fetchActiveSubscriptions()).rejects.toThrow('Network error') + }) + }) + + // --- hasActiveSubscriptions --- + + describe('hasActiveSubscriptions', () => { + it('有订阅时返回 true', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + + expect(store.hasActiveSubscriptions).toBe(true) + }) + + it('无订阅时返回 false', () => { + const store = useSubscriptionStore() + expect(store.hasActiveSubscriptions).toBe(false) + }) + + it('清除后返回 false', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(store.hasActiveSubscriptions).toBe(true) + + store.clear() + expect(store.hasActiveSubscriptions).toBe(false) + }) + }) + + // --- invalidateCache --- + + describe('invalidateCache', () => { + it('失效缓存后下次请求重新获取数据', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + store.invalidateCache() + + await store.fetchActiveSubscriptions() + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(2) + }) + }) + + // --- clear --- + + describe('clear', () => { + it('清除所有订阅数据', async () => { + mockGetActiveSubscriptions.mockResolvedValue(fakeSubscriptions) + const store = useSubscriptionStore() + + await store.fetchActiveSubscriptions() + expect(store.activeSubscriptions).toHaveLength(2) + + store.clear() + + expect(store.activeSubscriptions).toHaveLength(0) + expect(store.hasActiveSubscriptions).toBe(false) + }) + }) + + // --- polling --- + + describe('startPolling / stopPolling', () => { + it('startPolling 不会创建重复 interval', () => { + const store = useSubscriptionStore() + mockGetActiveSubscriptions.mockResolvedValue([]) + + store.startPolling() + store.startPolling() // 重复调用 + + // 推进5分钟只触发一次 + vi.advanceTimersByTime(5 * 60 * 1000) + expect(mockGetActiveSubscriptions).toHaveBeenCalledTimes(1) + + store.stopPolling() + }) + + it('stopPolling 停止定期刷新', () => { + const store = useSubscriptionStore() + mockGetActiveSubscriptions.mockResolvedValue([]) + + store.startPolling() + store.stopPolling() + + vi.advanceTimersByTime(10 * 60 * 1000) + expect(mockGetActiveSubscriptions).not.toHaveBeenCalled() + }) + }) +}) diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts index 0b20cb60..1007f6ed 100644 --- a/frontend/vitest.config.ts +++ b/frontend/vitest.config.ts @@ -1,35 +1,44 @@ -import { defineConfig, mergeConfig } from 'vitest/config' -import viteConfig from './vite.config' +import { defineConfig } from 'vitest/config' +import vue from '@vitejs/plugin-vue' +import { resolve } from 'path' -export default mergeConfig( - viteConfig, - defineConfig({ - test: { - globals: true, - environment: 'jsdom', - include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'], - exclude: ['node_modules', 'dist'], - coverage: { - provider: 'v8', - reporter: ['text', 'json', 'html'], - include: ['src/**/*.{js,ts,vue}'], - exclude: [ - 'node_modules', - 'src/**/*.d.ts', - 'src/**/*.spec.ts', - 'src/**/*.test.ts', - 'src/main.ts' - ], - thresholds: { - global: { - statements: 80, - branches: 80, - functions: 80, - lines: 80 - } - } - }, - setupFiles: ['./src/__tests__/setup.ts'] +export default defineConfig({ + plugins: [vue()], + resolve: { + alias: { + '@': resolve(__dirname, 'src'), + 'vue-i18n': 'vue-i18n/dist/vue-i18n.runtime.esm-bundler.js' } - }) -) + }, + define: { + __INTLIFY_JIT_COMPILATION__: true + }, + test: { + globals: true, + environment: 'jsdom', + include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'], + exclude: ['node_modules', 'dist'], + coverage: { + provider: 'v8', + reporter: ['text', 'json', 'html'], + include: ['src/**/*.{js,ts,vue}'], + exclude: [ + 'node_modules', + 'src/**/*.d.ts', + 'src/**/*.spec.ts', + 'src/**/*.test.ts', + 'src/main.ts' + ], + thresholds: { + global: { + statements: 80, + branches: 80, + functions: 80, + lines: 80 + } + } + }, + setupFiles: ['./src/__tests__/setup.ts'], + testTimeout: 10000 + } +}) From 9da80e9fda18a3888d4963c475b2de54e9d9977b Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 8 Feb 2026 12:13:29 +0800 Subject: [PATCH 057/363] feat: update --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 48172982..bfa6bb1b 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,6 @@ deploy/docker-compose.override.yml .gocache/ vite.config.js docs/* -.serena/ \ No newline at end of file +.serena/ + +frontend/coverage \ No newline at end of file From fb58560d15fa34d2fc14b89f301e946e039861e7 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:06:25 +0800 Subject: [PATCH 058/363] refactor(upstream): replace upstream account type with apikey, auto-append /antigravity Upstream accounts now use the standard APIKey type instead of a dedicated upstream type. GetBaseURL() and new GetGeminiBaseURL() automatically append /antigravity for Antigravity platform APIKey accounts, eliminating the need for separate upstream forwarding methods. - Remove ForwardUpstream, ForwardUpstreamGemini, testUpstreamConnection - Remove upstream branch guards in Forward/ForwardGemini/TestConnection - Add migration 052 to convert existing upstream accounts to apikey - Update frontend CreateAccountModal to create apikey type - Add unit tests for GetBaseURL and GetGeminiBaseURL --- backend/internal/handler/gateway_handler.go | 2 +- .../internal/handler/gemini_v1beta_handler.go | 2 +- backend/internal/service/account.go | 16 + .../internal/service/account_base_url_test.go | 160 ++++++++ .../service/antigravity_gateway_service.go | 386 ------------------ .../service/gemini_messages_compat_service.go | 25 +- .../upstream_header_passthrough_test.go | 285 ------------- .../052_migrate_upstream_to_apikey.sql | 11 + .../components/account/CreateAccountModal.vue | 6 +- 9 files changed, 197 insertions(+), 696 deletions(-) create mode 100644 backend/internal/service/account_base_url_test.go delete mode 100644 backend/internal/service/upstream_header_passthrough_test.go create mode 100644 backend/migrations/052_migrate_upstream_to_apikey.sql diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ca4442e4..255d3fab 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -482,7 +482,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index b1477ac6..2b69be2e 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -410,7 +410,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) } else { result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a6ae8a68..138d5bcb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string { if baseURL == "" { return "https://api.anthropic.com" } + if a.Platform == PlatformAntigravity { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } + return baseURL +} + +// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。 +// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。 +func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string { + baseURL := strings.TrimSpace(a.GetCredential("base_url")) + if baseURL == "" { + return defaultBaseURL + } + if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } return baseURL } diff --git a/backend/internal/service/account_base_url_test.go b/backend/internal/service/account_base_url_test.go new file mode 100644 index 00000000..a1322193 --- /dev/null +++ b/backend/internal/service/account_base_url_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + account Account + expected string + }{ + { + name: "non-apikey type returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAnthropic, + }, + expected: "", + }, + { + name: "apikey without base_url returns default anthropic", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{}, + }, + expected: "https://api.anthropic.com", + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{"base_url": "https://custom.example.com"}, + }, + expected: "https://custom.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash before appending", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity non-apikey returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetBaseURL() + if result != tt.expected { + t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetGeminiBaseURL(t *testing.T) { + const defaultGeminiURL = "https://generativelanguage.googleapis.com" + + tests := []struct { + name string + account Account + expected string + }{ + { + name: "apikey without base_url returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"}, + }, + expected: "https://custom-gemini.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity oauth does NOT append /antigravity", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com", + }, + { + name: "oauth without base_url returns default", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "nil credentials returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + expected: defaultGeminiURL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetGeminiBaseURL(defaultGeminiURL) + if result != tt.expected { + t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2d96b1ab..4ea73e64 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -665,9 +665,6 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - if account.Type == AccountTypeUpstream { - return s.testUpstreamConnection(ctx, account, modelID) - } // 获取 token if s.tokenProvider == nil { @@ -986,10 +983,6 @@ func isModelNotFoundError(statusCode int, body []byte) bool { func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstream(ctx, c, account, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1610,10 +1603,6 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() - if account.Type == AccountTypeUpstream { - return s.ForwardUpstreamGemini(ctx, c, account, originalModel, action, stream, body, isStickySession) - } - sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -3361,378 +3350,3 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } - -// --------------------------------------------------------------------------- -// Upstream 专用转发方法 -// upstream 账号直接连接上游 Anthropic/Gemini 兼容端点,不走 Antigravity OAuth 协议转换。 -// --------------------------------------------------------------------------- - -// testUpstreamConnection 测试 upstream 账号连接 -func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, errors.New("upstream account missing base_url in credentials") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, errors.New("upstream account missing api_key in credentials") - } - - mappedModel := s.getMappedModel(account, modelID) - if mappedModel == "" { - return nil, fmt.Errorf("model %s not in whitelist", modelID) - } - - // 构建最小 Claude 格式请求 - requestBody, _ := json.Marshal(map[string]any{ - "model": mappedModel, - "max_tokens": 1, - "messages": []map[string]any{ - {"role": "user", "content": "."}, - }, - "stream": false, - }) - - apiURL := baseURL + "/antigravity/v1/messages" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("构建请求失败: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - req.Header.Set("anthropic-version", "2023-06-01") - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, apiURL) - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) - } - - // 从 Claude 格式非流式响应中提取文本 - var claudeResp struct { - Content []struct { - Text string `json:"text"` - } `json:"content"` - } - text := "" - if json.Unmarshal(respBody, &claudeResp) == nil && len(claudeResp.Content) > 0 { - text = claudeResp.Content[0].Text - } - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil -} - -// ForwardUpstream 转发 Claude 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "api_error", "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "authentication_error", "Upstream account missing api_key") - } - - // 解析请求以获取模型和流式标志 - var claudeReq antigravity.ClaudeRequest - if err := json.Unmarshal(body, &claudeReq); err != nil { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid request body") - } - if strings.TrimSpace(claudeReq.Model) == "" { - return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing model") - } - - originalModel := claudeReq.Model - mappedModel := s.getMappedModel(account, claudeReq.Model) - if mappedModel == "" { - return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - apiURL := baseURL + "/antigravity/v1/messages" - log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - - return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - }, nil -} - -// ForwardUpstreamGemini 转发 Gemini 协议请求到 upstream(不做协议转换) -func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") - if baseURL == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing base_url") - } - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if apiKey == "" { - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream account missing api_key") - } - - if strings.TrimSpace(originalModel) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") - } - if strings.TrimSpace(action) == "" { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL") - } - if len(body) == 0 { - return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") - } - - imageSize := s.extractImageSize(body) - - switch action { - case "generateContent", "streamGenerateContent": - // ok - case "countTokens": - c.JSON(http.StatusOK, map[string]any{"totalTokens": 0}) - return &ForwardResult{ - RequestID: "", - Usage: ClaudeUsage{}, - Model: originalModel, - Stream: false, - Duration: time.Since(time.Now()), - FirstTokenMs: nil, - }, nil - default: - return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) - } - - mappedModel := s.getMappedModel(account, originalModel) - if mappedModel == "" { - return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 统计模型调用次数 - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - - // 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION - apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action) - if stream || action == "streamGenerateContent" { - apiURL += "?alt=sse" - } - - log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action) - - // 构建请求:body 原样透传 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body)) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request") - } - // 透传客户端所有请求头(排除 hop-by-hop 和认证头) - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - // 覆盖认证头 - req.Header.Set("Authorization", "Bearer "+apiKey) - - if c != nil && len(body) > 0 { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - // 单次发送,不重试 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err)) - } - defer func() { _ = resp.Body.Close() }() - - // 错误响应处理 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - contentType := resp.Header.Get("Content-Type") - - requestID := resp.Header.Get("x-request-id") - if requestID != "" { - c.Header("x-request-id", requestID) - } - - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession) - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := s.getUpstreamErrorDetail(respBody) - - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - - if s.shouldFailoverUpstreamError(resp.StatusCode) { - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - if contentType == "" { - contentType = "application/json" - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("[antigravity-Forward-Upstream] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500)) - c.Data(resp.StatusCode, contentType, respBody) - return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) - } - - // 成功响应:透传 response header + body - requestID := resp.Header.Get("x-request-id") - - // 透传上游响应头(排除 hop-by-hop) - for key, values := range resp.Header { - if upstreamHopByHopHeaders[strings.ToLower(key)] { - continue - } - for _, v := range values { - c.Header(key, v) - } - } - - c.Status(resp.StatusCode) - _, copyErr := io.Copy(c.Writer, resp.Body) - if copyErr != nil { - log.Printf("%s status=copy_error error=%v", prefix, copyErr) - } - - imageCount := 0 - if isImageGenerationModel(mappedModel) { - imageCount = 1 - } - - return &ForwardResult{ - RequestID: requestID, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - ImageCount: imageCount, - ImageSize: imageSize, - }, nil -} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 0f156c2e..4e0442fd 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1026,10 +1020,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1097,10 +1088,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -2420,10 +2408,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, errors.New("invalid path") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err diff --git a/backend/internal/service/upstream_header_passthrough_test.go b/backend/internal/service/upstream_header_passthrough_test.go deleted file mode 100644 index 51d8588b..00000000 --- a/backend/internal/service/upstream_header_passthrough_test.go +++ /dev/null @@ -1,285 +0,0 @@ -//go:build unit - -package service - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -// httpUpstreamCapture captures the outgoing *http.Request for assertion. -type httpUpstreamCapture struct { - capturedReq *http.Request - resp *http.Response - err error -} - -func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { - s.capturedReq = req - return s.resp, s.err -} - -func newUpstreamAccount() *Account { - return &Account{ - ID: 100, - Name: "upstream-test", - Platform: PlatformAntigravity, - Type: AccountTypeUpstream, - Status: StatusActive, - Concurrency: 1, - Credentials: map[string]any{ - "base_url": "https://upstream.example.com", - "api_key": "sk-upstream-secret", - }, - } -} - -// makeSSEOKResponse builds a minimal SSE response that -// handleClaudeStreamingResponse / handleGeminiStreamingResponse -// can consume without error. -// We return 502 to bypass streaming and hit the error branch instead, -// which is sufficient for testing header passthrough. -func makeUpstreamErrorResponse() *http.Response { - body := []byte(`{"error":{"message":"test error"}}`) - return &http.Response{ - StatusCode: http.StatusBadGateway, - Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(bytes.NewReader(body)), - } -} - -// --- ForwardUpstream tests --- - -func TestForwardUpstream_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("anthropic-version", "2024-10-22") - req.Header.Set("anthropic-beta", "output-128k-2025-02-19") - req.Header.Set("X-Custom-Header", "custom-value") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version")) - require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta")) - require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header")) -} - -func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - // 客户端发来的认证头应被覆盖 - req.Header.Set("Authorization", "Bearer client-token") - req.Header.Set("x-api-key", "client-api-key") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key,而非客户端的 - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) - require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key")) -} - -func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{{"role": "user", "content": "hi"}}, - "max_tokens": 1, - "stream": false, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Keep-Alive", "timeout=5") - req.Header.Set("Transfer-Encoding", "chunked") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Te", "trailers") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Keep-Alive")) - require.Empty(t, captured.Header.Get("Transfer-Encoding")) - require.Empty(t, captured.Header.Get("Upgrade")) - require.Empty(t, captured.Header.Get("Te")) - - // 但普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} - -// --- ForwardUpstreamGemini tests --- - -func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Custom-Gemini", "gemini-value") - req.Header.Set("X-Request-Id", "req-abc-123") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured, "upstream request should have been made") - - // 客户端 header 应被透传 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) - require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini")) - require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id")) -} - -func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer client-gemini-token") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // 认证头应使用上游账号的 api_key - require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization")) -} - -func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - - body, _ := json.Marshal(map[string]any{ - "contents": []map[string]any{ - {"role": "user", "parts": []map[string]any{{"text": "hi"}}}, - }, - }) - - req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") - req.Header.Set("Host", "evil.example.com") - c.Request = req - - stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()} - svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: stub, - } - - _, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false) - - captured := stub.capturedReq - require.NotNil(t, captured) - - // hop-by-hop header 不应出现 - require.Empty(t, captured.Header.Get("Connection")) - require.Empty(t, captured.Header.Get("Proxy-Authorization")) - // Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传 - require.Empty(t, captured.Header.Values("Host")) - - // 普通 header 应保留 - require.Equal(t, "application/json", captured.Header.Get("Content-Type")) -} diff --git a/backend/migrations/052_migrate_upstream_to_apikey.sql b/backend/migrations/052_migrate_upstream_to_apikey.sql new file mode 100644 index 00000000..974f3f3c --- /dev/null +++ b/backend/migrations/052_migrate_upstream_to_apikey.sql @@ -0,0 +1,11 @@ +-- Migrate upstream accounts to apikey type +-- Background: upstream type is no longer needed. Antigravity platform APIKey accounts +-- with base_url pointing to an upstream sub2api instance can reuse the standard +-- APIKey forwarding path. GetBaseURL()/GetGeminiBaseURL() automatically appends +-- /antigravity for Antigravity platform APIKey accounts. + +UPDATE accounts +SET type = 'apikey' +WHERE type = 'upstream' + AND platform = 'antigravity' + AND deleted_at IS NULL; diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 7d759be1..603941c1 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2289,9 +2289,9 @@ watch( watch( [accountCategory, addMethod, antigravityAccountType], ([category, method, agType]) => { - // Antigravity upstream 类型 + // Antigravity upstream 类型(实际创建为 apikey) if (form.platform === 'antigravity' && agType === 'upstream') { - form.type = 'upstream' + form.type = 'apikey' return } if (category === 'oauth-based') { @@ -2715,7 +2715,7 @@ const handleSubmit = async () => { submitting.value = true try { const extra = mixedScheduling.value ? { mixed_scheduling: true } : undefined - await createAccountAndFinish(form.platform, 'upstream', credentials, extra) + await createAccountAndFinish(form.platform, 'apikey', credentials, extra) } catch (error: any) { appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) } finally { From 3c936441469d9483bd02c2681fcfbea9fa271f9a Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:14:58 +0800 Subject: [PATCH 059/363] chore: bump version to 0.1.74.7 --- backend/cmd/server/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index f0768f09..bc88be6e 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.70 +0.1.74.7 From b4ec65785d9fbf525d9cc0202663c830be1a6791 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 8 Feb 2026 13:26:28 +0800 Subject: [PATCH 060/363] =?UTF-8?q?fix:=20apikey=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E8=B4=A6=E5=8F=B7test=E5=8E=BB=E6=8E=89oauth-2025-04-20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/account_test_service.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 3290fe52..899a4498 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -245,7 +245,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set common headers req.Header.Set("Content-Type", "application/json") req.Header.Set("anthropic-version", "2023-06-01") - req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) // Apply Claude Code client headers for key, value := range claude.DefaultHeaders { @@ -254,8 +253,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set authentication header if useBearer { + req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) req.Header.Set("Authorization", "Bearer "+authToken) } else { + req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader) req.Header.Set("x-api-key", authToken) } From 69816f8691e9374adfafde596c5b5a34ec96ddaf Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 8 Feb 2026 13:30:39 +0800 Subject: [PATCH 061/363] fix: remove unused upstreamHopByHopHeaders variable to pass golangci-lint --- .../service/antigravity_gateway_service.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 4ea73e64..26b1c530 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -47,21 +47,6 @@ const ( googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" ) -// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头 -var upstreamHopByHopHeaders = map[string]bool{ - "connection": true, - "keep-alive": true, - "proxy-authenticate": true, - "proxy-authorization": true, - "proxy-connection": true, - "te": true, - "trailer": true, - "transfer-encoding": true, - "upgrade": true, - "host": true, - "content-length": true, -} - // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) // 匹配时使用 strings.Contains,无需完全匹配 var antigravityPassthroughErrorMessages = []string{ From b1c30df8e300fb4258352e5384ee8fe3fe3cc240 Mon Sep 17 00:00:00 2001 From: shaw Date: Sun, 8 Feb 2026 14:00:02 +0800 Subject: [PATCH 062/363] fix(ui): unify admin table toolbar layout with search and buttons in single row Standardize filter bar layout across admin pages to place search/filters on left and action buttons on right within the same row, improving visual consistency and space utilization. --- .../src/views/admin/AnnouncementsView.vue | 50 ++++++------- frontend/src/views/admin/PromoCodesView.vue | 50 ++++++------- frontend/src/views/admin/ProxiesView.vue | 74 +++++++++---------- frontend/src/views/admin/RedeemView.vue | 56 +++++++------- frontend/src/views/admin/UsersView.vue | 6 +- 5 files changed, 114 insertions(+), 122 deletions(-) diff --git a/frontend/src/views/admin/AnnouncementsView.vue b/frontend/src/views/admin/AnnouncementsView.vue index 38574454..08d7b871 100644 --- a/frontend/src/views/admin/AnnouncementsView.vue +++ b/frontend/src/views/admin/AnnouncementsView.vue @@ -1,26 +1,10 @@ diff --git a/frontend/src/components/user/profile/__tests__/totp-timer-cleanup.spec.ts b/frontend/src/components/user/profile/__tests__/totp-timer-cleanup.spec.ts new file mode 100644 index 00000000..0259f902 --- /dev/null +++ b/frontend/src/components/user/profile/__tests__/totp-timer-cleanup.spec.ts @@ -0,0 +1,108 @@ +import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import TotpSetupModal from '@/components/user/profile/TotpSetupModal.vue' +import TotpDisableDialog from '@/components/user/profile/TotpDisableDialog.vue' + +const mocks = vi.hoisted(() => ({ + showSuccess: vi.fn(), + showError: vi.fn(), + getVerificationMethod: vi.fn(), + sendVerifyCode: vi.fn() +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key + }) +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showSuccess: mocks.showSuccess, + showError: mocks.showError + }) +})) + +vi.mock('@/api', () => ({ + totpAPI: { + getVerificationMethod: mocks.getVerificationMethod, + sendVerifyCode: mocks.sendVerifyCode, + initiateSetup: vi.fn(), + enable: vi.fn(), + disable: vi.fn() + } +})) + +const flushPromises = async () => { + await Promise.resolve() + await Promise.resolve() +} + +describe('TOTP 弹窗定时器清理', () => { + let intervalSeed = 1000 + let setIntervalSpy: ReturnType + let clearIntervalSpy: ReturnType + + beforeEach(() => { + intervalSeed = 1000 + mocks.showSuccess.mockReset() + mocks.showError.mockReset() + mocks.getVerificationMethod.mockReset() + mocks.sendVerifyCode.mockReset() + + mocks.getVerificationMethod.mockResolvedValue({ method: 'email' }) + mocks.sendVerifyCode.mockResolvedValue({ success: true }) + + setIntervalSpy = vi.spyOn(window, 'setInterval').mockImplementation(((handler: TimerHandler) => { + void handler + intervalSeed += 1 + return intervalSeed as unknown as number + }) as typeof window.setInterval) + clearIntervalSpy = vi.spyOn(window, 'clearInterval') + }) + + afterEach(() => { + setIntervalSpy.mockRestore() + clearIntervalSpy.mockRestore() + }) + + it('TotpSetupModal 卸载时清理倒计时定时器', async () => { + const wrapper = mount(TotpSetupModal) + await flushPromises() + + const sendButton = wrapper + .findAll('button') + .find((button) => button.text().includes('profile.totp.sendCode')) + + expect(sendButton).toBeTruthy() + await sendButton!.trigger('click') + await flushPromises() + + expect(setIntervalSpy).toHaveBeenCalledTimes(1) + const timerId = setIntervalSpy.mock.results[0]?.value + + wrapper.unmount() + + expect(clearIntervalSpy).toHaveBeenCalledWith(timerId) + }) + + it('TotpDisableDialog 卸载时清理倒计时定时器', async () => { + const wrapper = mount(TotpDisableDialog) + await flushPromises() + + const sendButton = wrapper + .findAll('button') + .find((button) => button.text().includes('profile.totp.sendCode')) + + expect(sendButton).toBeTruthy() + await sendButton!.trigger('click') + await flushPromises() + + expect(setIntervalSpy).toHaveBeenCalledTimes(1) + const timerId = setIntervalSpy.mock.results[0]?.value + + wrapper.unmount() + + expect(clearIntervalSpy).toHaveBeenCalledWith(timerId) + }) +}) diff --git a/frontend/src/composables/__tests__/useKeyedDebouncedSearch.spec.ts b/frontend/src/composables/__tests__/useKeyedDebouncedSearch.spec.ts new file mode 100644 index 00000000..4866746a --- /dev/null +++ b/frontend/src/composables/__tests__/useKeyedDebouncedSearch.spec.ts @@ -0,0 +1,100 @@ +import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest' +import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch' + +const flushPromises = () => Promise.resolve() + +describe('useKeyedDebouncedSearch', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('为不同 key 独立防抖触发搜索', async () => { + const search = vi.fn().mockResolvedValue([]) + const onSuccess = vi.fn() + + const searcher = useKeyedDebouncedSearch({ + delay: 100, + search, + onSuccess + }) + + searcher.trigger('a', 'foo') + searcher.trigger('b', 'bar') + + expect(search).not.toHaveBeenCalled() + + vi.advanceTimersByTime(100) + await flushPromises() + + expect(search).toHaveBeenCalledTimes(2) + expect(search).toHaveBeenNthCalledWith( + 1, + 'foo', + expect.objectContaining({ key: 'a', signal: expect.any(AbortSignal) }) + ) + expect(search).toHaveBeenNthCalledWith( + 2, + 'bar', + expect.objectContaining({ key: 'b', signal: expect.any(AbortSignal) }) + ) + expect(onSuccess).toHaveBeenCalledTimes(2) + }) + + it('同 key 新请求会取消旧请求并忽略过期响应', async () => { + const resolves: Array<(value: string[]) => void> = [] + const search = vi.fn().mockImplementation( + () => new Promise((resolve) => { + resolves.push(resolve) + }) + ) + const onSuccess = vi.fn() + + const searcher = useKeyedDebouncedSearch({ + delay: 50, + search, + onSuccess + }) + + searcher.trigger('rule-1', 'first') + vi.advanceTimersByTime(50) + await flushPromises() + + searcher.trigger('rule-1', 'second') + vi.advanceTimersByTime(50) + await flushPromises() + + expect(search).toHaveBeenCalledTimes(2) + + resolves[1](['second']) + await flushPromises() + expect(onSuccess).toHaveBeenCalledTimes(1) + expect(onSuccess).toHaveBeenLastCalledWith('rule-1', ['second']) + + resolves[0](['first']) + await flushPromises() + expect(onSuccess).toHaveBeenCalledTimes(1) + }) + + it('clearKey 会取消未执行任务', () => { + const search = vi.fn().mockResolvedValue([]) + const onSuccess = vi.fn() + + const searcher = useKeyedDebouncedSearch({ + delay: 100, + search, + onSuccess + }) + + searcher.trigger('a', 'foo') + searcher.clearKey('a') + + vi.advanceTimersByTime(100) + + expect(search).not.toHaveBeenCalled() + expect(onSuccess).not.toHaveBeenCalled() + }) +}) diff --git a/frontend/src/composables/useKeyedDebouncedSearch.ts b/frontend/src/composables/useKeyedDebouncedSearch.ts new file mode 100644 index 00000000..81133c38 --- /dev/null +++ b/frontend/src/composables/useKeyedDebouncedSearch.ts @@ -0,0 +1,103 @@ +import { getCurrentInstance, onUnmounted } from 'vue' + +export interface KeyedDebouncedSearchContext { + key: string + signal: AbortSignal +} + +interface UseKeyedDebouncedSearchOptions { + delay?: number + search: (keyword: string, context: KeyedDebouncedSearchContext) => Promise + onSuccess: (key: string, result: T) => void + onError?: (key: string, error: unknown) => void +} + +/** + * 多实例隔离的防抖搜索:每个 key 有独立的防抖、请求取消与过期响应保护。 + */ +export function useKeyedDebouncedSearch(options: UseKeyedDebouncedSearchOptions) { + const delay = options.delay ?? 300 + const timers = new Map>() + const controllers = new Map() + const versions = new Map() + + const clearKey = (key: string) => { + const timer = timers.get(key) + if (timer) { + clearTimeout(timer) + timers.delete(key) + } + + const controller = controllers.get(key) + if (controller) { + controller.abort() + controllers.delete(key) + } + + versions.delete(key) + } + + const clearAll = () => { + const allKeys = new Set([ + ...timers.keys(), + ...controllers.keys(), + ...versions.keys() + ]) + + allKeys.forEach((key) => clearKey(key)) + } + + const trigger = (key: string, keyword: string) => { + const nextVersion = (versions.get(key) ?? 0) + 1 + versions.set(key, nextVersion) + + const existingTimer = timers.get(key) + if (existingTimer) { + clearTimeout(existingTimer) + timers.delete(key) + } + + const inFlight = controllers.get(key) + if (inFlight) { + inFlight.abort() + controllers.delete(key) + } + + const timer = setTimeout(async () => { + timers.delete(key) + + const controller = new AbortController() + controllers.set(key, controller) + const requestVersion = versions.get(key) + + try { + const result = await options.search(keyword, { key, signal: controller.signal }) + if (controller.signal.aborted) return + if (versions.get(key) !== requestVersion) return + options.onSuccess(key, result) + } catch (error) { + if (controller.signal.aborted) return + if (versions.get(key) !== requestVersion) return + options.onError?.(key, error) + } finally { + if (controllers.get(key) === controller) { + controllers.delete(key) + } + } + }, delay) + + timers.set(key, timer) + } + + if (getCurrentInstance()) { + onUnmounted(() => { + clearAll() + }) + } + + return { + trigger, + clearKey, + clearAll + } +} diff --git a/frontend/src/i18n/index.ts b/frontend/src/i18n/index.ts index 486fb3bc..00e34dc2 100644 --- a/frontend/src/i18n/index.ts +++ b/frontend/src/i18n/index.ts @@ -1,53 +1,83 @@ import { createI18n } from 'vue-i18n' -import en from './locales/en' -import zh from './locales/zh' + +type LocaleCode = 'en' | 'zh' + +type LocaleMessages = Record const LOCALE_KEY = 'sub2api_locale' +const DEFAULT_LOCALE: LocaleCode = 'en' -function getDefaultLocale(): string { - // Check localStorage first +const localeLoaders: Record Promise<{ default: LocaleMessages }>> = { + en: () => import('./locales/en'), + zh: () => import('./locales/zh') +} + +function isLocaleCode(value: string): value is LocaleCode { + return value === 'en' || value === 'zh' +} + +function getDefaultLocale(): LocaleCode { const saved = localStorage.getItem(LOCALE_KEY) - if (saved && ['en', 'zh'].includes(saved)) { + if (saved && isLocaleCode(saved)) { return saved } - // Check browser language const browserLang = navigator.language.toLowerCase() if (browserLang.startsWith('zh')) { return 'zh' } - return 'en' + return DEFAULT_LOCALE } export const i18n = createI18n({ legacy: false, locale: getDefaultLocale(), - fallbackLocale: 'en', - messages: { - en, - zh - }, + fallbackLocale: DEFAULT_LOCALE, + messages: {}, // 禁用 HTML 消息警告 - 引导步骤使用富文本内容(driver.js 支持 HTML) // 这些内容是内部定义的,不存在 XSS 风险 warnHtmlMessage: false }) -export function setLocale(locale: string) { - if (['en', 'zh'].includes(locale)) { - i18n.global.locale.value = locale as 'en' | 'zh' - localStorage.setItem(LOCALE_KEY, locale) - document.documentElement.setAttribute('lang', locale) +const loadedLocales = new Set() + +export async function loadLocaleMessages(locale: LocaleCode): Promise { + if (loadedLocales.has(locale)) { + return } + + const loader = localeLoaders[locale] + const module = await loader() + i18n.global.setLocaleMessage(locale, module.default) + loadedLocales.add(locale) } -export function getLocale(): string { - return i18n.global.locale.value +export async function initI18n(): Promise { + const current = getLocale() + await loadLocaleMessages(current) + document.documentElement.setAttribute('lang', current) +} + +export async function setLocale(locale: string): Promise { + if (!isLocaleCode(locale)) { + return + } + + await loadLocaleMessages(locale) + i18n.global.locale.value = locale + localStorage.setItem(LOCALE_KEY, locale) + document.documentElement.setAttribute('lang', locale) +} + +export function getLocale(): LocaleCode { + const current = i18n.global.locale.value + return isLocaleCode(current) ? current : DEFAULT_LOCALE } export const availableLocales = [ { code: 'en', name: 'English', flag: '🇺🇸' }, { code: 'zh', name: '中文', flag: '🇨🇳' } -] +] as const export default i18n diff --git a/frontend/src/main.ts b/frontend/src/main.ts index 11c0b1e8..23f9d297 100644 --- a/frontend/src/main.ts +++ b/frontend/src/main.ts @@ -2,28 +2,33 @@ import { createApp } from 'vue' import { createPinia } from 'pinia' import App from './App.vue' import router from './router' -import i18n from './i18n' +import i18n, { initI18n } from './i18n' +import { useAppStore } from '@/stores/app' import './style.css' -const app = createApp(App) -const pinia = createPinia() -app.use(pinia) +async function bootstrap() { + const app = createApp(App) + const pinia = createPinia() + app.use(pinia) -// Initialize settings from injected config BEFORE mounting (prevents flash) -// This must happen after pinia is installed but before router and i18n -import { useAppStore } from '@/stores/app' -const appStore = useAppStore() -appStore.initFromInjectedConfig() + // Initialize settings from injected config BEFORE mounting (prevents flash) + // This must happen after pinia is installed but before router and i18n + const appStore = useAppStore() + appStore.initFromInjectedConfig() -// Set document title immediately after config is loaded -if (appStore.siteName && appStore.siteName !== 'Sub2API') { - document.title = `${appStore.siteName} - AI API Gateway` + // Set document title immediately after config is loaded + if (appStore.siteName && appStore.siteName !== 'Sub2API') { + document.title = `${appStore.siteName} - AI API Gateway` + } + + await initI18n() + + app.use(router) + app.use(i18n) + + // 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染 + await router.isReady() + app.mount('#app') } -app.use(router) -app.use(i18n) - -// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染 -router.isReady().then(() => { - app.mount('#app') -}) +bootstrap() diff --git a/frontend/src/router/__tests__/title.spec.ts b/frontend/src/router/__tests__/title.spec.ts new file mode 100644 index 00000000..3a892837 --- /dev/null +++ b/frontend/src/router/__tests__/title.spec.ts @@ -0,0 +1,25 @@ +import { describe, expect, it } from 'vitest' +import { resolveDocumentTitle } from '@/router/title' + +describe('resolveDocumentTitle', () => { + it('路由存在标题时,使用“路由标题 - 站点名”格式', () => { + expect(resolveDocumentTitle('Usage Records', 'My Site')).toBe('Usage Records - My Site') + }) + + it('路由无标题时,回退到站点名', () => { + expect(resolveDocumentTitle(undefined, 'My Site')).toBe('My Site') + }) + + it('站点名为空时,回退默认站点名', () => { + expect(resolveDocumentTitle('Dashboard', '')).toBe('Dashboard - Sub2API') + expect(resolveDocumentTitle(undefined, ' ')).toBe('Sub2API') + }) + + it('站点名变更时仅影响后续路由标题计算', () => { + const before = resolveDocumentTitle('Admin Dashboard', 'Alpha') + const after = resolveDocumentTitle('Admin Dashboard', 'Beta') + + expect(before).toBe('Admin Dashboard - Alpha') + expect(after).toBe('Admin Dashboard - Beta') + }) +}) diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 4bb46cee..1a67cac6 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -8,6 +8,7 @@ import { useAuthStore } from '@/stores/auth' import { useAppStore } from '@/stores/app' import { useNavigationLoadingState } from '@/composables/useNavigationLoading' import { useRoutePrefetch } from '@/composables/useRoutePrefetch' +import { resolveDocumentTitle } from './title' /** * Route definitions with lazy loading @@ -389,12 +390,7 @@ router.beforeEach((to, _from, next) => { // Set page title const appStore = useAppStore() - const siteName = appStore.siteName || 'Sub2API' - if (to.meta.title) { - document.title = `${to.meta.title} - ${siteName}` - } else { - document.title = siteName - } + document.title = resolveDocumentTitle(to.meta.title, appStore.siteName) // Check if route requires authentication const requiresAuth = to.meta.requiresAuth !== false // Default to true diff --git a/frontend/src/router/title.ts b/frontend/src/router/title.ts new file mode 100644 index 00000000..e0db24b0 --- /dev/null +++ b/frontend/src/router/title.ts @@ -0,0 +1,12 @@ +/** + * 统一生成页面标题,避免多处写入 document.title 产生覆盖冲突。 + */ +export function resolveDocumentTitle(routeTitle: unknown, siteName?: string): string { + const normalizedSiteName = typeof siteName === 'string' && siteName.trim() ? siteName.trim() : 'Sub2API' + + if (typeof routeTitle === 'string' && routeTitle.trim()) { + return `${routeTitle.trim()} - ${normalizedSiteName}` + } + + return normalizedSiteName +} diff --git a/frontend/src/utils/__tests__/stableObjectKey.spec.ts b/frontend/src/utils/__tests__/stableObjectKey.spec.ts new file mode 100644 index 00000000..5a6f99f4 --- /dev/null +++ b/frontend/src/utils/__tests__/stableObjectKey.spec.ts @@ -0,0 +1,37 @@ +import { describe, expect, it } from 'vitest' +import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' + +describe('createStableObjectKeyResolver', () => { + it('对同一对象返回稳定 key', () => { + const resolve = createStableObjectKeyResolver<{ value: string }>('rule') + const obj = { value: 'a' } + + const key1 = resolve(obj) + const key2 = resolve(obj) + + expect(key1).toBe(key2) + expect(key1.startsWith('rule-')).toBe(true) + }) + + it('不同对象返回不同 key', () => { + const resolve = createStableObjectKeyResolver<{ value: string }>('rule') + + const key1 = resolve({ value: 'a' }) + const key2 = resolve({ value: 'a' }) + + expect(key1).not.toBe(key2) + }) + + it('不同 resolver 互不影响', () => { + const resolveA = createStableObjectKeyResolver<{ id: number }>('a') + const resolveB = createStableObjectKeyResolver<{ id: number }>('b') + const obj = { id: 1 } + + const keyA = resolveA(obj) + const keyB = resolveB(obj) + + expect(keyA).not.toBe(keyB) + expect(keyA.startsWith('a-')).toBe(true) + expect(keyB.startsWith('b-')).toBe(true) + }) +}) diff --git a/frontend/src/utils/stableObjectKey.ts b/frontend/src/utils/stableObjectKey.ts new file mode 100644 index 00000000..a61414f0 --- /dev/null +++ b/frontend/src/utils/stableObjectKey.ts @@ -0,0 +1,19 @@ +let globalStableObjectKeySeed = 0 + +/** + * 为对象实例生成稳定 key(基于 WeakMap,不污染业务对象) + */ +export function createStableObjectKeyResolver(prefix = 'item') { + const keyMap = new WeakMap() + + return (item: T): string => { + const cached = keyMap.get(item) + if (cached) { + return cached + } + + const key = `${prefix}-${++globalStableObjectKeySeed}` + keyMap.set(item, key) + return key + } +} diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index c6d15e2d..4d6dccf6 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -759,8 +759,8 @@
@@ -786,7 +786,7 @@ {{ account.name }}
+
+ {{ t('admin.accounts.listPendingSyncHint') }} + +
`, "cf-ray", "9cff2d62d83bb98d"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") + body := rec.Body.String() + require.Contains(t, body, `"type":"error"`) + require.Contains(t, body, "Cloudflare challenge") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + body := rec.Body.String() + require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") + require.Contains(t, body, `"type":"test_complete","success":true`) +} diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index e247e654..6f6261d8 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -14,6 +14,7 @@ import ( type OpenAIOAuthClient interface { ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) + RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index ca7470b9..087ad4ec 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -2,13 +2,20 @@ package service import ( "context" + "crypto/subtle" + "encoding/json" + "io" "net/http" + "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 type OpenAIExchangeCodeInput struct { SessionID string Code string + State string RedirectURI string ProxyID *int64 } @@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if !ok { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") } + if input.State == "" { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required") + } + if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state") + } // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL proxyURL := session.ProxyURL @@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // RefreshToken refreshes an OpenAI OAuth token func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { - tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { + tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { return nil, err } @@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return tokenInfo, nil } -// RefreshAccountToken refreshes token for an OpenAI account -func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if !account.IsOpenAI() { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") +// ExchangeSoraSessionToken exchanges Sora session_token to access_token. +func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + if strings.TrimSpace(sessionToken) == "" { + return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } - refreshToken := account.GetOpenAIRefreshToken() + proxyURL, err := s.resolveProxyURL(ctx, proxyID) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) + if err != nil { + return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + + client := newOpenAIOAuthHTTPClient(proxyURL) + resp, err := client.Do(req) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if resp.StatusCode != http.StatusOK { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var sessionResp struct { + AccessToken string `json:"accessToken"` + Expires string `json:"expires"` + User struct { + Email string `json:"email"` + Name string `json:"name"` + } `json:"user"` + } + if err := json.Unmarshal(body, &sessionResp); err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) + } + if strings.TrimSpace(sessionResp.AccessToken) == "" { + return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") + } + + expiresAt := time.Now().Add(time.Hour).Unix() + if strings.TrimSpace(sessionResp.Expires) != "" { + if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { + expiresAt = parsed.Unix() + } + } + expiresIn := expiresAt - time.Now().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + return &OpenAITokenInfo{ + AccessToken: strings.TrimSpace(sessionResp.AccessToken), + ExpiresIn: expiresIn, + ExpiresAt: expiresAt, + Email: strings.TrimSpace(sessionResp.User.Email), + }, nil +} + +// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { + if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + } + if account.Type != AccountTypeOAuth { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } @@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A } } - return s.RefreshToken(ctx, refreshToken, proxyURL) + clientID := account.GetCredential("client_id") + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } // BuildAccountCredentials builds credentials map from token info @@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } + +func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { + if proxyID == nil { + return "", nil + } + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err != nil { + return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) + } + if proxy == nil { + return "", nil + } + return proxy.URL(), nil +} + +func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { + transport := &http.Transport{} + if strings.TrimSpace(proxyURL) != "" { + if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { + transport.Proxy = http.ProxyURL(parsed) + } + } + return &http.Client{ + Timeout: 120 * time.Second, + Transport: transport, + } +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go new file mode 100644 index 00000000..fb76f6c1 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientNoopStub struct{} + +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at-token", info.AccessToken) + require.Equal(t, "demo@example.com", info.Email) + require.Greater(t, info.ExpiresAt, int64(0)) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "missing access token") +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go new file mode 100644 index 00000000..0a2a195f --- /dev/null +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -0,0 +1,102 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientStateStub struct { + exchangeCalled int32 +} + +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.exchangeCalled, 1) + return &openai.TokenResponse{ + AccessToken: "at", + RefreshToken: "rt", + ExpiresIn: 3600, + }, nil +} + +func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "oauth state is required") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "wrong-state", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid oauth state") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "expected-state", + }) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at", info.AccessToken) + require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) + + _, ok := svc.sessionStore.Get("sid") + require.False(t, ok) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 3842f0a4..a8a6b96c 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) p.metrics.refreshFailure.Add(1) refreshFailed = true // 无法刷新,标记失败 @@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) p.metrics.refreshFailure.Add(1) refreshFailed = true diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index de097d5e..38be7a04 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -17,12 +17,15 @@ import ( "net/textproto" "net/url" "path" + "sort" "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/google/uuid" "github.com/tidwall/gjson" "golang.org/x/crypto/sha3" @@ -34,6 +37,11 @@ const ( soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" ) +var ( + soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + soraOAuthTokenURL = "https://auth.openai.com/oauth/token" +) + const ( soraPowMaxIteration = 500000 ) @@ -96,6 +104,7 @@ type SoraClient interface { UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) } @@ -157,26 +166,94 @@ func (e *SoraUpstreamError) Error() string { // SoraDirectClient 直连 Sora 实现 type SoraDirectClient struct { - cfg *config.Config - httpUpstream HTTPUpstream - tokenProvider *OpenAITokenProvider + cfg *config.Config + httpUpstream HTTPUpstream + tokenProvider *OpenAITokenProvider + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository + baseURL string } // NewSoraDirectClient 创建 Sora 直连客户端 func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { + baseURL := "" + if cfg != nil { + rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/") + baseURL = normalizeSoraBaseURL(rawBaseURL) + if rawBaseURL != "" && baseURL != rawBaseURL { + log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL)) + } + } return &SoraDirectClient{ cfg: cfg, httpUpstream: httpUpstream, tokenProvider: tokenProvider, + baseURL: baseURL, } } +func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) { + if c == nil { + return + } + c.accountRepo = accountRepo + c.soraAccountRepo = soraAccountRepo +} + // Enabled 判断是否启用 Sora 直连 func (c *SoraDirectClient) Enabled() bool { - if c == nil || c.cfg == nil { + if c == nil { return false } - return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != "" + if strings.TrimSpace(c.baseURL) != "" { + return true + } + if c.cfg == nil { + return false + } + return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != "" +} + +// PreflightCheck 在创建任务前执行账号能力预检。 +// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。 +func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error { + if modelCfg.Type != "video" { + return nil + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Accept", "application/json") + body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false) + if err != nil { + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound { + return &SoraUpstreamError{ + StatusCode: http.StatusForbidden, + Message: "当前账号未开通 Sora2 能力或无可用配额", + Headers: upstreamErr.Headers, + Body: upstreamErr.Body, + } + } + return err + } + + rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool() + remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining") + if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) { + msg := "当前账号 Sora2 可用配额不足" + if requestedModel != "" { + msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel) + } + return &SoraUpstreamError{ + StatusCode: http.StatusTooManyRequests, + Message: msg, + Headers: http.Header{}, + } + } + return nil } func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { @@ -347,6 +424,45 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account return taskID, nil } +func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + if strings.TrimSpace(expansionLevel) == "" { + expansionLevel = "medium" + } + if durationS <= 0 { + durationS = 10 + } + + payload := map[string]any{ + "prompt": prompt, + "expansion_level": expansionLevel, + "duration_s": durationS, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", "application/json") + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false) + if err != nil { + return "", err + } + enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String()) + if enhancedPrompt == "" { + return "", errors.New("enhance_prompt response missing enhanced_prompt") + } + return enhancedPrompt, nil +} + func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit()) if err != nil { @@ -512,9 +628,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t } func (c *SoraDirectClient) buildURL(endpoint string) string { - base := "" - if c != nil && c.cfg != nil { - base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/") + base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/") + if base == "" && c != nil && c.cfg != nil { + base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL) + c.baseURL = base } if base == "" { return endpoint @@ -540,14 +657,257 @@ func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) if account == nil { return "", errors.New("account is nil") } - if c.tokenProvider != nil { - return c.tokenProvider.GetAccessToken(ctx, account) + + allowProvider := c.allowOpenAITokenProvider(account) + var providerErr error + if allowProvider && c.tokenProvider != nil { + token, err := c.tokenProvider.GetAccessToken(ctx, account) + if err == nil && strings.TrimSpace(token) != "" { + c.logTokenSource(account, "openai_token_provider") + return token, nil + } + providerErr = err + if err != nil && c.debugEnabled() { + c.debugLogf( + "token_provider_failed account_id=%d platform=%s err=%s", + account.ID, + account.Platform, + logredact.RedactText(err.Error()), + ) + } } token := strings.TrimSpace(account.GetCredential("access_token")) - if token == "" { - return "", errors.New("access_token not found") + if token != "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute { + refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring") + if refreshErr == nil && strings.TrimSpace(refreshed) != "" { + c.logTokenSource(account, "refresh_token_recovered") + return refreshed, nil + } + if refreshErr != nil && c.debugEnabled() { + c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error())) + } + } + c.logTokenSource(account, "account_credentials") + return token, nil } - return token, nil + + recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing") + if recoverErr == nil && strings.TrimSpace(recovered) != "" { + c.logTokenSource(account, "session_or_refresh_recovered") + return recovered, nil + } + if recoverErr != nil && c.debugEnabled() { + c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error())) + } + if providerErr != nil { + return "", providerErr + } + if c.tokenProvider != nil && !allowProvider { + c.logTokenSource(account, "account_credentials(provider_disabled)") + } + return "", errors.New("access_token not found") +} + +func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + + if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" { + accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken) + if err == nil && strings.TrimSpace(accessToken) != "" { + c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken) + c.logTokenRecover(account, "session_token", reason, true, nil) + return accessToken, nil + } + c.logTokenRecover(account, "session_token", reason, false, err) + } + + refreshToken := strings.TrimSpace(account.GetCredential("refresh_token")) + if refreshToken == "" { + return "", errors.New("session_token/refresh_token not found") + } + accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken) + if err != nil { + c.logTokenRecover(account, "refresh_token", reason, false, err) + return "", err + } + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("refreshed access_token is empty") + } + c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "") + c.logTokenRecover(account, "refresh_token", reason, true, nil) + return accessToken, nil +} + +func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) { + headers := http.Header{} + headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken) + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + headers.Set("User-Agent", c.defaultUserAgent()) + body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false) + if err != nil { + return "", "", err + } + accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String()) + if accessToken == "" { + return "", "", errors.New("session exchange missing accessToken") + } + expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String()) + return accessToken, expiresAt, nil +} + +func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) { + clientIDs := []string{ + strings.TrimSpace(account.GetCredential("client_id")), + openaioauth.SoraClientID, + openaioauth.ClientID, + } + tried := make(map[string]struct{}, len(clientIDs)) + var lastErr error + + for _, clientID := range clientIDs { + if clientID == "" { + continue + } + if _, ok := tried[clientID]; ok { + continue + } + tried[clientID] = struct{}{} + + payload := map[string]any{ + "client_id": clientID, + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", + } + bodyBytes, err := json.Marshal(payload) + if err != nil { + return "", "", "", err + } + headers := http.Header{} + headers.Set("Accept", "application/json") + headers.Set("Content-Type", "application/json") + headers.Set("User-Agent", c.defaultUserAgent()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false) + if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error())) + } + continue + } + accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String()) + if accessToken == "" { + lastErr = errors.New("oauth refresh response missing access_token") + continue + } + newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String()) + expiresIn := gjson.GetBytes(respBody, "expires_in").Int() + expiresAt := "" + if expiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) + } + return accessToken, newRefreshToken, expiresAt, nil + } + + if lastErr != nil { + return "", "", "", lastErr + } + return "", "", "", errors.New("no available client_id for refresh_token exchange") +} + +func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) { + if account == nil { + return + } + if account.Credentials == nil { + account.Credentials = make(map[string]any) + } + if strings.TrimSpace(accessToken) != "" { + account.Credentials["access_token"] = accessToken + } + if strings.TrimSpace(refreshToken) != "" { + account.Credentials["refresh_token"] = refreshToken + } + if strings.TrimSpace(expiresAt) != "" { + account.Credentials["expires_at"] = expiresAt + } + if strings.TrimSpace(sessionToken) != "" { + account.Credentials["session_token"] = sessionToken + } + + if c.accountRepo != nil { + if err := c.accountRepo.Update(ctx, account); err != nil { + if c.debugEnabled() { + c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } + } + } + c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken) +} + +func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) { + if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 { + return + } + updates := make(map[string]any) + if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" { + updates["access_token"] = accessToken + updates["refresh_token"] = refreshToken + } + if strings.TrimSpace(sessionToken) != "" { + updates["session_token"] = sessionToken + } + if len(updates) == 0 { + return + } + if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() { + c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } +} + +func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) { + if !c.debugEnabled() || account == nil { + return + } + if success { + c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + if err == nil { + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error())) +} + +func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool { + if c == nil || c.tokenProvider == nil { + return false + } + if account != nil && account.Platform == PlatformSora { + return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider + } + return true +} + +func (c *SoraDirectClient) logTokenSource(account *Account, source string) { + if !c.debugEnabled() || account == nil { + return + } + c.debugLogf( + "token_selected account_id=%d platform=%s account_type=%s source=%s", + account.ID, + account.Platform, + account.Type, + source, + ) } func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { @@ -600,7 +960,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } attempts := maxRetries + 1 + authRecovered := false + authRecoverExtraAttemptGranted := false + var lastErr error for attempt := 1; attempt <= attempts; attempt++ { + if c.debugEnabled() { + c.debugLogf( + "request_start method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t headers=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + timeout, + len(bodyBytes), + account != nil && account.ProxyID != nil && account.Proxy != nil, + formatSoraHeaders(headers), + ) + } + var reader io.Reader if bodyBytes != nil { reader = bytes.NewReader(bodyBytes) @@ -618,7 +995,21 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } resp, err := c.doHTTP(req, proxyURL, account) if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf( + "request_transport_error method=%s url=%s attempt=%d/%d err=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + logredact.RedactText(err.Error()), + ) + } if attempt < attempts && allowRetry { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled method=%s url=%s reason=transport_error next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), attempt+1, attempts) + } c.sleepRetry(attempt) continue } @@ -632,12 +1023,53 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } if c.cfg != nil && c.cfg.Sora.Client.Debug { - log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start)) + c.debugLogf( + "response_received method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + time.Since(start), + len(respBody), + formatSoraHeaders(resp.Header), + ) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody) + if !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil { + if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" { + headers.Set("Authorization", "Bearer "+recovered) + authRecovered = true + if attempt == attempts && !authRecoverExtraAttemptGranted { + attempts++ + authRecoverExtraAttemptGranted = true + } + if c.debugEnabled() { + c.debugLogf("request_retry_with_recovered_token method=%s url=%s status=%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode) + } + continue + } else if recoverErr != nil && c.debugEnabled() { + c.debugLogf("request_recover_token_failed method=%s url=%s status=%d err=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error())) + } + } + if c.debugEnabled() { + c.debugLogf( + "response_non_success method=%s url=%s attempt=%d/%d status=%d body=%s", + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + summarizeSoraResponseBody(respBody, 512), + ) + } + upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr) + lastErr = upstreamErr if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled method=%s url=%s reason=status_%d next_attempt=%d/%d", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) + } c.sleepRetry(attempt) continue } @@ -645,9 +1077,34 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } return respBody, resp.Header, nil } + if lastErr != nil { + return nil, nil, lastErr + } return nil, nil, errors.New("upstream retries exhausted") } +func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool { + switch statusCode { + case http.StatusUnauthorized, http.StatusForbidden: + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return false + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return false + } + // 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。 + path := strings.ToLower(strings.TrimSpace(parsed.Path)) + if path == "/api/auth/session" { + return false + } + return true + default: + return false + } +} + func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint if c.httpUpstream != nil { @@ -670,9 +1127,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) { time.Sleep(backoff) } -func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error { +func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error { msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) msg = sanitizeUpstreamErrorMessage(msg) + if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") { + if hint := soraBaseURLNotFoundHint(requestURL); hint != "" { + msg = strings.TrimSpace(msg + " " + hint) + } + } if msg == "" { msg = truncateForLog(body, 256) } @@ -684,6 +1146,45 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b } } +func normalizeSoraBaseURL(raw string) string { + trimmed := strings.TrimRight(strings.TrimSpace(raw), "/") + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return trimmed + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return trimmed + } + pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/") + switch pathVal { + case "", "/": + parsed.Path = "/backend" + case "/backend-api": + parsed.Path = "/backend" + } + return strings.TrimRight(parsed.String(), "/") +} + +func soraBaseURLNotFoundHint(requestURL string) string { + parsed, err := url.Parse(strings.TrimSpace(requestURL)) + if err != nil || parsed.Host == "" { + return "" + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return "" + } + pathVal := strings.TrimSpace(parsed.Path) + if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" { + return "" + } + return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)" +} + func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) { reqID := uuid.NewString() userAgent := soraRandChoice(soraDesktopUserAgents) @@ -901,3 +1402,70 @@ func sanitizeSoraLogURL(raw string) string { parsed.RawQuery = q.Encode() return parsed.String() } + +func (c *SoraDirectClient) debugEnabled() bool { + return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug +} + +func (c *SoraDirectClient) debugLogf(format string, args ...any) { + if !c.debugEnabled() { + return + } + log.Printf("[SoraClient] "+format, args...) +} + +func formatSoraHeaders(headers http.Header) string { + if len(headers) == 0 { + return "{}" + } + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) + } + sort.Strings(keys) + out := make(map[string]string, len(keys)) + for _, key := range keys { + values := headers.Values(key) + if len(values) == 0 { + continue + } + val := strings.Join(values, ",") + if isSensitiveHeader(key) { + out[key] = "***" + continue + } + out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160) + } + encoded, err := json.Marshal(out) + if err != nil { + return "{}" + } + return string(encoded) +} + +func isSensitiveHeader(key string) bool { + k := strings.ToLower(strings.TrimSpace(key)) + switch k { + case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key": + return true + default: + return false + } +} + +func summarizeSoraResponseBody(body []byte, maxLen int) string { + if len(body) == 0 { + return "" + } + var text string + if json.Valid(body) { + text = logredact.RedactJSON(body) + } else { + text = logredact.RedactText(string(body)) + } + text = strings.TrimSpace(text) + if maxLen <= 0 || len(text) <= maxLen { + return text + } + return text[:maxLen] + "...(truncated)" +} diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index a6bf71cd..3e88c9f9 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -4,9 +4,13 @@ package service import ( "context" + "encoding/json" "net/http" "net/http/httptest" + "strings" + "sync/atomic" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" @@ -85,3 +89,273 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) { require.Equal(t, "completed", status.Status) require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) } + +func TestNormalizeSoraBaseURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "append_backend_for_sora_host", + raw: "https://sora.chatgpt.com", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "convert_backend_api_to_backend", + raw: "https://sora.chatgpt.com/backend-api", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_backend", + raw: "https://sora.chatgpt.com/backend", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_custom_host", + raw: "https://example.com/custom-path", + want: "https://example.com/custom-path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeSoraBaseURL(tt.raw) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) { + t.Parallel() + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen")) +} + +func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) { + t.Parallel() + client := NewSoraDirectClient(&config.Config{}, nil, nil) + + err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen") + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url") + + errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen") + require.ErrorAs(t, errNoHint, &upstreamErr) + require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url") +} + +func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) { + t.Parallel() + headers := http.Header{} + headers.Set("Authorization", "Bearer secret-token") + headers.Set("openai-sentinel-token", "sentinel-secret") + headers.Set("X-Test", "ok") + + out := formatSoraHeaders(headers) + require.Contains(t, out, `"Authorization":"***"`) + require.Contains(t, out, `Sentinel-Token":"***"`) + require.Contains(t, out, `"X-Test":"ok"`) + require.NotContains(t, out, "secret-token") + require.NotContains(t, out, "sentinel-secret") +} + +func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) { + t.Parallel() + body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`) + out := summarizeSoraResponseBody(body, 512) + require.Contains(t, out, `"access_token":"***"`) + require.NotContains(t, out, "abc123") +} + +func TestSummarizeSoraResponseBody_Truncates(t *testing.T) { + t.Parallel() + body := []byte(strings.Repeat("x", 100)) + out := summarizeSoraResponseBody(body, 10) + require.Contains(t, out, "(truncated)") +} + +func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "sora-credential-token", token) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled)) +} + +func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + account := &Account{ + ID: 2, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + cache.tokens[OpenAITokenCacheKey(account)] = "provider-token" + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + UseOpenAITokenProvider: true, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "provider-token", token) + require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0)) +} + +func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "accessToken": "session-access-token", + "expires": "2099-01-01T00:00:00Z", + }) + })) + defer server.Close() + + origin := soraSessionAuthURL + soraSessionAuthURL = server.URL + defer func() { soraSessionAuthURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 10, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "session_token": "session-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "session-access-token", token) + require.Equal(t, "session-access-token", account.GetCredential("access_token")) +} + +func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/oauth/token", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "refresh-access-token", + "refresh_token": "refresh-token-new", + "expires_in": 3600, + }) + })) + defer server.Close() + + origin := soraOAuthTokenURL + soraOAuthTokenURL = server.URL + "/oauth/token" + defer func() { soraOAuthTokenURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 11, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "refresh-token-old", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refresh-access-token", token) + require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token")) + require.NotNil(t, account.GetCredentialAsTime("expires_at")) +} + +func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Equal(t, "/nf/check", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "rate_limit_and_credit_balance": map[string]any{ + "estimated_num_videos_remaining": 0, + "rate_limit_reached": true, + }, + }) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ + ID: 12, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "ok", + "expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339), + }, + } + err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"}) + require.Error(t, err) + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) +} + +func TestShouldAttemptSoraTokenRecover(t *testing.T) { + t.Parallel() + + require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen")) + require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen")) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index d7ff297c..8ae89f92 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -61,6 +61,10 @@ type SoraGatewayService struct { cfg *config.Config } +type soraPreflightChecker interface { + PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error +} + func NewSoraGatewayService( soraClient SoraClient, mediaStorage *SoraMediaStorage, @@ -112,11 +116,6 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) return nil, fmt.Errorf("unsupported model: %s", reqModel) } - if modelCfg.Type == "prompt_enhance" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream) - return nil, fmt.Errorf("prompt-enhance not supported") - } - prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) if strings.TrimSpace(prompt) == "" { s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) @@ -131,6 +130,41 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun if cancel != nil { defer cancel() } + if checker, ok := s.soraClient.(soraPreflightChecker); ok { + if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + } + + if modelCfg.Type == "prompt_enhance" { + enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + content := strings.TrimSpace(enhancedPrompt) + if content == "" { + content = prompt + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } var imageData []byte imageFilename := "" @@ -267,7 +301,7 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) ( func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { - case 401, 402, 403, 429, 529: + case 401, 402, 403, 404, 429, 529: return true default: return statusCode >= 500 @@ -460,7 +494,7 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) } if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { - return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode} + return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body} } msg := upstreamErr.Message if override := soraProErrorMessage(model, msg); override != "" { diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index d6bf9eae..f706d052 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -18,6 +18,8 @@ type stubSoraClientForPoll struct { videoStatus *SoraVideoTaskStatus imageCalls int videoCalls int + enhanced string + enhanceErr error } func (s *stubSoraClientForPoll) Enabled() bool { return true } @@ -30,6 +32,12 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { return "task-video", nil } +func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + if s.enhanced != "" { + return s.enhanced, s.enhanceErr + } + return "enhanced prompt", s.enhanceErr +} func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { s.imageCalls++ return s.imageStatus, nil @@ -62,6 +70,33 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { require.Equal(t, 1, client.imageCalls) } +func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { + client := &stubSoraClientForPoll{ + enhanced: "cinematic prompt", + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + } + body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, "prompt-enhance-short-10s", result.Model) +} + func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { client := &stubSoraClientForPoll{ videoStatus: &SoraVideoTaskStatus{ @@ -178,6 +213,7 @@ func TestSoraProErrorMessage(t *testing.T) { func TestShouldFailoverUpstreamError(t *testing.T) { svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) require.True(t, svc.shouldFailoverUpstreamError(401)) + require.True(t, svc.shouldFailoverUpstreamError(404)) require.True(t, svc.shouldFailoverUpstreamError(429)) require.True(t, svc.shouldFailoverUpstreamError(500)) require.True(t, svc.shouldFailoverUpstreamError(502)) diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go index ab095e46..80b20a4b 100644 --- a/backend/internal/service/sora_models.go +++ b/backend/internal/service/sora_models.go @@ -17,6 +17,9 @@ type SoraModelConfig struct { Model string Size string RequirePro bool + // Prompt-enhance 专用参数 + ExpansionLevel string + DurationS int } var soraModelConfigs = map[string]SoraModelConfig{ @@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{ RequirePro: true, }, "prompt-enhance-short-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 10, }, "prompt-enhance-short-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 15, }, "prompt-enhance-short-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 20, }, "prompt-enhance-medium-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 10, }, "prompt-enhance-medium-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 15, }, "prompt-enhance-medium-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 20, }, "prompt-enhance-long-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 10, }, "prompt-enhance-long-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 15, }, "prompt-enhance-long-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 20, }, } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 9de1c164..a37e0d0a 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -43,10 +43,13 @@ func NewTokenRefreshService( stopCh: make(chan struct{}), } + openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) + openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts) + // 注册平台特定的刷新器 s.refreshers = []TokenRefresher{ NewClaudeTokenRefresher(oauthService), - NewOpenAITokenRefresher(openaiOAuthService, accountRepo), + openAIRefresher, NewGeminiTokenRefresher(geminiOAuthService), NewAntigravityTokenRefresher(antigravityOAuthService), } diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 46033f75..0dd3cf45 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -86,6 +86,7 @@ type OpenAITokenRefresher struct { openaiOAuthService *OpenAIOAuthService accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 + syncLinkedSora bool } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { r.soraAccountRepo = repo } +// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。 +func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) { + r.syncLinkedSora = enabled +} + // CanRefresh 检查是否能处理此账号 -// 只处理 openai 平台的 oauth 类型账号 +// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号) func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { - return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) && - account.Type == AccountTypeOAuth + return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth } // NeedsRefresh 检查token是否需要刷新 @@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m } // 异步同步关联的 Sora 账号(不阻塞主流程) - if r.accountRepo != nil { + if r.accountRepo != nil && r.syncLinkedSora { go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) } diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go index c7505037..264d7912 100644 --- a/backend/internal/service/token_refresher_test.go +++ b/backend/internal/service/token_refresher_test.go @@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { }) } } + +func TestOpenAITokenRefresher_CanRefresh(t *testing.T) { + refresher := &OpenAITokenRefresher{} + + tests := []struct { + name string + platform string + accType string + want bool + }{ + { + name: "openai oauth - can refresh", + platform: PlatformOpenAI, + accType: AccountTypeOAuth, + want: true, + }, + { + name: "sora oauth - cannot refresh directly", + platform: PlatformSora, + accType: AccountTypeOAuth, + want: false, + }, + { + name: "openai apikey - cannot refresh", + platform: PlatformOpenAI, + accType: AccountTypeAPIKey, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: tt.platform, + Type: tt.accType, + } + require.Equal(t, tt.want, refresher.CanRefresh(account)) + }) + } +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 5d712f75..652f9e00 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { return NewSoraMediaStorage(cfg) } +func ProvideSoraDirectClient( + cfg *config.Config, + httpUpstream HTTPUpstream, + tokenProvider *OpenAITokenProvider, + accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, +) *SoraDirectClient { + client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider) + client.SetAccountRepositories(accountRepo, soraAccountRepo) + return client +} + // ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务 func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { svc := NewSoraMediaCleanupService(storage, cfg) @@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet( NewGatewayService, ProvideSoraMediaStorage, ProvideSoraMediaCleanupService, - NewSoraDirectClient, + ProvideSoraDirectClient, wire.Bind(new(SoraClient), new(*SoraDirectClient)), NewSoraGatewayService, NewOpenAIGatewayService, diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 7f37d59c..f7ba5c9e 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/setup/") || path == "/health" || @@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/setup/") || path == "/health" || diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index 50f5a323..e2cbcf15 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) { "/api/v1/users", "/v1/models", "/v1beta/chat", + "/sora/v1/models", "/antigravity/test", "/setup/init", "/health", @@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) { "/api/users", "/v1/models", "/v1beta/chat", + "/sora/v1/models", "/antigravity/test", "/setup/init", "/health", diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 9fd2d391..0ff1ec02 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -388,7 +388,11 @@ sora: recent_task_limit_max: 200 # Enable debug logs for Sora upstream requests # 启用 Sora 直连调试日志 + # 调试日志会输出上游请求尝试、重试、响应摘要;Authorization/openai-sentinel-token 等敏感头会自动脱敏 debug: false + # Allow Sora client to fetch token via OpenAI token provider + # 是否允许 Sora 客户端通过 OpenAI token provider 取 token(默认 false,避免误走 OpenAI 刷新链路) + use_openai_token_provider: false # Optional custom headers (key-value) # 额外请求头(键值对) headers: {} @@ -431,6 +435,13 @@ sora: # Cron 调度表达式 schedule: "0 3 * * *" +# Token refresh behavior +# token 刷新行为控制 +token_refresh: + # Whether OpenAI refresh flow is allowed to sync linked Sora accounts + # 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token + sync_linked_sora_accounts: false + # ============================================================================= # API Key Auth Cache Configuration # API Key 认证缓存配置 diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 36bec4e7..e1f502ec 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -220,7 +220,7 @@ export async function generateAuthUrl( */ export async function exchangeCode( endpoint: string, - exchangeData: { session_id: string; code: string; proxy_id?: number } + exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number } ): Promise> { const { data } = await apiClient.post>(endpoint, exchangeData) return data @@ -442,7 +442,8 @@ export async function getAntigravityDefaultModelMapping(): Promise> { const payload: { refresh_token: string; proxy_id?: number } = { refresh_token: refreshToken @@ -450,7 +451,29 @@ export async function refreshOpenAIToken( if (proxyId) { payload.proxy_id = proxyId } - const { data } = await apiClient.post>('/admin/openai/refresh-token', payload) + const { data } = await apiClient.post>(endpoint, payload) + return data +} + +/** + * Validate Sora session token and exchange to access token + * @param sessionToken - Sora session token + * @param proxyId - Optional proxy ID + * @param endpoint - API endpoint path + * @returns Token information including access_token + */ +export async function validateSoraSessionToken( + sessionToken: string, + proxyId?: number | null, + endpoint: string = '/admin/sora/st2at' +): Promise> { + const payload: { session_token: string; proxy_id?: number } = { + session_token: sessionToken + } + if (proxyId) { + payload.proxy_id = proxyId + } + const { data } = await apiClient.post>(endpoint, payload) return data } @@ -475,6 +498,7 @@ export const accountsAPI = { generateAuthUrl, exchangeCode, refreshOpenAIToken, + validateSoraSessionToken, batchCreate, batchUpdateCredentials, bulkUpdate, diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 85785d6a..8024dfb6 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -109,6 +109,28 @@ OpenAI +
+ +
+ +
+ +
+
+
@@ -1747,32 +1801,6 @@
- -
- -
-
@@ -2148,6 +2178,7 @@ interface OAuthFlowExposed { projectId: string sessionKey: string refreshToken: string + sessionToken: string inputMethod: AuthInputMethod reset: () => void } @@ -2156,7 +2187,7 @@ const { t } = useI18n() const authStore = useAuthStore() const oauthStepTitle = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.oauth.openai.title') if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title') if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title') return t('admin.accounts.oauth.title') @@ -2164,13 +2195,13 @@ const oauthStepTitle = computed(() => { // Platform-specific hints for API Key type const baseUrlHint = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.baseUrlHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') return t('admin.accounts.baseUrlHint') }) const apiKeyHint = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.apiKeyHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint') return t('admin.accounts.apiKeyHint') }) @@ -2191,34 +2222,36 @@ const appStore = useAppStore() // OAuth composables const oauth = useAccountOAuth() // For Anthropic OAuth -const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth +const openaiOAuth = useOpenAIOAuth({ platform: 'openai' }) // For OpenAI OAuth +const soraOAuth = useOpenAIOAuth({ platform: 'sora' }) // For Sora OAuth const geminiOAuth = useGeminiOAuth() // For Gemini OAuth const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth +const activeOpenAIOAuth = computed(() => (form.platform === 'sora' ? soraOAuth : openaiOAuth)) // Computed: current OAuth state for template binding const currentAuthUrl = computed(() => { - if (form.platform === 'openai') return openaiOAuth.authUrl.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.authUrl.value if (form.platform === 'gemini') return geminiOAuth.authUrl.value if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value return oauth.authUrl.value }) const currentSessionId = computed(() => { - if (form.platform === 'openai') return openaiOAuth.sessionId.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.sessionId.value if (form.platform === 'gemini') return geminiOAuth.sessionId.value if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value return oauth.sessionId.value }) const currentOAuthLoading = computed(() => { - if (form.platform === 'openai') return openaiOAuth.loading.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.loading.value if (form.platform === 'gemini') return geminiOAuth.loading.value if (form.platform === 'antigravity') return antigravityOAuth.loading.value return oauth.loading.value }) const currentOAuthError = computed(() => { - if (form.platform === 'openai') return openaiOAuth.error.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.error.value if (form.platform === 'gemini') return geminiOAuth.error.value if (form.platform === 'antigravity') return antigravityOAuth.error.value return oauth.error.value @@ -2257,7 +2290,6 @@ const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(true) const openaiPassthroughEnabled = ref(false) const codexCLIOnlyEnabled = ref(false) -const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream const upstreamBaseUrl = ref('') // For upstream type: base URL @@ -2398,8 +2430,8 @@ const expiresAtInput = computed({ const canExchangeCode = computed(() => { const authCode = oauthFlowRef.value?.authCode || '' - if (form.platform === 'openai') { - return authCode.trim() && openaiOAuth.sessionId.value && !openaiOAuth.loading.value + if (form.platform === 'openai' || form.platform === 'sora') { + return authCode.trim() && activeOpenAIOAuth.value.sessionId.value && !activeOpenAIOAuth.value.loading.value } if (form.platform === 'gemini') { return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value @@ -2459,7 +2491,7 @@ watch( (newPlatform) => { // Reset base URL based on platform apiKeyBaseUrl.value = - newPlatform === 'openai' + (newPlatform === 'openai' || newPlatform === 'sora') ? 'https://api.openai.com' : newPlatform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -2485,6 +2517,11 @@ watch( if (newPlatform !== 'anthropic') { interceptWarmupRequests.value = false } + if (newPlatform === 'sora') { + accountCategory.value = 'oauth-based' + addMethod.value = 'oauth' + form.type = 'oauth' + } if (newPlatform !== 'openai') { openaiPassthroughEnabled.value = false codexCLIOnlyEnabled.value = false @@ -2492,6 +2529,7 @@ watch( // Reset OAuth states oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() } @@ -2753,7 +2791,6 @@ const resetForm = () => { autoPauseOnExpired.value = true openaiPassthroughEnabled.value = false codexCLIOnlyEnabled.value = false - enableSoraOnOpenAIOAuth.value = false // Reset quota control state windowCostEnabled.value = false windowCostLimit.value = null @@ -2776,6 +2813,7 @@ const resetForm = () => { geminiTierAIStudio.value = 'aistudio_free' oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -2807,6 +2845,23 @@ const buildOpenAIExtra = (base?: Record): Record 0 ? extra : undefined } +const buildSoraExtra = ( + base?: Record, + linkedOpenAIAccountId?: string | number +): Record | undefined => { + const extra: Record = { ...(base || {}) } + if (linkedOpenAIAccountId !== undefined && linkedOpenAIAccountId !== null) { + const id = String(linkedOpenAIAccountId).trim() + if (id) { + extra.linked_openai_account_id = id + } + } + delete extra.openai_passthrough + delete extra.openai_oauth_passthrough + delete extra.codex_cli_only + return Object.keys(extra).length > 0 ? extra : undefined +} + // Helper function to create account with mixed channel warning handling const doCreateAccount = async (payload: any) => { submitting.value = true @@ -2922,7 +2977,7 @@ const handleSubmit = async () => { // Determine default base URL based on platform const defaultBaseUrl = - form.platform === 'openai' + (form.platform === 'openai' || form.platform === 'sora') ? 'https://api.openai.com' : form.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -2974,14 +3029,15 @@ const goBackToBasicInfo = () => { step.value = 1 oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() } const handleGenerateUrl = async () => { - if (form.platform === 'openai') { - await openaiOAuth.generateAuthUrl(form.proxy_id) + if (form.platform === 'openai' || form.platform === 'sora') { + await activeOpenAIOAuth.value.generateAuthUrl(form.proxy_id) } else if (form.platform === 'gemini') { await geminiOAuth.generateAuthUrl( form.proxy_id, @@ -2997,13 +3053,19 @@ const handleGenerateUrl = async () => { } const handleValidateRefreshToken = (rt: string) => { - if (form.platform === 'openai') { + if (form.platform === 'openai' || form.platform === 'sora') { handleOpenAIValidateRT(rt) } else if (form.platform === 'antigravity') { handleAntigravityValidateRT(rt) } } +const handleValidateSessionToken = (sessionToken: string) => { + if (form.platform === 'sora') { + handleSoraValidateST(sessionToken) + } +} + const formatDateTimeLocal = formatDateTimeLocalInput const parseDateTimeLocal = parseDateTimeLocalInput @@ -3039,100 +3101,101 @@ const createAccountAndFinish = async ( // OpenAI OAuth 授权码兑换 const handleOpenAIExchange = async (authCode: string) => { - if (!authCode.trim() || !openaiOAuth.sessionId.value) return + const oauthClient = activeOpenAIOAuth.value + if (!authCode.trim() || !oauthClient.sessionId.value) return - openaiOAuth.loading.value = true - openaiOAuth.error.value = '' + oauthClient.loading.value = true + oauthClient.error.value = '' try { - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } + + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), - openaiOAuth.sessionId.value, + oauthClient.sessionId.value, + stateToUse, form.proxy_id ) if (!tokenInfo) return - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record | undefined + const credentials = oauthClient.buildCredentials(tokenInfo) + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) + const shouldCreateOpenAI = form.platform === 'openai' + const shouldCreateSora = form.platform === 'sora' // 应用临时不可调度配置 if (!applyTempUnschedConfig(credentials)) { return } - // 1. 创建 OpenAI 账号 - const openaiAccount = await adminAPI.accounts.create({ - name: form.name, - notes: form.notes, - platform: 'openai', - type: 'oauth', - credentials, - extra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) + let openaiAccountId: string | number | undefined - appStore.showSuccess(t('admin.accounts.accountCreated')) + if (shouldCreateOpenAI) { + const openaiAccount = await adminAPI.accounts.create({ + name: form.name, + notes: form.notes, + platform: 'openai', + type: 'oauth', + credentials, + extra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + openaiAccountId = openaiAccount.id + appStore.showSuccess(t('admin.accounts.accountCreated')) + } - // 2. 如果启用了 Sora,同时创建 Sora 账号 - if (enableSoraOnOpenAIOAuth.value) { - try { - // Sora 使用相同的 OAuth credentials - const soraCredentials = { - access_token: credentials.access_token, - refresh_token: credentials.refresh_token, - expires_at: credentials.expires_at - } - - // 建立关联关系 - const soraExtra: Record = { - ...(extra || {}), - linked_openai_account_id: String(openaiAccount.id) - } - delete soraExtra.openai_passthrough - delete soraExtra.openai_oauth_passthrough - - await adminAPI.accounts.create({ - name: `${form.name} (Sora)`, - notes: form.notes, - platform: 'sora', - type: 'oauth', - credentials: soraCredentials, - extra: soraExtra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) - - appStore.showSuccess(t('admin.accounts.soraAccountCreated')) - } catch (error: any) { - console.error('创建 Sora 账号失败:', error) - appStore.showWarning(t('admin.accounts.soraAccountFailed')) + if (shouldCreateSora) { + const soraCredentials = { + access_token: credentials.access_token, + refresh_token: credentials.refresh_token, + expires_at: credentials.expires_at } + + const soraName = shouldCreateOpenAI ? `${form.name} (Sora)` : form.name + const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId) + await adminAPI.accounts.create({ + name: soraName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials: soraCredentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + appStore.showSuccess(t('admin.accounts.accountCreated')) } emit('created') handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } finally { - openaiOAuth.loading.value = false + oauthClient.loading.value = false } } // OpenAI 手动 RT 批量验证和创建 const handleOpenAIValidateRT = async (refreshTokenInput: string) => { + const oauthClient = activeOpenAIOAuth.value if (!refreshTokenInput.trim()) return // Parse multiple refresh tokens (one per line) @@ -3142,53 +3205,86 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { .filter((rt) => rt) if (refreshTokens.length === 0) { - openaiOAuth.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken') + oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken') return } - openaiOAuth.loading.value = true - openaiOAuth.error.value = '' + oauthClient.loading.value = true + oauthClient.error.value = '' let successCount = 0 let failedCount = 0 const errors: string[] = [] + const shouldCreateOpenAI = form.platform === 'openai' + const shouldCreateSora = form.platform === 'sora' try { for (let i = 0; i < refreshTokens.length; i++) { try { - const tokenInfo = await openaiOAuth.validateRefreshToken( + const tokenInfo = await oauthClient.validateRefreshToken( refreshTokens[i], form.proxy_id ) if (!tokenInfo) { failedCount++ - errors.push(`#${i + 1}: ${openaiOAuth.error.value || 'Validation failed'}`) - openaiOAuth.error.value = '' + errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`) + oauthClient.error.value = '' continue } - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record | undefined + const credentials = oauthClient.buildCredentials(tokenInfo) + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) // Generate account name with index for batch const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name - await adminAPI.accounts.create({ - name: accountName, - notes: form.notes, - platform: 'openai', - type: 'oauth', - credentials, - extra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) + let openaiAccountId: string | number | undefined + + if (shouldCreateOpenAI) { + const openaiAccount = await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'openai', + type: 'oauth', + credentials, + extra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + openaiAccountId = openaiAccount.id + } + + if (shouldCreateSora) { + const soraCredentials = { + access_token: credentials.access_token, + refresh_token: credentials.refresh_token, + expires_at: credentials.expires_at + } + const soraName = shouldCreateOpenAI ? `${accountName} (Sora)` : accountName + const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId) + await adminAPI.accounts.create({ + name: soraName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials: soraCredentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + } + successCount++ } catch (error: any) { failedCount++ @@ -3210,14 +3306,99 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { appStore.showWarning( t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) ) - openaiOAuth.error.value = errors.join('\n') + oauthClient.error.value = errors.join('\n') emit('created') } else { - openaiOAuth.error.value = errors.join('\n') + oauthClient.error.value = errors.join('\n') appStore.showError(t('admin.accounts.oauth.batchFailed')) } } finally { - openaiOAuth.loading.value = false + oauthClient.loading.value = false + } +} + +// Sora 手动 ST 批量验证和创建 +const handleSoraValidateST = async (sessionTokenInput: string) => { + const oauthClient = activeOpenAIOAuth.value + if (!sessionTokenInput.trim()) return + + const sessionTokens = sessionTokenInput + .split('\n') + .map((st) => st.trim()) + .filter((st) => st) + + if (sessionTokens.length === 0) { + oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterSessionToken') + return + } + + oauthClient.loading.value = true + oauthClient.error.value = '' + + let successCount = 0 + let failedCount = 0 + const errors: string[] = [] + + try { + for (let i = 0; i < sessionTokens.length; i++) { + try { + const tokenInfo = await oauthClient.validateSessionToken(sessionTokens[i], form.proxy_id) + if (!tokenInfo) { + failedCount++ + errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`) + oauthClient.error.value = '' + continue + } + + const credentials = oauthClient.buildCredentials(tokenInfo) + credentials.session_token = sessionTokens[i] + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined + const soraExtra = buildSoraExtra(oauthExtra) + + const accountName = sessionTokens.length > 1 ? `${form.name} #${i + 1}` : form.name + await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + successCount++ + } catch (error: any) { + failedCount++ + const errMsg = error.response?.data?.detail || error.message || 'Unknown error' + errors.push(`#${i + 1}: ${errMsg}`) + } + } + + if (successCount > 0 && failedCount === 0) { + appStore.showSuccess( + sessionTokens.length > 1 + ? t('admin.accounts.oauth.batchSuccess', { count: successCount }) + : t('admin.accounts.accountCreated') + ) + emit('created') + handleClose() + } else if (successCount > 0 && failedCount > 0) { + appStore.showWarning( + t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) + ) + oauthClient.error.value = errors.join('\n') + emit('created') + } else { + oauthClient.error.value = errors.join('\n') + appStore.showError(t('admin.accounts.oauth.batchFailed')) + } + } finally { + oauthClient.loading.value = false } } @@ -3462,6 +3643,7 @@ const handleExchangeCode = async () => { switch (form.platform) { case 'openai': + case 'sora': return handleOpenAIExchange(authCode) case 'gemini': return handleGeminiExchange(authCode) diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 9c4b7e4b..8e00d25b 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -48,6 +48,17 @@ t(getOAuthKey('refreshTokenAuth')) }} +
@@ -135,6 +146,87 @@
+ +
+
+

+ {{ t(getOAuthKey('sessionTokenDesc')) }} +

+ +
+ + +

+ {{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedSessionTokenCount }) }} +

+
+ +
+

+ {{ error }} +

+
+ + +
+
+
(), { authUrl: '', @@ -540,6 +633,7 @@ const props = withDefaults(defineProps(), { methodLabel: 'Authorization Method', showCookieOption: true, showRefreshTokenOption: false, + showSessionTokenOption: false, platform: 'anthropic', showProjectId: true }) @@ -549,6 +643,7 @@ const emit = defineEmits<{ 'exchange-code': [code: string] 'cookie-auth': [sessionKey: string] 'validate-refresh-token': [refreshToken: string] + 'validate-session-token': [sessionToken: string] 'update:inputMethod': [method: AuthInputMethod] }>() @@ -587,12 +682,13 @@ const inputMethod = ref(props.showCookieOption ? 'manual' : 'ma const authCodeInput = ref('') const sessionKeyInput = ref('') const refreshTokenInput = ref('') +const sessionTokenInput = ref('') const showHelpDialog = ref(false) const oauthState = ref('') const projectId = ref('') // Computed: show method selection when either cookie or refresh token option is enabled -const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption) +const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption) // Clipboard const { copied, copyToClipboard } = useClipboard() @@ -613,6 +709,13 @@ const parsedRefreshTokenCount = computed(() => { .filter((rt) => rt).length }) +const parsedSessionTokenCount = computed(() => { + return sessionTokenInput.value + .split('\n') + .map((st) => st.trim()) + .filter((st) => st).length +}) + // Watchers watch(inputMethod, (newVal) => { emit('update:inputMethod', newVal) @@ -631,7 +734,7 @@ watch(authCodeInput, (newVal) => { const url = new URL(trimmed) const code = url.searchParams.get('code') const stateParam = url.searchParams.get('state') - if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) { + if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) { oauthState.value = stateParam } if (code && code !== trimmed) { @@ -642,7 +745,7 @@ watch(authCodeInput, (newVal) => { // If URL parsing fails, try regex extraction const match = trimmed.match(/[?&]code=([^&]+)/) const stateMatch = trimmed.match(/[?&]state=([^&]+)/) - if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) { + if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) { oauthState.value = stateMatch[1] } if (match && match[1] && match[1] !== trimmed) { @@ -680,6 +783,12 @@ const handleValidateRefreshToken = () => { } } +const handleValidateSessionToken = () => { + if (sessionTokenInput.value.trim()) { + emit('validate-session-token', sessionTokenInput.value.trim()) + } +} + // Expose methods and state defineExpose({ authCode: authCodeInput, @@ -687,6 +796,7 @@ defineExpose({ projectId, sessionKey: sessionKeyInput, refreshToken: refreshTokenInput, + sessionToken: sessionTokenInput, inputMethod, reset: () => { authCodeInput.value = '' @@ -694,6 +804,7 @@ defineExpose({ projectId.value = '' sessionKeyInput.value = '' refreshTokenInput.value = '' + sessionTokenInput.value = '' inputMethod.value = 'manual' showHelpDialog.value = false } diff --git a/frontend/src/components/account/ReAuthAccountModal.vue b/frontend/src/components/account/ReAuthAccountModal.vue index b2734b4f..aab0fe7d 100644 --- a/frontend/src/components/account/ReAuthAccountModal.vue +++ b/frontend/src/components/account/ReAuthAccountModal.vue @@ -14,7 +14,7 @@
('code_as // Computed - check platform const isOpenAI = computed(() => props.account?.platform === 'openai') +const isSora = computed(() => props.account?.platform === 'sora') +const isOpenAILike = computed(() => isOpenAI.value || isSora.value) const isGemini = computed(() => props.account?.platform === 'gemini') const isAnthropic = computed(() => props.account?.platform === 'anthropic') const isAntigravity = computed(() => props.account?.platform === 'antigravity') +const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth)) // Computed - current OAuth state based on platform const currentAuthUrl = computed(() => { - if (isOpenAI.value) return openaiOAuth.authUrl.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value if (isGemini.value) return geminiOAuth.authUrl.value if (isAntigravity.value) return antigravityOAuth.authUrl.value return claudeOAuth.authUrl.value }) const currentSessionId = computed(() => { - if (isOpenAI.value) return openaiOAuth.sessionId.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value if (isGemini.value) return geminiOAuth.sessionId.value if (isAntigravity.value) return antigravityOAuth.sessionId.value return claudeOAuth.sessionId.value }) const currentLoading = computed(() => { - if (isOpenAI.value) return openaiOAuth.loading.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value if (isGemini.value) return geminiOAuth.loading.value if (isAntigravity.value) return antigravityOAuth.loading.value return claudeOAuth.loading.value }) const currentError = computed(() => { - if (isOpenAI.value) return openaiOAuth.error.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value if (isGemini.value) return geminiOAuth.error.value if (isAntigravity.value) return antigravityOAuth.error.value return claudeOAuth.error.value @@ -269,8 +275,8 @@ const currentError = computed(() => { // Computed const isManualInputMethod = computed(() => { - // OpenAI/Gemini/Antigravity always use manual input (no cookie auth option) - return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' + // OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option) + return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' }) const canExchangeCode = computed(() => { @@ -313,6 +319,7 @@ const resetState = () => { geminiOAuthType.value = 'code_assist' claudeOAuth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -325,8 +332,8 @@ const handleClose = () => { const handleGenerateUrl = async () => { if (!props.account) return - if (isOpenAI.value) { - await openaiOAuth.generateAuthUrl(props.account.proxy_id) + if (isOpenAILike.value) { + await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id) } else if (isGemini.value) { const creds = (props.account.credentials || {}) as Record const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined @@ -345,21 +352,29 @@ const handleExchangeCode = async () => { const authCode = oauthFlowRef.value?.authCode || '' if (!authCode.trim()) return - if (isOpenAI.value) { + if (isOpenAILike.value) { // OpenAI OAuth flow - const sessionId = openaiOAuth.sessionId.value + const oauthClient = activeOpenAIOAuth.value + const sessionId = oauthClient.sessionId.value if (!sessionId) return + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), sessionId, + stateToUse, props.account.proxy_id ) if (!tokenInfo) return // Build credentials and extra info - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const extra = openaiOAuth.buildExtraInfo(tokenInfo) + const credentials = oauthClient.buildCredentials(tokenInfo) + const extra = oauthClient.buildExtraInfo(tokenInfo) try { // Update account with new credentials @@ -376,8 +391,8 @@ const handleExchangeCode = async () => { emit('reauthorized') handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } } else if (isGemini.value) { const sessionId = geminiOAuth.sessionId.value @@ -490,7 +505,7 @@ const handleExchangeCode = async () => { } const handleCookieAuth = async (sessionKey: string) => { - if (!props.account || isOpenAI.value) return + if (!props.account || isOpenAILike.value) return claudeOAuth.loading.value = true claudeOAuth.error.value = '' diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index feb09654..38196781 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -238,6 +238,11 @@ const loadAvailableModels = async () => { availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') || availableModels.value.find((m) => m.id === 'gemini-3-pro-preview') selectedModelId.value = preferred?.id || availableModels.value[0].id + } else if (props.account.platform === 'sora') { + const preferred = + availableModels.value.find((m) => m.id === 'gpt-image') || + availableModels.value.find((m) => !m.id.startsWith('prompt-enhance')) + selectedModelId.value = preferred?.id || availableModels.value[0].id } else { // Try to select Sonnet as default, otherwise use first model const sonnetModel = availableModels.value.find((m) => m.id.includes('sonnet')) diff --git a/frontend/src/components/admin/account/ReAuthAccountModal.vue b/frontend/src/components/admin/account/ReAuthAccountModal.vue index eeb3f288..c269eea4 100644 --- a/frontend/src/components/admin/account/ReAuthAccountModal.vue +++ b/frontend/src/components/admin/account/ReAuthAccountModal.vue @@ -14,7 +14,7 @@
('code_as // Computed - check platform const isOpenAI = computed(() => props.account?.platform === 'openai') +const isSora = computed(() => props.account?.platform === 'sora') +const isOpenAILike = computed(() => isOpenAI.value || isSora.value) const isGemini = computed(() => props.account?.platform === 'gemini') const isAnthropic = computed(() => props.account?.platform === 'anthropic') const isAntigravity = computed(() => props.account?.platform === 'antigravity') +const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth)) // Computed - current OAuth state based on platform const currentAuthUrl = computed(() => { - if (isOpenAI.value) return openaiOAuth.authUrl.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value if (isGemini.value) return geminiOAuth.authUrl.value if (isAntigravity.value) return antigravityOAuth.authUrl.value return claudeOAuth.authUrl.value }) const currentSessionId = computed(() => { - if (isOpenAI.value) return openaiOAuth.sessionId.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value if (isGemini.value) return geminiOAuth.sessionId.value if (isAntigravity.value) return antigravityOAuth.sessionId.value return claudeOAuth.sessionId.value }) const currentLoading = computed(() => { - if (isOpenAI.value) return openaiOAuth.loading.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value if (isGemini.value) return geminiOAuth.loading.value if (isAntigravity.value) return antigravityOAuth.loading.value return claudeOAuth.loading.value }) const currentError = computed(() => { - if (isOpenAI.value) return openaiOAuth.error.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value if (isGemini.value) return geminiOAuth.error.value if (isAntigravity.value) return antigravityOAuth.error.value return claudeOAuth.error.value @@ -269,8 +275,8 @@ const currentError = computed(() => { // Computed const isManualInputMethod = computed(() => { - // OpenAI/Gemini/Antigravity always use manual input (no cookie auth option) - return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' + // OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option) + return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' }) const canExchangeCode = computed(() => { @@ -313,6 +319,7 @@ const resetState = () => { geminiOAuthType.value = 'code_assist' claudeOAuth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -325,8 +332,8 @@ const handleClose = () => { const handleGenerateUrl = async () => { if (!props.account) return - if (isOpenAI.value) { - await openaiOAuth.generateAuthUrl(props.account.proxy_id) + if (isOpenAILike.value) { + await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id) } else if (isGemini.value) { const creds = (props.account.credentials || {}) as Record const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined @@ -345,21 +352,29 @@ const handleExchangeCode = async () => { const authCode = oauthFlowRef.value?.authCode || '' if (!authCode.trim()) return - if (isOpenAI.value) { + if (isOpenAILike.value) { // OpenAI OAuth flow - const sessionId = openaiOAuth.sessionId.value + const oauthClient = activeOpenAIOAuth.value + const sessionId = oauthClient.sessionId.value if (!sessionId) return + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), sessionId, + stateToUse, props.account.proxy_id ) if (!tokenInfo) return // Build credentials and extra info - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const extra = openaiOAuth.buildExtraInfo(tokenInfo) + const credentials = oauthClient.buildCredentials(tokenInfo) + const extra = oauthClient.buildExtraInfo(tokenInfo) try { // Update account with new credentials @@ -376,8 +391,8 @@ const handleExchangeCode = async () => { emit('reauthorized', updatedAccount) handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } } else if (isGemini.value) { const sessionId = geminiOAuth.sessionId.value @@ -490,7 +505,7 @@ const handleExchangeCode = async () => { } const handleCookieAuth = async (sessionKey: string) => { - if (!props.account || isOpenAI.value) return + if (!props.account || isOpenAILike.value) return claudeOAuth.loading.value = true claudeOAuth.error.value = '' diff --git a/frontend/src/composables/useAccountOAuth.ts b/frontend/src/composables/useAccountOAuth.ts index ca200cb3..6f53404c 100644 --- a/frontend/src/composables/useAccountOAuth.ts +++ b/frontend/src/composables/useAccountOAuth.ts @@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' export type AddMethod = 'oauth' | 'setup-token' -export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' +export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' export interface OAuthState { authUrl: string diff --git a/frontend/src/composables/useOpenAIOAuth.ts b/frontend/src/composables/useOpenAIOAuth.ts index 82a77031..32045cbe 100644 --- a/frontend/src/composables/useOpenAIOAuth.ts +++ b/frontend/src/composables/useOpenAIOAuth.ts @@ -19,12 +19,21 @@ export interface OpenAITokenInfo { [key: string]: unknown } -export function useOpenAIOAuth() { +export type OpenAIOAuthPlatform = 'openai' | 'sora' + +interface UseOpenAIOAuthOptions { + platform?: OpenAIOAuthPlatform +} + +export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { const appStore = useAppStore() + const oauthPlatform = options?.platform ?? 'openai' + const endpointPrefix = oauthPlatform === 'sora' ? '/admin/sora' : '/admin/openai' // State const authUrl = ref('') const sessionId = ref('') + const oauthState = ref('') const loading = ref(false) const error = ref('') @@ -32,6 +41,7 @@ export function useOpenAIOAuth() { const resetState = () => { authUrl.value = '' sessionId.value = '' + oauthState.value = '' loading.value = false error.value = '' } @@ -44,6 +54,7 @@ export function useOpenAIOAuth() { loading.value = true authUrl.value = '' sessionId.value = '' + oauthState.value = '' error.value = '' try { @@ -56,11 +67,17 @@ export function useOpenAIOAuth() { } const response = await adminAPI.accounts.generateAuthUrl( - '/admin/openai/generate-auth-url', + `${endpointPrefix}/generate-auth-url`, payload ) authUrl.value = response.auth_url sessionId.value = response.session_id + try { + const parsed = new URL(response.auth_url) + oauthState.value = parsed.searchParams.get('state') || '' + } catch { + oauthState.value = '' + } return true } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to generate OpenAI auth URL' @@ -75,10 +92,11 @@ export function useOpenAIOAuth() { const exchangeAuthCode = async ( code: string, currentSessionId: string, + state: string, proxyId?: number | null ): Promise => { - if (!code.trim() || !currentSessionId) { - error.value = 'Missing auth code or session ID' + if (!code.trim() || !currentSessionId || !state.trim()) { + error.value = 'Missing auth code, session ID, or state' return null } @@ -86,15 +104,16 @@ export function useOpenAIOAuth() { error.value = '' try { - const payload: { session_id: string; code: string; proxy_id?: number } = { + const payload: { session_id: string; code: string; state: string; proxy_id?: number } = { session_id: currentSessionId, - code: code.trim() + code: code.trim(), + state: state.trim() } if (proxyId) { payload.proxy_id = proxyId } - const tokenInfo = await adminAPI.accounts.exchangeCode('/admin/openai/exchange-code', payload) + const tokenInfo = await adminAPI.accounts.exchangeCode(`${endpointPrefix}/exchange-code`, payload) return tokenInfo as OpenAITokenInfo } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to exchange OpenAI auth code' @@ -120,7 +139,11 @@ export function useOpenAIOAuth() { try { // Use dedicated refresh-token endpoint - const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(refreshToken.trim(), proxyId) + const tokenInfo = await adminAPI.accounts.refreshOpenAIToken( + refreshToken.trim(), + proxyId, + `${endpointPrefix}/refresh-token` + ) return tokenInfo as OpenAITokenInfo } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to validate refresh token' @@ -131,6 +154,33 @@ export function useOpenAIOAuth() { } } + // Validate Sora session token and get access token + const validateSessionToken = async ( + sessionToken: string, + proxyId?: number | null + ): Promise => { + if (!sessionToken.trim()) { + error.value = 'Missing session token' + return null + } + loading.value = true + error.value = '' + try { + const tokenInfo = await adminAPI.accounts.validateSoraSessionToken( + sessionToken.trim(), + proxyId, + `${endpointPrefix}/st2at` + ) + return tokenInfo as OpenAITokenInfo + } catch (err: any) { + error.value = err.response?.data?.detail || 'Failed to validate session token' + appStore.showError(error.value) + return null + } finally { + loading.value = false + } + } + // Build credentials for OpenAI OAuth account const buildCredentials = (tokenInfo: OpenAITokenInfo): Record => { const creds: Record = { @@ -172,6 +222,7 @@ export function useOpenAIOAuth() { // State authUrl, sessionId, + oauthState, loading, error, // Methods @@ -179,6 +230,7 @@ export function useOpenAIOAuth() { generateAuthUrl, exchangeAuthCode, validateRefreshToken, + validateSessionToken, buildCredentials, buildExtraInfo } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 293af1da..0dd87f8a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1740,9 +1740,13 @@ export default { refreshTokenAuth: 'Manual RT Input', refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line', + sessionTokenAuth: 'Manual ST Input', + sessionTokenDesc: 'Enter your existing Sora Session Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', + sessionTokenPlaceholder: 'Paste your Sora Session Token...\nSupports multiple, one per line', validating: 'Validating...', validateAndCreate: 'Validate & Create Account', - pleaseEnterRefreshToken: 'Please enter Refresh Token' + pleaseEnterRefreshToken: 'Please enter Refresh Token', + pleaseEnterSessionToken: 'Please enter Session Token' }, // Gemini specific gemini: { @@ -1963,6 +1967,7 @@ export default { reAuthorizeAccount: 'Re-Authorize Account', claudeCodeAccount: 'Claude Code Account', openaiAccount: 'OpenAI Account', + soraAccount: 'Sora Account', geminiAccount: 'Gemini Account', antigravityAccount: 'Antigravity Account', inputMethod: 'Input Method', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 08f1aeef..f28045e3 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1879,9 +1879,13 @@ export default { refreshTokenAuth: '手动输入 RT', refreshTokenDesc: '输入您已有的 OpenAI Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。', refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个,每行一个', + sessionTokenAuth: '手动输入 ST', + sessionTokenDesc: '输入您已有的 Sora Session Token,支持批量输入(每行一个),系统将自动验证并创建账号。', + sessionTokenPlaceholder: '粘贴您的 Sora Session Token...\n支持多个,每行一个', validating: '验证中...', validateAndCreate: '验证并创建账号', - pleaseEnterRefreshToken: '请输入 Refresh Token' + pleaseEnterRefreshToken: '请输入 Refresh Token', + pleaseEnterSessionToken: '请输入 Session Token' }, // Gemini specific gemini: { @@ -2097,6 +2101,7 @@ export default { reAuthorizeAccount: '重新授权账号', claudeCodeAccount: 'Claude Code 账号', openaiAccount: 'OpenAI 账号', + soraAccount: 'Sora 账号', geminiAccount: 'Gemini 账号', antigravityAccount: 'Antigravity 账号', inputMethod: '输入方式', From 5d2219d299b98e113ac5ca4e8ca5b90e0593c2c2 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 19 Feb 2026 08:23:00 +0800 Subject: [PATCH 191/363] =?UTF-8?q?fix(sora):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E4=BB=A4=E7=89=8C=E5=88=B7=E6=96=B0=E8=AF=B7=E6=B1=82=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E4=B8=8E=E6=B5=81=E5=BC=8F=E9=94=99=E8=AF=AF=E8=BD=AC?= =?UTF-8?q?=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 refresh_token 恢复请求改为表单编码并匹配 OAuth 约定 - 流式错误改为 JSON 序列化,避免消息含引号或换行导致 SSE 非法 - 补充 Sora token 恢复与 failover 流式错误透传回归测试 Co-Authored-By: Claude Opus 4.6 --- .../internal/handler/sora_gateway_handler.go | 14 +++- .../handler/sora_gateway_handler_test.go | 81 +++++++++++++++++++ backend/internal/service/sora_client.go | 19 ++--- backend/internal/service/sora_client_test.go | 6 ++ 4 files changed, 107 insertions(+), 13 deletions(-) diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 9c9f53b1..3a5ddcb0 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -442,7 +443,18 @@ func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status in if streamStarted { flusher, ok := c.Writer.(http.Flusher) if ok { - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 39e2eed6..d80b959c 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -498,3 +498,84 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) { require.NotEmpty(t, hash3) require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash } + +func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) { + tests := []struct { + name string + errType string + message string + }{ + { + name: "包含双引号", + errType: "upstream_error", + message: `upstream returned "invalid" payload`, + }, + { + name: "包含换行和制表符", + errType: "rate_limit_error", + message: "line1\nline2\ttab", + }, + { + name: "包含反斜杠", + errType: "upstream_error", + message: `path C:\Users\test\file.txt not found`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &SoraGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) + + body := w.Body.String() + require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头") + require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾") + + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行") + require.Equal(t, "event: error", lines[0]) + require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀") + + jsonStr := strings.TrimPrefix(lines[1], "data: ") + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON") + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok, "JSON 中应包含 error 对象") + require.Equal(t, tt.errType, errorObj["type"]) + require.Equal(t, tt.message, errorObj["message"]) + }) + } +} + +func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &SoraGatewayHandler{} + resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`) + h.handleFailoverExhausted(c, http.StatusBadGateway, resp, true) + + body := w.Body.String() + require.True(t, strings.HasPrefix(body, "event: error\n")) + require.True(t, strings.HasSuffix(body, "\n\n")) + + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"]) +} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 38be7a04..38c1b3cc 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -779,22 +779,17 @@ func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Ac } tried[clientID] = struct{}{} - payload := map[string]any{ - "client_id": clientID, - "grant_type": "refresh_token", - "refresh_token": refreshToken, - "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", - } - bodyBytes, err := json.Marshal(payload) - if err != nil { - return "", "", "", err - } + formData := url.Values{} + formData.Set("client_id", clientID) + formData.Set("grant_type", "refresh_token") + formData.Set("refresh_token", refreshToken) + formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback") headers := http.Header{} headers.Set("Accept", "application/json") - headers.Set("Content-Type", "application/json") + headers.Set("Content-Type", "application/x-www-form-urlencoded") headers.Set("User-Agent", c.defaultUserAgent()) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false) + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false) if err != nil { lastErr = err if c.debugEnabled() { diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index 3e88c9f9..e566f06b 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -281,6 +281,12 @@ func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodPost, r.Method) require.Equal(t, "/oauth/token", r.URL.Path) + require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + require.NoError(t, r.ParseForm()) + require.Equal(t, "refresh_token", r.FormValue("grant_type")) + require.Equal(t, "refresh-token-old", r.FormValue("refresh_token")) + require.NotEmpty(t, r.FormValue("client_id")) + require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri")) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": "refresh-access-token", From be09188bdaff441cd8cfc680f573cf22448abe31 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 19 Feb 2026 08:29:51 +0800 Subject: [PATCH 192/363] =?UTF-8?q?feat(account-test):=20=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=20Sora=20=E8=B4=A6=E5=8F=B7=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=83=BD=E5=8A=9B=E6=8E=A2=E6=B5=8B=E4=B8=8E=E5=BC=B9=E7=AA=97?= =?UTF-8?q?=E4=BA=A4=E4=BA=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 后端新增 Sora2 邀请码与剩余额度探测,并补充对应结果解析 - Sora 测试流程补齐请求头与 Cloudflare 场景提示,完善单测覆盖 - 前端测试弹窗对 Sora 账号改为免选模型流程,并新增中英文提示文案 Co-Authored-By: Claude Opus 4.6 --- .../internal/service/account_test_service.go | 175 ++++++++++++++++++ .../service/account_test_service_sora_test.go | 17 +- .../components/account/AccountTestModal.vue | 36 +++- .../admin/account/AccountTestModal.vue | 41 ++-- frontend/src/i18n/locales/en.ts | 4 + frontend/src/i18n/locales/zh.ts | 4 + 6 files changed, 251 insertions(+), 26 deletions(-) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 67c9ef0c..e6c1cf4c 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -34,6 +34,9 @@ const ( chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" + soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" + soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" + soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -498,6 +501,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") // Get proxy URL proxyURL := "" @@ -543,6 +549,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * subReq.Header.Set("Authorization", "Bearer "+authToken) subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") subReq.Header.Set("Accept", "application/json") + subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") + subReq.Header.Set("Origin", "https://sora.chatgpt.com") + subReq.Header.Set("Referer", "https://sora.chatgpt.com/") subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) if subErr != nil { @@ -566,10 +575,134 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * } } + // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) return nil } +func (s *AccountTestService) testSora2Capabilities( + c *gin.Context, + ctx context.Context, + account *Account, + authToken string, + proxyURL string, + enableTLSFingerprint bool, +) { + inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) + return + } + + if inviteStatus == http.StatusUnauthorized { + bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraBootstrapURL, + proxyURL, + enableTLSFingerprint, + ) + if bootstrapErr == nil && bootstrapStatus == http.StatusOK { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) + inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) + return + } + } + } + + if inviteStatus != http.StatusOK { + if isCloudflareChallengeResponse(inviteStatus, inviteBody) { + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 invite check blocked by Cloudflare challenge (HTTP 403)", inviteHeader, inviteBody)}) + return + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) + return + } + + if summary := parseSoraInviteSummary(inviteBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) + } + + remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraRemainingURL, + proxyURL, + enableTLSFingerprint, + ) + if remainingErr != nil { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) + return + } + if remainingStatus != http.StatusOK { + if isCloudflareChallengeResponse(remainingStatus, remainingBody) { + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 remaining check blocked by Cloudflare challenge (HTTP 403)", remainingHeader, remainingBody)}) + return + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) + return + } + if summary := parseSoraRemainingSummary(remainingBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) + } +} + +func (s *AccountTestService) fetchSoraTestEndpoint( + ctx context.Context, + account *Account, + authToken string, + url string, + proxyURL string, + enableTLSFingerprint bool, +) (int, http.Header, []byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return 0, nil, nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint) + if err != nil { + return 0, nil, nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp.StatusCode, resp.Header, nil, readErr + } + return resp.StatusCode, resp.Header, body, nil +} + func parseSoraSubscriptionSummary(body []byte) string { var subResp struct { Data []struct { @@ -604,6 +737,48 @@ func parseSoraSubscriptionSummary(body []byte) string { return "Subscription: " + strings.Join(parts, " | ") } +func parseSoraInviteSummary(body []byte) string { + var inviteResp struct { + InviteCode string `json:"invite_code"` + RedeemedCount int64 `json:"redeemed_count"` + TotalCount int64 `json:"total_count"` + } + if err := json.Unmarshal(body, &inviteResp); err != nil { + return "" + } + + parts := []string{"Sora2: supported"} + if inviteResp.InviteCode != "" { + parts = append(parts, "invite="+inviteResp.InviteCode) + } + if inviteResp.TotalCount > 0 { + parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) + } + return strings.Join(parts, " | ") +} + +func parseSoraRemainingSummary(body []byte) string { + var remainingResp struct { + RateLimitAndCreditBalance struct { + EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` + RateLimitReached bool `json:"rate_limit_reached"` + AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` + } `json:"rate_limit_and_credit_balance"` + } + if err := json.Unmarshal(body, &remainingResp); err != nil { + return "" + } + info := remainingResp.RateLimitAndCreditBalance + parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} + if info.RateLimitReached { + parts = append(parts, "rate_limited=true") + } + if info.AccessResetsInSeconds > 0 { + parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) + } + return strings.Join(parts, " | ") +} + func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { if s == nil || s.cfg == nil { return false diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go index fbbc8ff1..0c09bf18 100644 --- a/backend/internal/service/account_test_service_sora_test.go +++ b/backend/internal/service/account_test_service_sora_test.go @@ -61,6 +61,8 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin responses: []*http.Response{ newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), + newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), + newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), }, } svc := &AccountTestService{ @@ -92,17 +94,21 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin err := svc.testSoraAccountConnection(c, account) require.NoError(t, err) - require.Len(t, upstream.requests, 2) + require.Len(t, upstream.requests, 4) require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) + require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) + require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) - require.Equal(t, []bool{true, true}, upstream.tlsFlags) + require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) body := rec.Body.String() require.Contains(t, body, `"type":"test_start"`) require.Contains(t, body, "Sora connection OK - Email: demo@example.com") require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") + require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") + require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") require.Contains(t, body, `"type":"test_complete","success":true`) } @@ -111,6 +117,8 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc responses: []*http.Response{ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), }, } svc := &AccountTestService{httpUpstream: upstream} @@ -128,10 +136,11 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc err := svc.testSoraAccountConnection(c, account) require.NoError(t, err) - require.Len(t, upstream.requests, 2) + require.Len(t, upstream.requests, 4) body := rec.Body.String() require.Contains(t, body, "Sora connection OK - User: demo-user") require.Contains(t, body, "Subscription check returned 403") + require.Contains(t, body, "Sora2 invite check returned 401") require.Contains(t, body, `"type":"test_complete","success":true`) } @@ -169,6 +178,7 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChal responses: []*http.Response{ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), newJSONResponse(http.StatusForbidden, `Just a moment...`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), }, } svc := &AccountTestService{httpUpstream: upstream} @@ -188,6 +198,7 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChal require.NoError(t, err) body := rec.Body.String() require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") require.Contains(t, body, `"type":"test_complete","success":true`) } diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index dfa1503e..792a8f45 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -41,7 +41,7 @@
-
+
@@ -54,6 +54,12 @@ :placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')" />
+
+ {{ t('admin.accounts.soraTestHint') }} +
@@ -135,12 +141,12 @@
- {{ t('admin.accounts.testModel') }} + {{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
- {{ t('admin.accounts.testPrompt') }} + {{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
@@ -156,10 +162,10 @@
-
+
@@ -54,6 +54,12 @@ :placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')" />
+
+ {{ t('admin.accounts.soraTestHint') }} +
@@ -114,12 +120,12 @@
- {{ t('admin.accounts.testModel') }} + {{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
- {{ t('admin.accounts.testPrompt') }} + {{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
@@ -135,10 +141,10 @@ + +
-
{{ t('admin.users.newBalance') }}:${{ formatBalance(calculateNewBalance()) }}
+
{{ t('admin.users.newBalance') }}:${{ formatBalance(calculateNewBalance()) }}
@@ -271,6 +271,7 @@ import { ref } from 'vue' import { useI18n } from 'vue-i18n' import { formatDateTime, formatReasoningEffort } from '@/utils/format' +import { resolveUsageRequestType } from '@/utils/usageRequestType' import DataTable from '@/components/common/DataTable.vue' import EmptyState from '@/components/common/EmptyState.vue' import Icon from '@/components/icons/Icon.vue' @@ -289,6 +290,21 @@ const tokenTooltipVisible = ref(false) const tokenTooltipPosition = ref({ x: 0, y: 0 }) const tokenTooltipData = ref(null) +const getRequestTypeLabel = (row: AdminUsageLog): string => { + const requestType = resolveUsageRequestType(row) + if (requestType === 'ws_v2') return t('usage.ws') + if (requestType === 'stream') return t('usage.stream') + if (requestType === 'sync') return t('usage.sync') + return t('usage.unknown') +} + +const getRequestTypeBadgeClass = (row: AdminUsageLog): string => { + const requestType = resolveUsageRequestType(row) + if (requestType === 'ws_v2') return 'bg-violet-100 text-violet-800 dark:bg-violet-900 dark:text-violet-200' + if (requestType === 'stream') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200' + if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200' + return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' +} const formatCacheTokens = (tokens: number): string => { if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M` if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K` diff --git a/frontend/src/components/admin/user/UserEditModal.vue b/frontend/src/components/admin/user/UserEditModal.vue index 70ebd2d3..e537dbf6 100644 --- a/frontend/src/components/admin/user/UserEditModal.vue +++ b/frontend/src/components/admin/user/UserEditModal.vue @@ -37,6 +37,14 @@ +
+ +
+ + GB +
+

{{ t('admin.users.soraStorageQuotaHint') }}

+