diff --git a/README_CN.md b/README_CN.md index 8129c3b2..707f0201 100644 --- a/README_CN.md +++ b/README_CN.md @@ -300,6 +300,27 @@ default: rate_multiplier: 1.0 ``` +### Sora 媒体签名 URL(可选) + +当配置 `gateway.sora_media_signing_key` 且 `gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL(`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query)。 + +```yaml +gateway: + # /sora/media 是否强制要求 API Key(默认 false) + sora_media_require_api_key: false + # 媒体临时签名密钥(为空则禁用签名) + sora_media_signing_key: "your-signing-key" + # 临时签名 URL 有效期(秒) + sora_media_signed_url_ttl_seconds: 900 +``` + +> 若未配置签名密钥,`/sora/media-signed` 将返回 503。 +> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true,仅允许携带 API Key 的 `/sora/media` 访问。 + +访问策略说明: +- `/sora/media`:内部调用或客户端携带 API Key 才能下载 +- `/sora/media-signed`:外部可访问,但有签名 + 过期控制 + `config.yaml` 还支持以下安全相关配置: - `cors.allowed_origins` 配置 CORS 白名单 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index b8668665..1d88b612 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -87,10 +87,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) soraAccountRepository := repository.NewSoraAccountRepository(db) + sora2APIService := service.NewSora2APIService(configConfig) + sora2APISyncService := service.NewSora2APISyncService(sora2APIService, accountRepository) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, sora2APISyncService, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -162,11 +164,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) + modelHandler := admin.NewModelHandler(sora2APIService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, modelHandler) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, sora2APIService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) + soraGatewayService := service.NewSoraGatewayService(sora2APIService, httpUpstream, rateLimitService, configConfig) + soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -177,7 +182,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, sora2APISyncService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ diff --git a/backend/ent/group.go b/backend/ent/group.go index 0d0c0538..0a32543b 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -52,6 +52,14 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` + // SoraImagePrice360 holds the value of the "sora_image_price_360" field. + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + // SoraImagePrice540 holds the value of the "sora_image_price_540" field. + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field. + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. + SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -170,7 +178,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: values[i] = new(sql.NullInt64) @@ -309,6 +317,34 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } + case group.FieldSoraImagePrice360: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i]) + } else if value.Valid { + _m.SoraImagePrice360 = new(float64) + *_m.SoraImagePrice360 = value.Float64 + } + case group.FieldSoraImagePrice540: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i]) + } else if value.Valid { + _m.SoraImagePrice540 = new(float64) + *_m.SoraImagePrice540 = value.Float64 + } + case group.FieldSoraVideoPricePerRequest: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequest = new(float64) + *_m.SoraVideoPricePerRequest = value.Float64 + } + case group.FieldSoraVideoPricePerRequestHd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequestHd = new(float64) + *_m.SoraVideoPricePerRequestHd = value.Float64 + } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -479,6 +515,26 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.SoraImagePrice360; v != nil { + builder.WriteString("sora_image_price_360=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraImagePrice540; v != nil { + builder.WriteString("sora_image_price_540=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequest; v != nil { + builder.WriteString("sora_video_price_per_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequestHd; v != nil { + builder.WriteString("sora_video_price_per_request_hd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index d66d3edc..7470dd82 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,6 +49,14 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" + // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database. + FieldSoraImagePrice360 = "sora_image_price_360" + // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database. + FieldSoraImagePrice540 = "sora_image_price_540" + // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database. + FieldSoraVideoPricePerRequest = "sora_video_price_per_request" + // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. + FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -149,6 +157,10 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, + FieldSoraImagePrice360, + FieldSoraImagePrice540, + FieldSoraVideoPricePerRequest, + FieldSoraVideoPricePerRequestHd, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldModelRouting, @@ -307,6 +319,26 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } +// BySoraImagePrice360 orders the results by the sora_image_price_360 field. +func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc() +} + +// BySoraImagePrice540 orders the results by the sora_image_price_540 field. +func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc() +} + +// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field. +func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc() +} + +// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field. +func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() +} + // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 6ce9e4c6..3f8f4c04 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,6 +140,26 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } +// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ. +func SoraImagePrice360(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ. +func SoraImagePrice540(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ. +func SoraVideoPricePerRequest(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ. +func SoraVideoPricePerRequestHd(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1010,6 +1030,206 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } +// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field. +func SoraImagePrice360In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field. +func SoraImagePrice360GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field. +func SoraImagePrice360LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field. +func SoraImagePrice540In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field. +func SoraImagePrice540GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field. +func SoraImagePrice540LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540)) +} + +// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540)) +} + +// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd)) +} + +// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) +} + // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 0f251e0b..ac5cb4d5 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,6 +258,62 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice360(v) + return _c +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice360(*v) + } + return _c +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice540(v) + return _c +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice540(*v) + } + return _c +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequest(v) + return _c +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequest(*v) + } + return _c +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequestHd(v) + return _c +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequestHd(*v) + } + return _c +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -632,6 +688,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } + if value, ok := _c.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + _node.SoraImagePrice360 = &value + } + if value, ok := _c.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + _node.SoraImagePrice540 = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + _node.SoraVideoPricePerRequest = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + _node.SoraVideoPricePerRequestHd = &value + } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -1092,6 +1164,102 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice360, v) + return u +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice360) + return u +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice360, v) + return u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice360) + return u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice540, v) + return u +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice540) + return u +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice540, v) + return u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice540) + return u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequest) + return u +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequest) + return u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequestHd) + return u +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequestHd) + return u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1539,6 +1707,118 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2163,6 +2443,118 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index c3cc2708..528a7fe9 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -354,6 +354,114 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -817,6 +925,42 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1472,6 +1616,114 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -1965,6 +2217,42 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d1f05186..fe1f80a8 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -224,6 +224,10 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, @@ -499,6 +503,7 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, + {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64}, @@ -514,31 +519,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -547,32 +552,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[26]}, }, { Name: "usagelog_model", @@ -587,12 +592,12 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 9b330616..b3d1e410 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -3836,61 +3836,69 @@ func (m *AccountGroupMutation) ResetEdge(name string) error { // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + sora_image_price_360 *float64 + addsora_image_price_360 *float64 + sora_image_price_540 *float64 + addsora_image_price_540 *float64 + sora_video_price_per_request *float64 + addsora_video_price_per_request *float64 + sora_video_price_per_request_hd *float64 + addsora_video_price_per_request_hd *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -4873,6 +4881,286 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (m *GroupMutation) SetSoraImagePrice360(f float64) { + m.sora_image_price_360 = &f + m.addsora_image_price_360 = nil +} + +// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation. +func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) { + v := m.sora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err) + } + return oldValue.SoraImagePrice360, nil +} + +// AddSoraImagePrice360 adds f to the "sora_image_price_360" field. +func (m *GroupMutation) AddSoraImagePrice360(f float64) { + if m.addsora_image_price_360 != nil { + *m.addsora_image_price_360 += f + } else { + m.addsora_image_price_360 = &f + } +} + +// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) { + v := m.addsora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (m *GroupMutation) ClearSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + m.clearedFields[group.FieldSoraImagePrice360] = struct{}{} +} + +// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice360Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice360] + return ok +} + +// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field. +func (m *GroupMutation) ResetSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + delete(m.clearedFields, group.FieldSoraImagePrice360) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (m *GroupMutation) SetSoraImagePrice540(f float64) { + m.sora_image_price_540 = &f + m.addsora_image_price_540 = nil +} + +// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation. +func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) { + v := m.sora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err) + } + return oldValue.SoraImagePrice540, nil +} + +// AddSoraImagePrice540 adds f to the "sora_image_price_540" field. +func (m *GroupMutation) AddSoraImagePrice540(f float64) { + if m.addsora_image_price_540 != nil { + *m.addsora_image_price_540 += f + } else { + m.addsora_image_price_540 = &f + } +} + +// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) { + v := m.addsora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (m *GroupMutation) ClearSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + m.clearedFields[group.FieldSoraImagePrice540] = struct{}{} +} + +// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice540Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice540] + return ok +} + +// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field. +func (m *GroupMutation) ResetSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + delete(m.clearedFields, group.FieldSoraImagePrice540) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) { + m.sora_video_price_per_request = &f + m.addsora_video_price_per_request = nil +} + +// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) { + v := m.sora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err) + } + return oldValue.SoraVideoPricePerRequest, nil +} + +// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field. +func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) { + if m.addsora_video_price_per_request != nil { + *m.addsora_video_price_per_request += f + } else { + m.addsora_video_price_per_request = &f + } +} + +// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) { + v := m.addsora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{} +} + +// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest] + return ok +} + +// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequest) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) { + m.sora_video_price_per_request_hd = &f + m.addsora_video_price_per_request_hd = nil +} + +// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.sora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err) + } + return oldValue.SoraVideoPricePerRequestHd, nil +} + +// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) { + if m.addsora_video_price_per_request_hd != nil { + *m.addsora_video_price_per_request_hd += f + } else { + m.addsora_video_price_per_request_hd = &f + } +} + +// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.addsora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{} +} + +// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd] + return ok +} + +// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -5422,7 +5710,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 21) + fields := make([]string, 0, 25) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -5474,6 +5762,18 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.sora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.sora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.sora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.sora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -5528,6 +5828,14 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() + case group.FieldSoraImagePrice360: + return m.SoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.SoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.SoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.SoraVideoPricePerRequestHd() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -5579,6 +5887,14 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) + case group.FieldSoraImagePrice360: + return m.OldSoraImagePrice360(ctx) + case group.FieldSoraImagePrice540: + return m.OldSoraImagePrice540(ctx) + case group.FieldSoraVideoPricePerRequest: + return m.OldSoraVideoPricePerRequest(ctx) + case group.FieldSoraVideoPricePerRequestHd: + return m.OldSoraVideoPricePerRequestHd(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -5715,6 +6031,34 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequestHd(v) + return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -5775,6 +6119,18 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.addsora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.addsora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.addsora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.addsora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -5802,6 +6158,14 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() + case group.FieldSoraImagePrice360: + return m.AddedSoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.AddedSoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.AddedSoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.AddedSoraVideoPricePerRequestHd() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() } @@ -5869,6 +6233,34 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequestHd(v) + return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -5908,6 +6300,18 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } + if m.FieldCleared(group.FieldSoraImagePrice360) { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.FieldCleared(group.FieldSoraImagePrice540) { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequest) { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } @@ -5952,6 +6356,18 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ClearSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ClearSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ClearSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ClearSoraVideoPricePerRequestHd() + return nil case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil @@ -6017,6 +6433,18 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ResetSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ResetSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ResetSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ResetSoraVideoPricePerRequestHd() + return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -11504,6 +11932,7 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string + media_type *string created_at *time.Time clearedFields map[string]struct{} user *int64 @@ -13130,6 +13559,55 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } +// SetMediaType sets the "media_type" field. +func (m *UsageLogMutation) SetMediaType(s string) { + m.media_type = &s +} + +// MediaType returns the value of the "media_type" field in the mutation. +func (m *UsageLogMutation) MediaType() (r string, exists bool) { + v := m.media_type + if v == nil { + return + } + return *v, true +} + +// OldMediaType returns the old "media_type" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMediaType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMediaType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMediaType: %w", err) + } + return oldValue.MediaType, nil +} + +// ClearMediaType clears the value of the "media_type" field. +func (m *UsageLogMutation) ClearMediaType() { + m.media_type = nil + m.clearedFields[usagelog.FieldMediaType] = struct{}{} +} + +// MediaTypeCleared returns if the "media_type" field was cleared in this mutation. +func (m *UsageLogMutation) MediaTypeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldMediaType] + return ok +} + +// ResetMediaType resets all changes to the "media_type" field. +func (m *UsageLogMutation) ResetMediaType() { + m.media_type = nil + delete(m.clearedFields, usagelog.FieldMediaType) +} + // SetCreatedAt sets the "created_at" field. func (m *UsageLogMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -13335,7 +13813,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 31) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -13423,6 +13901,9 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } + if m.media_type != nil { + fields = append(fields, usagelog.FieldMediaType) + } if m.created_at != nil { fields = append(fields, usagelog.FieldCreatedAt) } @@ -13492,6 +13973,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() + case usagelog.FieldMediaType: + return m.MediaType() case usagelog.FieldCreatedAt: return m.CreatedAt() } @@ -13561,6 +14044,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) + case usagelog.FieldMediaType: + return m.OldMediaType(ctx) case usagelog.FieldCreatedAt: return m.OldCreatedAt(ctx) } @@ -13775,6 +14260,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil + case usagelog.FieldMediaType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMediaType(v) + return nil case usagelog.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -14055,6 +14547,9 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } + if m.FieldCleared(usagelog.FieldMediaType) { + fields = append(fields, usagelog.FieldMediaType) + } return fields } @@ -14093,6 +14588,9 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil + case usagelog.FieldMediaType: + m.ClearMediaType() + return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -14188,6 +14686,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil + case usagelog.FieldMediaType: + m.ResetMediaType() + return nil case usagelog.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 1e3f4cbe..15b02ad1 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -278,11 +278,11 @@ func init() { // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[14].Descriptor() + groupDescClaudeCodeOnly := groupFields[18].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[17].Descriptor() + groupDescModelRoutingEnabled := groupFields[21].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) promocodeFields := schema.PromoCode{}.Fields() @@ -647,8 +647,12 @@ func init() { usagelogDescImageSize := usagelogFields[28].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) + // usagelogDescMediaType is the schema descriptor for media_type field. + usagelogDescMediaType := usagelogFields[29].Descriptor() + // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[29].Descriptor() + usagelogDescCreatedAt := usagelogFields[30].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 5d0a1e9a..7fa04b8a 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -87,6 +87,24 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Sora 按次计费配置(阶段 1) + field.Float("sora_image_price_360"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_image_price_540"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request_hd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index fc7c7165..602f23f6 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -118,6 +118,11 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), + // 媒体类型字段(sora 使用) + field.String("media_type"). + MaxLen(16). + Optional(). + Nillable(), // 时间戳(只有 created_at,日志不可修改) field.Time("created_at"). diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 81c466b4..63a14197 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -80,6 +80,8 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` + // MediaType holds the value of the "media_type" field. + MediaType *string `json:"media_type,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -171,7 +173,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -378,6 +380,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } + case usagelog.FieldMediaType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field media_type", values[i]) + } else if value.Valid { + _m.MediaType = new(string) + *_m.MediaType = value.String + } case usagelog.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -548,6 +557,11 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.MediaType; v != nil { + builder.WriteString("media_type=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteByte(')') diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 980f1e58..3ea5d054 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -72,6 +72,8 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" + // FieldMediaType holds the string denoting the media_type field in the database. + FieldMediaType = "media_type" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // EdgeUser holds the string denoting the user edge name in mutations. @@ -155,6 +157,7 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, + FieldMediaType, FieldCreatedAt, } @@ -211,6 +214,8 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error + // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + MediaTypeValidator func(string) error // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time ) @@ -368,6 +373,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } +// ByMediaType orders the results by the media_type field. +func ByMediaType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMediaType, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index 28e2ab4c..0a33dba2 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } +// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. +func MediaType(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) @@ -1440,6 +1445,81 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } +// MediaTypeEQ applies the EQ predicate on the "media_type" field. +func MediaTypeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + +// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. +func MediaTypeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v)) +} + +// MediaTypeIn applies the In predicate on the "media_type" field. +func MediaTypeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...)) +} + +// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. +func MediaTypeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...)) +} + +// MediaTypeGT applies the GT predicate on the "media_type" field. +func MediaTypeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldMediaType, v)) +} + +// MediaTypeGTE applies the GTE predicate on the "media_type" field. +func MediaTypeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v)) +} + +// MediaTypeLT applies the LT predicate on the "media_type" field. +func MediaTypeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldMediaType, v)) +} + +// MediaTypeLTE applies the LTE predicate on the "media_type" field. +func MediaTypeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v)) +} + +// MediaTypeContains applies the Contains predicate on the "media_type" field. +func MediaTypeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldMediaType, v)) +} + +// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. +func MediaTypeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v)) +} + +// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. +func MediaTypeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v)) +} + +// MediaTypeIsNil applies the IsNil predicate on the "media_type" field. +func MediaTypeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldMediaType)) +} + +// MediaTypeNotNil applies the NotNil predicate on the "media_type" field. +func MediaTypeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldMediaType)) +} + +// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. +func MediaTypeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v)) +} + +// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. +func MediaTypeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index a17d6507..668a0ede 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } +// SetMediaType sets the "media_type" field. +func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate { + _c.mutation.SetMediaType(v) + return _c +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { + if v != nil { + _c.SetMediaType(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { _c.mutation.SetCreatedAt(v) @@ -627,6 +641,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _c.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} } @@ -762,6 +781,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } + if value, ok := _c.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + _node.MediaType = &value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -1407,6 +1430,24 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert { + u.Set(usagelog.FieldMediaType, v) + return u +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldMediaType) + return u +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { + u.SetNull(usagelog.FieldMediaType) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2040,6 +2081,27 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + // Exec executes the query. func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2839,6 +2901,27 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + // Exec executes the query. func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 571a7b3c..22f2613f 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -612,6 +612,26 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { + _u.mutation.ClearMediaType() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { return _u.SetUserID(v.ID) @@ -726,6 +746,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -894,6 +919,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -1639,6 +1670,26 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { + _u.mutation.ClearMediaType() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { return _u.SetUserID(v.ID) @@ -1766,6 +1817,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -1951,6 +2007,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 00a78480..5dd2b415 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -58,6 +58,7 @@ type Config struct { UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Sora2API Sora2APIConfig `mapstructure:"sora2api"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` @@ -204,6 +205,24 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } +// Sora2APIConfig Sora2API 服务配置 +type Sora2APIConfig struct { + // BaseURL Sora2API 服务地址(例如 http://localhost:8000) + BaseURL string `mapstructure:"base_url"` + // APIKey Sora2API OpenAI 兼容接口的 API Key + APIKey string `mapstructure:"api_key"` + // AdminUsername 管理员用户名(用于 token 同步) + AdminUsername string `mapstructure:"admin_username"` + // AdminPassword 管理员密码(用于 token 同步) + AdminPassword string `mapstructure:"admin_password"` + // AdminTokenTTLSeconds 管理员 Token 缓存时长(秒) + AdminTokenTTLSeconds int `mapstructure:"admin_token_ttl_seconds"` + // AdminTimeoutSeconds 管理接口请求超时(秒) + AdminTimeoutSeconds int `mapstructure:"admin_timeout_seconds"` + // TokenImportMode token 导入模式:at/offline + TokenImportMode string `mapstructure:"token_import_mode"` +} + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -258,6 +277,24 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` + // Sora 专用配置 + // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) + SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` + // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) + SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` + // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) + SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` + // SoraStreamMode: stream 强制策略(force/error) + SoraStreamMode string `mapstructure:"sora_stream_mode"` + // SoraModelFilters: 模型列表过滤配置 + SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` + // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key + SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` + // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) + SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` + // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) + SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` + // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) MaxAccountSwitches int `mapstructure:"max_account_switches"` // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) @@ -273,6 +310,12 @@ type GatewayConfig struct { TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` } +// SoraModelFiltersConfig Sora 模型过滤配置 +type SoraModelFiltersConfig struct { + // HidePromptEnhance 是否隐藏 prompt-enhance 模型 + HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` +} + // TLSFingerprintConfig TLS指纹伪装配置 // 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端 type TLSFingerprintConfig struct { @@ -823,6 +866,13 @@ func setDefaults() { viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) + viper.SetDefault("gateway.sora_request_timeout_seconds", 180) + viper.SetDefault("gateway.sora_stream_mode", "force") + viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true) + viper.SetDefault("gateway.sora_media_require_api_key", true) + viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) @@ -869,6 +919,15 @@ func setDefaults() { viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") + + // Sora2API + viper.SetDefault("sora2api.base_url", "") + viper.SetDefault("sora2api.api_key", "") + viper.SetDefault("sora2api.admin_username", "") + viper.SetDefault("sora2api.admin_password", "") + viper.SetDefault("sora2api.admin_token_ttl_seconds", 900) + viper.SetDefault("sora2api.admin_timeout_seconds", 10) + viper.SetDefault("sora2api.token_import_mode", "at") } func (c *Config) Validate() error { @@ -1085,6 +1144,25 @@ func (c *Config) Validate() error { if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if c.Gateway.SoraMaxBodySize < 0 { + return fmt.Errorf("gateway.sora_max_body_size must be non-negative") + } + if c.Gateway.SoraStreamTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative") + } + if c.Gateway.SoraRequestTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative") + } + if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 { + return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative") + } + if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" { + switch mode { + case "force", "error": + default: + return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") + } + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -1181,6 +1259,25 @@ func (c *Config) Validate() error { c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds { return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds") } + if strings.TrimSpace(c.Sora2API.BaseURL) != "" { + if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil { + return fmt.Errorf("sora2api.base_url invalid: %w", err) + } + warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL) + } + if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" { + switch mode { + case "at", "offline": + default: + return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline") + } + } + if c.Sora2API.AdminTokenTTLSeconds < 0 { + return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative") + } + if c.Sora2API.AdminTimeoutSeconds < 0 { + return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative") + } if c.Ops.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 926624d2..1af570d9 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -38,6 +38,10 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) @@ -49,7 +53,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` @@ -61,6 +65,10 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 模型路由配置(仅 anthropic 平台使用) @@ -167,6 +175,10 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, ModelRouting: req.ModelRouting, @@ -209,6 +221,10 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, ModelRouting: req.ModelRouting, diff --git a/backend/internal/handler/admin/model_handler.go b/backend/internal/handler/admin/model_handler.go new file mode 100644 index 00000000..035b09bd --- /dev/null +++ b/backend/internal/handler/admin/model_handler.go @@ -0,0 +1,55 @@ +package admin + +import ( + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ModelHandler handles admin model listing requests. +type ModelHandler struct { + sora2apiService *service.Sora2APIService +} + +// NewModelHandler creates a new ModelHandler. +func NewModelHandler(sora2apiService *service.Sora2APIService) *ModelHandler { + return &ModelHandler{ + sora2apiService: sora2apiService, + } +} + +// List handles listing models for a specific platform +// GET /api/v1/admin/models?platform=sora +func (h *ModelHandler) List(c *gin.Context) { + platform := strings.TrimSpace(strings.ToLower(c.Query("platform"))) + if platform == "" { + response.BadRequest(c, "platform is required") + return + } + + switch platform { + case service.PlatformSora: + if h.sora2apiService == nil || !h.sora2apiService.Enabled() { + response.Error(c, http.StatusServiceUnavailable, "sora2api not configured") + return + } + models, err := h.sora2apiService.ListModels(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusServiceUnavailable, "failed to fetch sora models") + return + } + ids := make([]string, 0, len(models)) + for _, m := range models { + if strings.TrimSpace(m.ID) != "" { + ids = append(ids, m.ID) + } + } + response.Success(c, ids) + default: + response.BadRequest(c, "unsupported platform") + } +} diff --git a/backend/internal/handler/admin/model_handler_test.go b/backend/internal/handler/admin/model_handler_test.go new file mode 100644 index 00000000..e61dc064 --- /dev/null +++ b/backend/internal/handler/admin/model_handler_test.go @@ -0,0 +1,87 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +func TestModelHandlerListSoraSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`)) + })) + t.Cleanup(upstream.Close) + + cfg := &config.Config{} + cfg.Sora2API.BaseURL = upstream.URL + cfg.Sora2API.APIKey = "test-key" + soraService := service.NewSora2APIService(cfg) + + h := NewModelHandler(soraService) + router := gin.New() + router.GET("/admin/models", h.List) + + req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) + } + var resp response.Response + if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + if resp.Code != 0 { + t.Fatalf("响应 code=%d", resp.Code) + } + data, ok := resp.Data.([]any) + if !ok { + t.Fatalf("响应 data 类型错误") + } + if len(data) != 2 { + t.Fatalf("模型数量不符: %d", len(data)) + } +} + +func TestModelHandlerListSoraNotConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewModelHandler(&service.Sora2APIService{}) + router := gin.New() + router.GET("/admin/models", h.List) + + req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) + } +} + +func TestModelHandlerListInvalidPlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := NewModelHandler(&service.Sora2APIService{}) + router := gin.New() + router.GET("/admin/models", h.List) + + req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) + } +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d58a8a29..b44c3225 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -136,6 +136,10 @@ func groupFromServiceBase(g *service.Group) Group { ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, CreatedAt: g.CreatedAt, @@ -379,6 +383,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, ImageSize: l.ImageSize, + MediaType: l.MediaType, UserAgent: l.UserAgent, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 938d707c..3ae899ee 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -61,6 +61,12 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + // Sora 按次计费配置 + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` + // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` @@ -246,6 +252,7 @@ type UsageLog struct { // 图片生成字段 ImageCount int `json:"image_count"` ImageSize *string `json:"image_size"` + MediaType *string `json:"media_type"` // User-Agent UserAgent *string `json:"user_agent"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 70ea51bf..983cc6b3 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -29,6 +29,7 @@ type GatewayHandler struct { geminiCompatService *service.GeminiMessagesCompatService antigravityGatewayService *service.AntigravityGatewayService userService *service.UserService + sora2apiService *service.Sora2APIService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int @@ -41,6 +42,7 @@ func NewGatewayHandler( geminiCompatService *service.GeminiMessagesCompatService, antigravityGatewayService *service.AntigravityGatewayService, userService *service.UserService, + sora2apiService *service.Sora2APIService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, cfg *config.Config, @@ -62,6 +64,7 @@ func NewGatewayHandler( geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, userService: userService, + sora2apiService: sora2apiService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, @@ -478,6 +481,26 @@ func (h *GatewayHandler) Models(c *gin.Context) { groupID = &apiKey.Group.ID platform = apiKey.Group.Platform } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" { + platform = forcedPlatform + } + + if platform == service.PlatformSora { + if h.sora2apiService == nil || !h.sora2apiService.Enabled() { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured") + return + } + models, err := h.sora2apiService.ListModels(c.Request.Context()) + if err != nil { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models") + return + } + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": models, + }) + return + } // Get available models from account configurations (without platform filter) availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 5b1b317d..7905148c 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -23,6 +23,7 @@ type AdminHandlers struct { Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler + Model *admin.ModelHandler } // Handlers contains all HTTP handlers @@ -36,6 +37,7 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler + SoraGateway *SoraGatewayHandler Setting *SettingHandler } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go new file mode 100644 index 00000000..94f712df --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler.go @@ -0,0 +1,474 @@ +package handler + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "path" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// SoraGatewayHandler handles Sora chat completions requests +type SoraGatewayHandler struct { + gatewayService *service.GatewayService + soraGatewayService *service.SoraGatewayService + billingCacheService *service.BillingCacheService + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + streamMode string + sora2apiBaseURL string + soraMediaSigningKey string +} + +// NewSoraGatewayHandler creates a new SoraGatewayHandler +func NewSoraGatewayHandler( + gatewayService *service.GatewayService, + soraGatewayService *service.SoraGatewayService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, + cfg *config.Config, +) *SoraGatewayHandler { + pingInterval := time.Duration(0) + maxAccountSwitches := 3 + streamMode := "force" + signKey := "" + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { + streamMode = mode + } + signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) + } + baseURL := "" + if cfg != nil { + baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/") + } + return &SoraGatewayHandler{ + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + billingCacheService: billingCacheService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, + streamMode: strings.ToLower(streamMode), + sora2apiBaseURL: baseURL, + soraMediaSigningKey: signKey, + } +} + +// ChatCompletions handles Sora /v1/chat/completions endpoint +func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + reqModel, _ := reqBody["model"].(string) + if reqModel == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqMessages, _ := reqBody["messages"].([]any) + if len(reqMessages) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") + return + } + + clientStream, _ := reqBody["stream"].(bool) + if !clientStream { + if h.streamMode == "error" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") + return + } + reqBody["stream"] = true + updated, err := json.Marshal(reqBody) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + body = updated + } + + setOpsRequestContext(c, reqModel, clientStream, body) + + platform := "" + if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { + platform = forced + } else if apiKey.Group != nil { + platform = apiKey.Group.Platform + } + if platform != service.PlatformSora { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") + return + } + + streamStarted := false + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + log.Printf("Increment wait count failed: %v", err) + } else if !canWait { + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) + if err != nil { + log.Printf("User concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + log.Printf("Billing eligibility check failed after wait: %v", err) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := generateOpenAISessionHash(c, reqBody) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") + if err != nil { + log.Printf("[Sora Handler] SelectAccount failed: %v", err) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + account := selection.Account + setOpsSelectedAccount(c, account.ID) + + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + clientStream, + &streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + log.Printf("Account %d: Forward request failed: %v", account.ID, err) + return + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account, userAgent, clientIP) + return + } +} + +func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && reqBody != nil { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + sessionID = strings.TrimSpace(v) + } + } + if sessionID == "" { + return "" + } + hash := sha256.Sum256([]byte(sessionID)) + return hex.EncodeToString(hash[:]) +} + +func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", + fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) +} + +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + +func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { + if streamStarted { + flusher, ok := c.Writer.(http.Flusher) + if ok { + errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { + _ = c.Error(err) + } + flusher.Flush() + } + return + } + h.errorResponse(c, status, errType, message) +} + +func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// MediaProxy proxies /tmp or /static media files from sora2api +func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) { + h.proxySoraMedia(c, false) +} + +// MediaProxySigned proxies /tmp or /static media files with signature verification +func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) { + h.proxySoraMedia(c, true) +} + +func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) { + if h.sora2apiBaseURL == "" { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "sora2api 未配置", + }, + }) + return + } + + rawPath := c.Param("filepath") + if rawPath == "" { + c.Status(http.StatusNotFound) + return + } + cleaned := path.Clean(rawPath) + if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") { + c.Status(http.StatusNotFound) + return + } + + query := c.Request.URL.Query() + if requireSignature { + if h.soraMediaSigningKey == "" { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Sora 媒体签名未配置", + }, + }) + return + } + expiresStr := strings.TrimSpace(query.Get("expires")) + signature := strings.TrimSpace(query.Get("sig")) + expires, err := strconv.ParseInt(expiresStr, 10, 64) + if err != nil || expires <= time.Now().Unix() { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "type": "authentication_error", + "message": "Sora 媒体签名已过期", + }, + }) + return + } + query.Del("sig") + query.Del("expires") + signingQuery := query.Encode() + if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "type": "authentication_error", + "message": "Sora 媒体签名无效", + }, + }) + return + } + } + + targetURL := h.sora2apiBaseURL + cleaned + if rawQuery := query.Encode(); rawQuery != "" { + targetURL += "?" + rawQuery + } + + req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil) + if err != nil { + c.Status(http.StatusBadGateway) + return + } + copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"} + for _, key := range copyHeaders { + if val := c.GetHeader(key); val != "" { + req.Header.Set(key, val) + } + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.Status(http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} { + if val := resp.Header.Get(key); val != "" { + c.Header(key, val) + } + } + c.Status(resp.StatusCode) + _, _ = io.Copy(c.Writer, resp.Body) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 2af7905e..1e3ef17d 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -26,6 +26,7 @@ func ProvideAdminHandlers( subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, + modelHandler *admin.ModelHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -45,6 +46,7 @@ func ProvideAdminHandlers( Subscription: subscriptionHandler, Usage: usageHandler, UserAttribute: userAttributeHandler, + Model: modelHandler, } } @@ -69,6 +71,7 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, + soraGatewayHandler *SoraGatewayHandler, settingHandler *SettingHandler, ) *Handlers { return &Handlers{ @@ -81,6 +84,7 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, + SoraGateway: soraGatewayHandler, Setting: settingHandler, } } @@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet( NewSubscriptionHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewSoraGatewayHandler, ProvideSettingHandler, // Admin handlers @@ -116,6 +121,7 @@ var ProviderSet = wire.NewSet( admin.NewSubscriptionHandler, admin.NewUsageHandler, admin.NewUserAttributeHandler, + admin.NewModelHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index 845d51e5..31a59fc7 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -13,6 +13,7 @@ import ( "io" "net/http" "net/url" + "os" "strings" "testing" "time" @@ -38,9 +39,7 @@ type TLSInfo struct { // TestDialerBasicConnection tests that the dialer can establish TLS connections. func TestDialerBasicConnection(t *testing.T) { - if testing.Short() { - t.Skip("skipping network test in short mode") - } + skipNetworkTest(t) // Create a dialer with default profile profile := &Profile{ @@ -74,10 +73,7 @@ func TestDialerBasicConnection(t *testing.T) { // Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) func TestJA3Fingerprint(t *testing.T) { - // Skip if network is unavailable or if running in short mode - if testing.Short() { - t.Skip("skipping integration test in short mode") - } + skipNetworkTest(t) profile := &Profile{ Name: "Claude CLI Test", @@ -178,6 +174,15 @@ func TestJA3Fingerprint(t *testing.T) { } } +func skipNetworkTest(t *testing.T) { + if testing.Short() { + t.Skip("跳过网络测试(short 模式)") + } + if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" { + t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)") + } +} + // TestDialerWithProfile tests that different profiles produce different fingerprints. func TestDialerWithProfile(t *testing.T) { // Create two dialers with different profiles @@ -317,9 +322,7 @@ type TestProfileExpectation struct { // TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. // Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... func TestAllProfiles(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } + skipNetworkTest(t) // Define all profiles to test with their expected fingerprints // These profiles are from config.yaml gateway.tls_fingerprint.profiles diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index ab890844..9308326b 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -134,6 +134,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, + group.FieldSoraImagePrice360, + group.FieldSoraImagePrice540, + group.FieldSoraVideoPricePerRequest, + group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, group.FieldModelRoutingEnabled, @@ -421,6 +425,10 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 5c4d6cf4..75684fc9 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). @@ -106,6 +110,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 963db7ba..0696c958 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, created_at" type usageLogRepository struct { client *dbent.Client @@ -114,6 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ip_address, image_count, image_size, + media_type, created_at ) VALUES ( $1, $2, $3, $4, $5, @@ -121,7 +122,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) + mediaType := nullString(log.MediaType) var requestIDArg any if requestID != "" { @@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ipAddress, log.ImageCount, imageSize, + mediaType, createdAt, } if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { @@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ipAddress sql.NullString imageCount int imageSize sql.NullString + mediaType sql.NullString createdAt time.Time ) @@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &ipAddress, &imageCount, &imageSize, + &mediaType, &createdAt, ); err != nil { return nil, err @@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } + if mediaType.Valid { + log.MediaType = &mediaType.String + } return log, nil } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 050e724d..2c1762d3 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -64,6 +64,9 @@ func RegisterAdminRoutes( // 用户属性管理 registerUserAttributeRoutes(admin, h) + + // 模型列表 + registerModelRoutes(admin, h) } } @@ -371,3 +374,7 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) } } + +func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + admin.GET("/models", h.Admin.Model.List) +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index bf019ce3..32f34e0c 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -20,6 +20,11 @@ func RegisterGatewayRoutes( cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) + soraMaxBodySize := cfg.Gateway.SoraMaxBodySize + if soraMaxBodySize <= 0 { + soraMaxBodySize = cfg.Gateway.MaxBodySize + } + soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) @@ -38,6 +43,16 @@ func RegisterGatewayRoutes( gateway.POST("/responses", h.OpenAIGateway.Responses) } + // Sora Chat Completions + soraGateway := r.Group("/v1") + soraGateway.Use(soraBodyLimit) + soraGateway.Use(clientRequestID) + soraGateway.Use(opsErrorLogger) + soraGateway.Use(gin.HandlerFunc(apiKeyAuth)) + { + soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions) + } + // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) gemini := r.Group("/v1beta") gemini.Use(bodyLimit) @@ -82,4 +97,25 @@ func RegisterGatewayRoutes( antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } + + // Sora 专用路由(强制使用 sora 平台) + soraV1 := r.Group("/sora/v1") + soraV1.Use(soraBodyLimit) + soraV1.Use(clientRequestID) + soraV1.Use(opsErrorLogger) + soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) + soraV1.Use(gin.HandlerFunc(apiKeyAuth)) + { + soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) + soraV1.GET("/models", h.Gateway.Models) + } + + // Sora 媒体代理(可选 API Key 验证) + if cfg.Gateway.SoraMediaRequireAPIKey { + r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy) + } else { + r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy) + } + // Sora 媒体代理(签名 URL,无需 API Key) + r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 398de0e0..a29bf4db 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -102,11 +102,16 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled bool // 是否启用模型路由 @@ -124,11 +129,16 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled *bool // 是否启用模型路由 @@ -273,6 +283,7 @@ type adminServiceImpl struct { groupRepo GroupRepository accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 + soraSyncService *Sora2APISyncService // Sora2API 同步服务 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -288,6 +299,7 @@ func NewAdminService( groupRepo GroupRepository, accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, + soraSyncService *Sora2APISyncService, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -301,6 +313,7 @@ func NewAdminService( groupRepo: groupRepo, accountRepo: accountRepo, soraAccountRepo: soraAccountRepo, + soraSyncService: soraSyncService, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -567,6 +580,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + soraImagePrice360 := normalizePrice(input.SoraImagePrice360) + soraImagePrice540 := normalizePrice(input.SoraImagePrice540) + soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) + soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) // 校验降级分组 if input.FallbackGroupID != nil { @@ -576,22 +593,26 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn } group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - ImagePrice1K: imagePrice1K, - ImagePrice2K: imagePrice2K, - ImagePrice4K: imagePrice4K, - ClaudeCodeOnly: input.ClaudeCodeOnly, - FallbackGroupID: input.FallbackGroupID, - ModelRouting: input.ModelRouting, + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + SoraImagePrice360: soraImagePrice360, + SoraImagePrice540: soraImagePrice540, + SoraVideoPricePerRequest: soraVideoPrice, + SoraVideoPricePerRequestHD: soraVideoPriceHD, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + ModelRouting: input.ModelRouting, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -702,6 +723,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } + if input.SoraImagePrice360 != nil { + group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) + } + if input.SoraImagePrice540 != nil { + group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) + } + if input.SoraVideoPricePerRequest != nil { + group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) + } + if input.SoraVideoPricePerRequestHD != nil { + group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) + } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -884,6 +917,9 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // 同步到 sora2api(异步,不阻塞创建) + s.syncSoraAccountAsync(account) + return account, nil } @@ -974,7 +1010,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + s.syncSoraAccountAsync(updated) + return updated, nil } // BulkUpdateAccounts updates multiple accounts in one request. @@ -990,16 +1031,23 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp return result, nil } - // Preload account platforms for mixed channel risk checks if group bindings are requested. + needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + needSoraSync := s != nil && s.soraSyncService != nil + + // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} - if input.GroupIDs != nil && !input.SkipMixedChannelCheck { + if needMixedChannelCheck || needSoraSync { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { - return nil, err - } - for _, account := range accounts { - if account != nil { - platformByID[account.ID] = account.Platform + if needMixedChannelCheck { + return nil, err + } + log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err) + } else { + for _, account := range accounts { + if account != nil { + platformByID[account.ID] = account.Platform + } } } } @@ -1086,13 +1134,46 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Success++ result.SuccessIDs = append(result.SuccessIDs, accountID) result.Results = append(result.Results, entry) + + // 批量更新后同步 sora2api + if needSoraSync { + platform := platformByID[accountID] + if platform == "" { + updated, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) + continue + } + if updated.Platform == PlatformSora { + s.syncSoraAccountAsync(updated) + } + continue + } + + if platform == PlatformSora { + updated, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) + continue + } + s.syncSoraAccountAsync(updated) + } + } } return result, nil } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - return s.accountRepo.Delete(ctx, id) + account, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return err + } + if err := s.accountRepo.Delete(ctx, id); err != nil { + return err + } + s.deleteSoraAccountAsync(account) + return nil } func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { @@ -1125,7 +1206,46 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err } - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + s.syncSoraAccountAsync(updated) + return updated, nil +} + +func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) { + if s == nil || s.soraSyncService == nil || account == nil { + return + } + if account.Platform != PlatformSora { + return + } + syncAccount := *account + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil { + log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err) + } + }() +} + +func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) { + if s == nil || s.soraSyncService == nil || account == nil { + return + } + if account.Platform != PlatformSora { + return + } + syncAccount := *account + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil { + log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err) + } + }() } // Proxy management implementations diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 662b95fb..cbdbe625 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct { bulkUpdateErr error bulkUpdateIDs []int64 bindGroupErrByID map[int64]error + getByIDsAccounts []*Account + getByIDsErr error + getByIDsCalled bool + getByIDsIDs []int64 + getByIDAccounts map[int64]*Account + getByIDErrByID map[int64]error + getByIDCalled []int64 } func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { @@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i return nil } +func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) { + s.getByIDsCalled = true + s.getByIDsIDs = append([]int64{}, ids...) + if s.getByIDsErr != nil { + return nil, s.getByIDsErr + } + return s.getByIDsAccounts, nil +} + +func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) { + s.getByIDCalled = append(s.getByIDCalled, id) + if err, ok := s.getByIDErrByID[id]; ok { + return nil, err + } + if account, ok := s.getByIDAccounts[id]; ok { + return account, nil + } + return nil, errors.New("account not found") +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} @@ -78,3 +105,31 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } + +// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。 +func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformSora}, + }, + getByIDAccounts: map[int64]*Account{ + 1: {ID: 1, Platform: PlatformSora}, + }, + } + svc := &adminServiceImpl{ + accountRepo: repo, + soraSyncService: &Sora2APISyncService{}, + } + + schedulable := true + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + Schedulable: &schedulable, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 1, result.Success) + require.True(t, repo.getByIDsCalled) + require.ElementsMatch(t, []int64{1}, repo.getByIDCalled) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 5b476dbc..6247da00 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -35,6 +35,10 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index eb5c7534..5569a503 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -235,6 +235,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, ModelRouting: apiKey.Group.ModelRouting, @@ -279,6 +283,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, ModelRouting: snapshot.Group.ModelRouting, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index f2afc343..9b72bf6e 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -303,6 +303,14 @@ type ImagePriceConfig struct { Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } +// SoraPriceConfig Sora 按次计费配置 +type SoraPriceConfig struct { + ImagePrice360 *float64 + ImagePrice540 *float64 + VideoPricePerRequest *float64 + VideoPricePerRequestHD *float64 +} + // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" @@ -332,6 +340,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag } } +// CalculateSoraImageCost 计算 Sora 图片按次费用 +func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + if imageCount <= 0 { + return &CostBreakdown{} + } + + unitPrice := 0.0 + if groupConfig != nil { + switch imageSize { + case "540": + if groupConfig.ImagePrice540 != nil { + unitPrice = *groupConfig.ImagePrice540 + } + default: + if groupConfig.ImagePrice360 != nil { + unitPrice = *groupConfig.ImagePrice360 + } + } + } + + totalCost := unitPrice * float64(imageCount) + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + +// CalculateSoraVideoCost 计算 Sora 视频按次费用 +func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + unitPrice := 0.0 + if groupConfig != nil { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + if groupConfig.VideoPricePerRequestHD != nil { + unitPrice = *groupConfig.VideoPricePerRequestHD + } + } + if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil { + unitPrice = *groupConfig.VideoPricePerRequest + } + } + + totalCost := unitPrice + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + // getImageUnitPrice 获取图片单价 func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { // 优先使用分组配置的价格 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9565da29..f0933ae3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -184,6 +184,10 @@ type ForwardResult struct { // 图片生成计费字段(仅 gemini-3-pro-image 使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" + + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -3461,7 +3465,22 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu var cost *CostBreakdown // 根据请求类型选择计费方式 - if result.ImageCount > 0 { + if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == "image" { + cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } else { + cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + } + } else if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig if apiKey.Group != nil { @@ -3501,6 +3520,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.ImageSize != "" { imageSize = &result.ImageSize } + var mediaType *string + if strings.TrimSpace(result.MediaType) != "" { + mediaType = &result.MediaType + } accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, @@ -3526,6 +3549,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + MediaType: mediaType, CreatedAt: time.Now(), } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index d6d1269b..bc97e062 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,6 +26,12 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 + // Sora 按次计费配置(阶段 1) + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 @@ -83,6 +89,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { } } +// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540) +func (g *Group) GetSoraImagePrice(imageSize string) *float64 { + switch imageSize { + case "360": + return g.SoraImagePrice360 + case "540": + return g.SoraImagePrice540 + default: + return g.SoraImagePrice360 + } +} + // IsGroupContextValid reports whether a group from context has the fields required for routing decisions. func IsGroupContextValid(group *Group) bool { if group == nil { diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 87a7713b..026d9061 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -41,8 +41,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { - return "", errors.New("not an openai oauth account") + if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai/sora oauth account") } cacheKey := OpenAITokenCacheKey(account) @@ -157,7 +157,7 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - accessToken := account.GetOpenAIAccessToken() + accessToken := account.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found in credentials") } diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index c2e3dbb0..3c649a7e 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } diff --git a/backend/internal/service/sora2api_service.go b/backend/internal/service/sora2api_service.go new file mode 100644 index 00000000..d4bf9ba4 --- /dev/null +++ b/backend/internal/service/sora2api_service.go @@ -0,0 +1,355 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// Sora2APIModel represents a model entry returned by sora2api. +type Sora2APIModel struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by,omitempty"` + Description string `json:"description,omitempty"` +} + +// Sora2APIModelList represents /v1/models response. +type Sora2APIModelList struct { + Object string `json:"object"` + Data []Sora2APIModel `json:"data"` +} + +// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem. +type Sora2APIImportTokenItem struct { + Email string `json:"email"` + AccessToken string `json:"access_token,omitempty"` + SessionToken string `json:"session_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ClientID string `json:"client_id,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + Remark string `json:"remark,omitempty"` + IsActive bool `json:"is_active"` + ImageEnabled bool `json:"image_enabled"` + VideoEnabled bool `json:"video_enabled"` + ImageConcurrency int `json:"image_concurrency"` + VideoConcurrency int `json:"video_concurrency"` +} + +// Sora2APIToken represents minimal fields for admin list. +type Sora2APIToken struct { + ID int64 `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Remark string `json:"remark"` +} + +// Sora2APIService provides access to sora2api endpoints. +type Sora2APIService struct { + cfg *config.Config + + baseURL string + apiKey string + adminUsername string + adminPassword string + adminTokenTTL time.Duration + adminTimeout time.Duration + tokenImportMode string + + client *http.Client + adminClient *http.Client + + adminToken string + adminTokenAt time.Time + adminMu sync.Mutex + + modelCache []Sora2APIModel + modelCacheAt time.Time + modelMu sync.RWMutex +} + +func NewSora2APIService(cfg *config.Config) *Sora2APIService { + if cfg == nil { + return &Sora2APIService{} + } + adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second + if adminTTL <= 0 { + adminTTL = 15 * time.Minute + } + adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second + if adminTimeout <= 0 { + adminTimeout = 10 * time.Second + } + return &Sora2APIService{ + cfg: cfg, + baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"), + apiKey: strings.TrimSpace(cfg.Sora2API.APIKey), + adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername), + adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword), + adminTokenTTL: adminTTL, + adminTimeout: adminTimeout, + tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)), + client: &http.Client{}, + adminClient: &http.Client{Timeout: adminTimeout}, + } +} + +func (s *Sora2APIService) Enabled() bool { + return s != nil && s.baseURL != "" && s.apiKey != "" +} + +func (s *Sora2APIService) AdminEnabled() bool { + return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != "" +} + +func (s *Sora2APIService) buildURL(path string) string { + if s.baseURL == "" { + return path + } + if strings.HasPrefix(path, "/") { + return s.baseURL + path + } + return s.baseURL + "/" + path +} + +// BuildURL 返回完整的 sora2api URL(用于代理媒体) +func (s *Sora2APIService) BuildURL(path string) string { + return s.buildURL(path) +} + +func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) { + if !s.Enabled() { + return nil, errors.New("sora2api not configured") + } + req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+s.apiKey) + req.Header.Set("Content-Type", "application/json") + return req, nil +} + +func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) { + if !s.Enabled() { + return nil, errors.New("sora2api not configured") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+s.apiKey) + resp, err := s.client.Do(req) + if err != nil { + return s.cachedModelsOnError(err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode)) + } + + var payload Sora2APIModelList + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return s.cachedModelsOnError(err) + } + models := payload.Data + if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance { + filtered := make([]Sora2APIModel, 0, len(models)) + for _, m := range models { + if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") { + continue + } + filtered = append(filtered, m) + } + models = filtered + } + + s.modelMu.Lock() + s.modelCache = models + s.modelCacheAt = time.Now() + s.modelMu.Unlock() + + return models, nil +} + +func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) { + s.modelMu.RLock() + cached := append([]Sora2APIModel(nil), s.modelCache...) + s.modelMu.RUnlock() + if len(cached) > 0 { + log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err) + return cached, nil + } + return nil, err +} + +func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error { + if !s.AdminEnabled() { + return errors.New("sora2api admin not configured") + } + mode := s.tokenImportMode + if mode == "" { + mode = "at" + } + payload := map[string]any{ + "tokens": items, + "mode": mode, + } + _, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil) + return err +} + +func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) { + if !s.AdminEnabled() { + return nil, errors.New("sora2api admin not configured") + } + var tokens []Sora2APIToken + _, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens) + return tokens, err +} + +func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error { + if !s.AdminEnabled() { + return errors.New("sora2api admin not configured") + } + path := fmt.Sprintf("/api/tokens/%d/disable", tokenID) + _, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil) + return err +} + +func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error { + if !s.AdminEnabled() { + return errors.New("sora2api admin not configured") + } + path := fmt.Sprintf("/api/tokens/%d", tokenID) + _, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil) + return err +} + +func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) { + if !s.AdminEnabled() { + return nil, errors.New("sora2api admin not configured") + } + token, err := s.getAdminToken(ctx) + if err != nil { + return nil, err + } + resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out) + if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if resp != nil && resp.StatusCode == http.StatusUnauthorized { + s.invalidateAdminToken() + token, err = s.getAdminToken(ctx) + if err != nil { + return resp, err + } + return s.doAdminRequestWithToken(ctx, method, path, token, body, out) + } + return resp, err +} + +func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) { + var reader *bytes.Reader + if body != nil { + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + reader = bytes.NewReader(buf) + } else { + reader = bytes.NewReader(nil) + } + req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + resp, err := s.adminClient.Do(req) + if err != nil { + return resp, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode) + } + if out != nil { + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return resp, err + } + } + return resp, nil +} + +func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) { + s.adminMu.Lock() + defer s.adminMu.Unlock() + + if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL { + return s.adminToken, nil + } + + if !s.AdminEnabled() { + return "", errors.New("sora2api admin not configured") + } + + payload := map[string]string{ + "username": s.adminUsername, + "password": s.adminPassword, + } + buf, err := json.Marshal(payload) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + resp, err := s.adminClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode) + } + var result struct { + Success bool `json:"success"` + Token string `json:"token"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + if !result.Success || result.Token == "" { + if result.Message == "" { + result.Message = "sora2api login failed" + } + return "", errors.New(result.Message) + } + s.adminToken = result.Token + s.adminTokenAt = time.Now() + return result.Token, nil +} + +func (s *Sora2APIService) invalidateAdminToken() { + s.adminMu.Lock() + defer s.adminMu.Unlock() + s.adminToken = "" + s.adminTokenAt = time.Time{} +} diff --git a/backend/internal/service/sora2api_sync_service.go b/backend/internal/service/sora2api_sync_service.go new file mode 100644 index 00000000..33978432 --- /dev/null +++ b/backend/internal/service/sora2api_sync_service.go @@ -0,0 +1,255 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池 +type Sora2APISyncService struct { + sora2api *Sora2APIService + accountRepo AccountRepository + httpClient *http.Client +} + +func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService { + return &Sora2APISyncService{ + sora2api: sora2api, + accountRepo: accountRepo, + httpClient: &http.Client{Timeout: 10 * time.Second}, + } +} + +func (s *Sora2APISyncService) Enabled() bool { + return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled() +} + +// SyncAccount 将 Sora 账号同步到 sora2api(导入或更新) +func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error { + if !s.Enabled() { + return nil + } + if account == nil || account.Platform != PlatformSora { + return nil + } + + accessToken := strings.TrimSpace(account.GetCredential("access_token")) + if accessToken == "" { + return errors.New("sora 账号缺少 access_token") + } + + email, updated := s.resolveAccountEmail(ctx, account) + if email == "" { + return errors.New("无法解析 Sora 账号邮箱") + } + if updated && s.accountRepo != nil { + if err := s.accountRepo.Update(ctx, account); err != nil { + log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err) + } + } + + item := Sora2APIImportTokenItem{ + Email: email, + AccessToken: accessToken, + SessionToken: strings.TrimSpace(account.GetCredential("session_token")), + RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")), + ClientID: strings.TrimSpace(account.GetCredential("client_id")), + Remark: account.Name, + IsActive: account.IsActive() && account.Schedulable, + ImageEnabled: true, + VideoEnabled: true, + ImageConcurrency: normalizeSoraConcurrency(account.Concurrency), + VideoConcurrency: normalizeSoraConcurrency(account.Concurrency), + } + + if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil { + return err + } + return nil +} + +// DisableAccount 禁用 sora2api 中的 token +func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error { + if !s.Enabled() { + return nil + } + if account == nil || account.Platform != PlatformSora { + return nil + } + tokenID, err := s.resolveTokenID(ctx, account) + if err != nil { + return err + } + return s.sora2api.DisableToken(ctx, tokenID) +} + +// DeleteAccount 删除 sora2api 中的 token +func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error { + if !s.Enabled() { + return nil + } + if account == nil || account.Platform != PlatformSora { + return nil + } + tokenID, err := s.resolveTokenID(ctx, account) + if err != nil { + return err + } + return s.sora2api.DeleteToken(ctx, tokenID) +} + +func normalizeSoraConcurrency(value int) int { + if value <= 0 { + return -1 + } + return value +} + +func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) { + if account == nil { + return "", false + } + if email := strings.TrimSpace(account.GetCredential("email")); email != "" { + return email, false + } + if email := strings.TrimSpace(account.GetExtraString("email")); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + + accessToken := strings.TrimSpace(account.GetCredential("access_token")) + if accessToken != "" { + if email := extractEmailFromAccessToken(accessToken); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + if email := s.fetchEmailFromSora(ctx, accessToken); email != "" { + if account.Credentials == nil { + account.Credentials = map[string]any{} + } + account.Credentials["email"] = email + return email, true + } + } + + return "", false +} + +func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) { + if account == nil { + return 0, errors.New("account is nil") + } + + if account.Extra != nil { + if v, ok := account.Extra["sora2api_token_id"]; ok { + if id, ok := v.(float64); ok && id > 0 { + return int64(id), nil + } + if id, ok := v.(int64); ok && id > 0 { + return id, nil + } + if id, ok := v.(int); ok && id > 0 { + return int64(id), nil + } + } + } + + email := strings.TrimSpace(account.GetCredential("email")) + if email == "" { + email, _ = s.resolveAccountEmail(ctx, account) + } + if email == "" { + return 0, errors.New("sora2api token email missing") + } + + tokenID, err := s.findTokenIDByEmail(ctx, email) + if err != nil { + return 0, err + } + return tokenID, nil +} + +func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) { + if !s.Enabled() { + return 0, errors.New("sora2api admin not configured") + } + tokens, err := s.sora2api.ListTokens(ctx) + if err != nil { + return 0, err + } + for _, token := range tokens { + if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) { + return token.ID, nil + } + } + return 0, fmt.Errorf("sora2api token not found for email: %s", email) +} + +func extractEmailFromAccessToken(accessToken string) string { + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + claims := jwt.MapClaims{} + _, _, err := parser.ParseUnverified(accessToken, claims) + if err != nil { + return "" + } + if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" { + return email + } + if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok { + if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" { + return email + } + } + return "" +} + +func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string { + if s.httpClient == nil { + return "" + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil) + if err != nil { + return "" + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + + resp, err := s.httpClient.Do(req) + if err != nil { + return "" + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return "" + } + var payload map[string]any + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return "" + } + if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" { + return email + } + return "" +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go new file mode 100644 index 00000000..82f4eaaa --- /dev/null +++ b/backend/internal/service/sora_gateway_service.go @@ -0,0 +1,660 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" +) + +var soraSSEDataRe = regexp.MustCompile(`^data:\s*`) +var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) +var soraVideoHTMLRe = regexp.MustCompile(`(?i)]+src=['"]([^'"]+)['"]`) + +var soraImageSizeMap = map[string]string{ + "gpt-image": "360", + "gpt-image-landscape": "540", + "gpt-image-portrait": "540", +} + +type soraStreamingResult struct { + content string + mediaType string + mediaURLs []string + imageCount int + imageSize string + firstTokenMs *int +} + +// SoraGatewayService handles forwarding requests to sora2api. +type SoraGatewayService struct { + sora2api *Sora2APIService + httpUpstream HTTPUpstream + rateLimitService *RateLimitService + cfg *config.Config +} + +func NewSoraGatewayService( + sora2api *Sora2APIService, + httpUpstream HTTPUpstream, + rateLimitService *RateLimitService, + cfg *config.Config, +) *SoraGatewayService { + return &SoraGatewayService{ + sora2api: sora2api, + httpUpstream: httpUpstream, + rateLimitService: rateLimitService, + cfg: cfg, + } +} + +func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { + startTime := time.Now() + + if s.sora2api == nil || !s.sora2api.Enabled() { + if c != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "sora2api 未配置", + }, + }) + } + return nil, errors.New("sora2api not configured") + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + reqModel, _ := reqBody["model"].(string) + reqStream, _ := reqBody["stream"].(bool) + + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel && mappedModel != "" { + reqBody["model"] = mappedModel + if updated, err := json.Marshal(reqBody); err == nil { + body = updated + } + } + + reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) + if cancel != nil { + defer cancel() + } + + upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body) + if err != nil { + return nil, err + } + if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + upstreamReq.Header.Set("User-Agent", ua) + } + } + if reqStream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } + + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + proxyURL := "" + if account != nil && account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + var resp *http.Response + if s.httpUpstream != nil { + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + } else { + resp, err = http.DefaultClient.Do(upstreamReq) + } + if err != nil { + s.setUpstreamRequestError(c, account, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + return s.handleErrorResponse(ctx, resp, c, account, reqModel) + } + + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream) + if err != nil { + return nil, err + } + + result := &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: streamResult.firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: streamResult.mediaType, + MediaURL: firstMediaURL(streamResult.mediaURLs), + ImageCount: streamResult.imageCount, + ImageSize: streamResult.imageSize, + } + + return result, nil +} + +func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if s == nil || s.cfg == nil { + return ctx, nil + } + timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds + if stream { + timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds + } + if timeoutSeconds <= 0 { + return ctx, nil + } + return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) +} + +func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if c != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + } +} + +func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + if s.rateLimitService == nil || account == nil || resp == nil { + return + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" { + upstreamMsg = msg + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if c != nil { + responsePayload := s.buildErrorPayload(respBody, upstreamMsg) + c.JSON(resp.StatusCode, responsePayload) + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any { + if len(respBody) > 0 { + var payload map[string]any + if err := json.Unmarshal(respBody, &payload); err == nil { + if errObj, ok := payload["error"].(map[string]any); ok { + if overrideMessage != "" { + errObj["message"] = overrideMessage + } + payload["error"] = errObj + return payload + } + } + } + return map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "message": overrideMessage, + }, + } +} + +func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) { + if resp == nil { + return nil, errors.New("empty response") + } + + if clientStream { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + } + + w := c.Writer + flusher, _ := w.(http.Flusher) + + contentBuilder := strings.Builder{} + var firstTokenMs *int + var upstreamError error + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + sendLine := func(line string) error { + if !clientStream { + return nil + } + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + return err + } + if flusher != nil { + flusher.Flush() + } + return nil + } + + for scanner.Scan() { + line := scanner.Text() + if soraSSEDataRe.MatchString(line) { + data := soraSSEDataRe.ReplaceAllString(line, "") + if data == "[DONE]" { + if err := sendLine("data: [DONE]"); err != nil { + return nil, err + } + break + } + updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel) + if errEvent != nil && upstreamError == nil { + upstreamError = errEvent + } + if contentDelta != "" { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + contentBuilder.WriteString(contentDelta) + } + if err := sendLine(updatedLine); err != nil { + return nil, err + } + continue + } + if err := sendLine(line); err != nil { + return nil, err + } + } + + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + if clientStream { + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n") + if flusher != nil { + flusher.Flush() + } + } + return nil, err + } + if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) + } + if clientStream { + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n") + if flusher != nil { + flusher.Flush() + } + } + return nil, err + } + + content := contentBuilder.String() + mediaType, mediaURLs := s.extractSoraMedia(content) + if mediaType == "" && isSoraPromptEnhanceModel(originalModel) { + mediaType = "prompt" + } + imageSize := "" + imageCount := 0 + if mediaType == "image" { + imageSize = soraImageSizeFromModel(originalModel) + imageCount = len(mediaURLs) + } + + if upstreamError != nil && !clientStream { + if c != nil { + c.JSON(http.StatusBadGateway, map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "message": upstreamError.Error(), + }, + }) + } + return nil, upstreamError + } + + if !clientStream { + response := buildSoraNonStreamResponse(content, originalModel) + if len(mediaURLs) > 0 { + response["media_url"] = mediaURLs[0] + if len(mediaURLs) > 1 { + response["media_urls"] = mediaURLs + } + } + c.JSON(http.StatusOK, response) + } + + return &soraStreamingResult{ + content: content, + mediaType: mediaType, + mediaURLs: mediaURLs, + imageCount: imageCount, + imageSize: imageSize, + firstTokenMs: firstTokenMs, + }, nil +} + +func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) { + if strings.TrimSpace(data) == "" { + return "data: ", "", nil + } + + var payload map[string]any + if err := json.Unmarshal([]byte(data), &payload); err != nil { + return "data: " + data, "", nil + } + + if errObj, ok := payload["error"].(map[string]any); ok { + if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { + return "data: " + data, "", errors.New(msg) + } + } + + if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" { + payload["model"] = originalModel + } + + contentDelta, updated := extractSoraContent(payload) + if updated { + rewritten := s.rewriteSoraContent(contentDelta) + if rewritten != contentDelta { + applySoraContent(payload, rewritten) + contentDelta = rewritten + } + } + + updatedData, err := json.Marshal(payload) + if err != nil { + return "data: " + data, contentDelta, nil + } + return "data: " + string(updatedData), contentDelta, nil +} + +func extractSoraContent(payload map[string]any) (string, bool) { + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + return "", false + } + choice, ok := choices[0].(map[string]any) + if !ok { + return "", false + } + if delta, ok := choice["delta"].(map[string]any); ok { + if content, ok := delta["content"].(string); ok { + return content, true + } + } + if message, ok := choice["message"].(map[string]any); ok { + if content, ok := message["content"].(string); ok { + return content, true + } + } + return "", false +} + +func applySoraContent(payload map[string]any, content string) { + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + return + } + choice, ok := choices[0].(map[string]any) + if !ok { + return + } + if delta, ok := choice["delta"].(map[string]any); ok { + delta["content"] = content + choice["delta"] = delta + return + } + if message, ok := choice["message"].(map[string]any); ok { + message["content"] = content + choice["message"] = message + } +} + +func (s *SoraGatewayService) rewriteSoraContent(content string) string { + if content == "" { + return content + } + content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string { + sub := soraImageMarkdownRe.FindStringSubmatch(match) + if len(sub) < 2 { + return match + } + rewritten := s.rewriteSoraURL(sub[1]) + if rewritten == sub[1] { + return match + } + return strings.Replace(match, sub[1], rewritten, 1) + }) + content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string { + sub := soraVideoHTMLRe.FindStringSubmatch(match) + if len(sub) < 2 { + return match + } + rewritten := s.rewriteSoraURL(sub[1]) + if rewritten == sub[1] { + return match + } + return strings.Replace(match, sub[1], rewritten, 1) + }) + return content +} + +func (s *SoraGatewayService) rewriteSoraURL(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return raw + } + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + path := parsed.Path + if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") { + return raw + } + return s.buildSoraMediaURL(path, parsed.RawQuery) +} + +func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) { + if content == "" { + return "", nil + } + if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 { + return "video", []string{match[1]} + } + imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1) + if len(imageMatches) == 0 { + return "", nil + } + urls := make([]string, 0, len(imageMatches)) + for _, match := range imageMatches { + if len(match) > 1 { + urls = append(urls, match[1]) + } + } + return "image", urls +} + +func buildSoraNonStreamResponse(content, model string) map[string]any { + return map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }, + }, + } +} + +func soraImageSizeFromModel(model string) string { + modelLower := strings.ToLower(model) + if size, ok := soraImageSizeMap[modelLower]; ok { + return size + } + if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { + return "540" + } + return "360" +} + +func isSoraPromptEnhanceModel(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance") +} + +func soraProErrorMessage(model, upstreamMsg string) string { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" + } + if strings.Contains(modelLower, "sora2pro") { + return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" + } + return "" +} + +func firstMediaURL(urls []string) string { + if len(urls) == 0 { + return "" + } + return urls[0] +} + +func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { + if path == "" { + return path + } + prefix := "/sora/media" + values := url.Values{} + if rawQuery != "" { + if parsed, err := url.ParseQuery(rawQuery); err == nil { + values = parsed + } + } + + signKey := "" + ttlSeconds := 0 + if s != nil && s.cfg != nil { + signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) + ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds + } + values.Del("sig") + values.Del("expires") + signingQuery := values.Encode() + if signKey != "" && ttlSeconds > 0 { + expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() + signature := SignSoraMediaURL(path, signingQuery, expires, signKey) + if signature != "" { + values.Set("expires", strconv.FormatInt(expires, 10)) + values.Set("sig", signature) + prefix = "/sora/media-signed" + } + } + + encoded := values.Encode() + if encoded == "" { + return prefix + path + } + return prefix + path + "?" + encoded +} diff --git a/backend/internal/service/sora_media_sign.go b/backend/internal/service/sora_media_sign.go new file mode 100644 index 00000000..5d4a8d88 --- /dev/null +++ b/backend/internal/service/sora_media_sign.go @@ -0,0 +1,42 @@ +package service + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "strconv" + "strings" +) + +// SignSoraMediaURL 生成 Sora 媒体临时签名 +func SignSoraMediaURL(path string, query string, expires int64, key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(buildSoraMediaSignPayload(path, query))) + mac.Write([]byte("|")) + mac.Write([]byte(strconv.FormatInt(expires, 10))) + return hex.EncodeToString(mac.Sum(nil)) +} + +// VerifySoraMediaURL 校验 Sora 媒体签名 +func VerifySoraMediaURL(path string, query string, expires int64, signature string, key string) bool { + signature = strings.TrimSpace(signature) + if signature == "" { + return false + } + expected := SignSoraMediaURL(path, query, expires, key) + if expected == "" { + return false + } + return hmac.Equal([]byte(signature), []byte(expected)) +} + +func buildSoraMediaSignPayload(path string, query string) string { + if strings.TrimSpace(query) == "" { + return path + } + return path + "?" + query +} diff --git a/backend/internal/service/sora_media_sign_test.go b/backend/internal/service/sora_media_sign_test.go new file mode 100644 index 00000000..2bbba987 --- /dev/null +++ b/backend/internal/service/sora_media_sign_test.go @@ -0,0 +1,34 @@ +package service + +import "testing" + +func TestSoraMediaSignVerify(t *testing.T) { + key := "test-key" + path := "/tmp/abc.png" + query := "a=1&b=2" + expires := int64(1700000000) + + signature := SignSoraMediaURL(path, query, expires, key) + if signature == "" { + t.Fatal("签名为空") + } + if !VerifySoraMediaURL(path, query, expires, signature, key) { + t.Fatal("签名校验失败") + } + if VerifySoraMediaURL(path, "a=1", expires, signature, key) { + t.Fatal("签名参数不同仍然通过") + } + if VerifySoraMediaURL(path, query, expires+1, signature, key) { + t.Fatal("签名过期校验未失败") + } +} + +func TestSoraMediaSignWithEmptyKey(t *testing.T) { + signature := SignSoraMediaURL("/tmp/a.png", "a=1", 1, "") + if signature != "" { + t.Fatalf("空密钥不应生成签名") + } + if VerifySoraMediaURL("/tmp/a.png", "a=1", 1, "sig", "") { + t.Fatalf("空密钥不应通过校验") + } +} diff --git a/backend/internal/service/token_cache_invalidator.go b/backend/internal/service/token_cache_invalidator.go index 74c9edc3..5c7ae8e9 100644 --- a/backend/internal/service/token_cache_invalidator.go +++ b/backend/internal/service/token_cache_invalidator.go @@ -42,7 +42,7 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac // Antigravity 同样可能有两种缓存键 keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account)) keysToDelete = append(keysToDelete, "ag:"+accountIDKey) - case PlatformOpenAI: + case PlatformOpenAI, PlatformSora: keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account)) case PlatformAnthropic: keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account)) diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 797ab721..167d2b54 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -19,6 +19,7 @@ type TokenRefreshService struct { refreshers []TokenRefresher cfg *config.TokenRefreshConfig cacheInvalidator TokenCacheInvalidator + soraSyncService *Sora2APISyncService stopCh chan struct{} wg sync.WaitGroup @@ -65,6 +66,17 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { } } +// SetSoraSyncService 设置 Sora2API 同步服务 +// 需要在 Start() 之前调用 +func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) { + s.soraSyncService = svc + for _, refresher := range s.refreshers { + if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { + openaiRefresher.SetSoraSyncService(svc) + } + } +} + // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 807524fd..9699092d 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -86,6 +86,7 @@ type OpenAITokenRefresher struct { openaiOAuthService *OpenAIOAuthService accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 + soraSyncService *Sora2APISyncService // Sora2API 同步服务 } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -103,17 +104,22 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { r.soraAccountRepo = repo } +// SetSoraSyncService 设置 Sora2API 同步服务 +func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) { + r.soraSyncService = svc +} + // CanRefresh 检查是否能处理此账号 // 只处理 openai 平台的 oauth 类型账号 func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { - return account.Platform == PlatformOpenAI && + return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) && account.Type == AccountTypeOAuth } // NeedsRefresh 检查token是否需要刷新 // 基于 expires_at 字段判断是否在刷新窗口内 func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { - expiresAt := account.GetOpenAITokenExpiresAt() + expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { return false } @@ -145,6 +151,17 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) } + // 如果是 Sora 平台账号,同步到 sora2api(不阻塞主流程) + if account.Platform == PlatformSora && r.soraSyncService != nil { + syncAccount := *account + syncAccount.Credentials = newCredentials + go func() { + if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil { + log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err) + } + }() + } + return newCredentials, nil } @@ -201,6 +218,13 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena } } + // 2.3 同步到 sora2api(如果配置) + if r.soraSyncService != nil { + if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil { + log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err) + } + } + log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v", soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil) } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 3b0e934f..4be35501 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -46,6 +46,7 @@ type UsageLog struct { // 图片生成字段 ImageCount int ImageSize *string + MediaType *string CreatedAt time.Time diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 73d23025..689fa5d7 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -40,6 +40,7 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { func ProvideTokenRefreshService( accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步 + soraSyncService *Sora2APISyncService, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, @@ -50,6 +51,9 @@ func ProvideTokenRefreshService( svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) + if soraSyncService != nil { + svc.SetSoraSyncService(soraSyncService) + } svc.Start() return svc } @@ -224,6 +228,7 @@ var ProviderSet = wire.NewSet( NewBillingCacheService, NewAdminService, NewGatewayService, + NewSoraGatewayService, NewOpenAIGatewayService, NewOAuthService, NewOpenAIOAuthService, @@ -237,6 +242,8 @@ var ProviderSet = wire.NewSet( NewAntigravityTokenProvider, NewOpenAITokenProvider, NewClaudeTokenProvider, + NewSora2APIService, + NewSora2APISyncService, NewAntigravityGatewayService, ProvideRateLimitService, NewAccountUsageService, diff --git a/backend/migrations/047_add_sora_pricing_and_media_type.sql b/backend/migrations/047_add_sora_pricing_and_media_type.sql new file mode 100644 index 00000000..d70e37c5 --- /dev/null +++ b/backend/migrations/047_add_sora_pricing_and_media_type.sql @@ -0,0 +1,11 @@ +-- Migration: 047_add_sora_pricing_and_media_type +-- 新增 Sora 按次计费字段与 usage_logs.media_type + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS sora_image_price_360 decimal(20,8), + ADD COLUMN IF NOT EXISTS sora_image_price_540 decimal(20,8), + ADD COLUMN IF NOT EXISTS sora_video_price_per_request decimal(20,8), + ADD COLUMN IF NOT EXISTS sora_video_price_per_request_hd decimal(20,8); + +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS media_type VARCHAR(16); diff --git a/deploy/Caddyfile b/deploy/Caddyfile index d4144057..fce88654 100644 --- a/deploy/Caddyfile +++ b/deploy/Caddyfile @@ -1,39 +1,6 @@ -# ============================================================================= -# Sub2API Caddy Reverse Proxy Configuration (宿主机部署) -# ============================================================================= -# 使用方法: -# 1. 安装 Caddy: https://caddyserver.com/docs/install -# 2. 修改下方 example.com 为你的域名 -# 3. 确保域名 DNS 已指向服务器 -# 4. 复制配置: sudo cp Caddyfile /etc/caddy/Caddyfile -# 5. 重载配置: sudo systemctl reload caddy -# -# Caddy 会自动申请和续期 Let's Encrypt SSL 证书 -# ============================================================================= - -# 全局配置 -{ - # Let's Encrypt 邮箱通知 - email admin@example.com - - # 服务器配置 - servers { - # 启用 HTTP/2 和 HTTP/3 - protocols h1 h2 h3 - - # 超时配置 - timeouts { - read_body 30s - read_header 10s - write 300s - idle 300s - } - } -} - # 修改为你的域名 -example.com { - # ========================================================================= +api.sub2api.com { + # ========================================================================= # 静态资源长期缓存(高优先级,放在最前面) # 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存 # ========================================================================= @@ -87,17 +54,13 @@ example.com { # 连接池优化 transport http { - versions h2c h1 keepalive 120s keepalive_idle_conns 256 read_buffer 16KB write_buffer 16KB compression off } - - # SSE/流式传输优化:禁用响应缓冲,立即刷新数据给客户端 - flush_interval -1 - + # 故障转移 fail_duration 30s max_fails 3 @@ -112,10 +75,6 @@ example.com { gzip 6 minimum_length 256 match { - # SSE 请求通常会带 Accept: text/event-stream,需排除压缩 - not header Accept text/event-stream* - # 排除已知 SSE 路径(即便 Accept 缺失) - not path /v1/messages /v1/responses /responses /antigravity/v1/messages /v1beta/models/* /antigravity/v1beta/models/* header Content-Type text/* header Content-Type application/json* header Content-Type application/javascript* @@ -199,7 +158,3 @@ example.com { respond "{err.status_code} {err.status_text}" } } - -# ============================================================================= -# HTTP 重定向到 HTTPS (Caddy 默认自动处理,此处显式声明) -# ============================================================================= diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 558b8ef0..99386fc9 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -116,6 +116,33 @@ gateway: # Max request body size in bytes (default: 100MB) # 请求体最大字节数(默认 100MB) max_body_size: 104857600 + # Sora max request body size in bytes (0=use max_body_size) + # Sora 请求体最大字节数(0=使用 max_body_size) + sora_max_body_size: 268435456 + # Sora stream timeout (seconds, 0=disable) + # Sora 流式请求总超时(秒,0=禁用) + sora_stream_timeout_seconds: 900 + # Sora non-stream timeout (seconds, 0=disable) + # Sora 非流式请求超时(秒,0=禁用) + sora_request_timeout_seconds: 180 + # Sora stream enforcement mode: force/error + # Sora stream 强制策略:force/error + sora_stream_mode: "force" + # Sora model filters + # Sora 模型过滤配置 + sora_model_filters: + # Hide prompt-enhance models by default + # 默认隐藏 prompt-enhance 模型 + hide_prompt_enhance: true + # Require API key for /sora/media proxy (default: false) + # /sora/media 是否强制要求 API Key(默认 true) + sora_media_require_api_key: true + # Sora media temporary signing key (empty disables signed URL) + # Sora 媒体临时签名密钥(为空则禁用签名) + sora_media_signing_key: "" + # Signed URL TTL seconds (<=0 disables) + # 临时签名 URL 有效期(秒,<=0 表示禁用) + sora_media_signed_url_ttl_seconds: 900 # Connection pool isolation strategy: # 连接池隔离策略: # - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts) @@ -220,6 +247,31 @@ gateway: # name: "Custom Profile 1" # profile_2: # name: "Custom Profile 2" + +# ============================================================================= +# Sora2API Configuration +# Sora2API 配置 +# ============================================================================= +sora2api: + # Sora2API base URL + # Sora2API 服务地址 + base_url: "http://127.0.0.1:8000" + # Sora2API API Key (for /v1/chat/completions and /v1/models) + # Sora2API API Key(用于生成/模型列表) + api_key: "" + # Admin username/password (for token sync) + # 管理口用户名/密码(用于 token 同步) + admin_username: "admin" + admin_password: "admin" + # Admin token cache ttl (seconds) + # 管理口 token 缓存时长(秒) + admin_token_ttl_seconds: 900 + # Admin request timeout (seconds) + # 管理口请求超时(秒) + admin_timeout_seconds: 10 + # Token import mode: at/offline + # Token 导入模式:at/offline + token_import_mode: "at" # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] # curves: [29, 23, 24] # point_formats: [0] diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index e86f6348..505c1419 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -18,6 +18,7 @@ import geminiAPI from './gemini' import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' +import modelsAPI from './models' /** * Unified admin API object for convenient access @@ -37,7 +38,8 @@ export const adminAPI = { gemini: geminiAPI, antigravity: antigravityAPI, userAttributes: userAttributesAPI, - ops: opsAPI + ops: opsAPI, + models: modelsAPI } export { @@ -55,7 +57,8 @@ export { geminiAPI, antigravityAPI, userAttributesAPI, - opsAPI + opsAPI, + modelsAPI } export default adminAPI diff --git a/frontend/src/api/admin/models.ts b/frontend/src/api/admin/models.ts new file mode 100644 index 00000000..897304ac --- /dev/null +++ b/frontend/src/api/admin/models.ts @@ -0,0 +1,14 @@ +import { apiClient } from '@/api/client' + +export async function getPlatformModels(platform: string): Promise { + const { data } = await apiClient.get('/admin/models', { + params: { platform } + }) + return data +} + +export const modelsAPI = { + getPlatformModels +} + +export default modelsAPI diff --git a/frontend/src/components/account/ModelWhitelistSelector.vue b/frontend/src/components/account/ModelWhitelistSelector.vue index c8c1b852..227e6e61 100644 --- a/frontend/src/components/account/ModelWhitelistSelector.vue +++ b/frontend/src/components/account/ModelWhitelistSelector.vue @@ -45,6 +45,19 @@ :placeholder="t('admin.accounts.searchModels')" @click.stop /> +
+ + {{ t('admin.accounts.soraModelsLoading') }} + + +
+ +
+ +

+ {{ t('admin.groups.soraPricing.description') }} +

+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
@@ -848,6 +906,64 @@
+ +
+ +

+ {{ t('admin.groups.soraPricing.description') }} +

+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
@@ -1152,7 +1268,8 @@ const platformOptions = computed(() => [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, - { value: 'antigravity', label: 'Antigravity' } + { value: 'antigravity', label: 'Antigravity' }, + { value: 'sora', label: 'Sora' } ]) const platformFilterOptions = computed(() => [ @@ -1160,7 +1277,8 @@ const platformFilterOptions = computed(() => [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, - { value: 'antigravity', label: 'Antigravity' } + { value: 'antigravity', label: 'Antigravity' }, + { value: 'sora', label: 'Sora' } ]) const editStatusOptions = computed(() => [ @@ -1240,6 +1358,16 @@ const createForm = reactive({ image_price_1k: null as number | null, image_price_2k: null as number | null, image_price_4k: null as number | null, + // Sora 按次计费配置 + sora_image_price_360: null as number | null, + sora_image_price_540: null as number | null, + sora_video_price_per_request: null as number | null, + sora_video_price_per_request_hd: null as number | null, + // Sora 按次计费配置 + sora_image_price_360: null as number | null, + sora_image_price_540: null as number | null, + sora_video_price_per_request: null as number | null, + sora_video_price_per_request_hd: null as number | null, // Claude Code 客户端限制(仅 anthropic 平台使用) claude_code_only: false, fallback_group_id: null as number | null, @@ -1411,6 +1539,11 @@ const editForm = reactive({ image_price_1k: null as number | null, image_price_2k: null as number | null, image_price_4k: null as number | null, + // Sora 按次计费配置 + sora_image_price_360: null as number | null, + sora_image_price_540: null as number | null, + sora_video_price_per_request: null as number | null, + sora_video_price_per_request_hd: null as number | null, // Claude Code 客户端限制(仅 anthropic 平台使用) claude_code_only: false, fallback_group_id: null as number | null, @@ -1495,6 +1628,10 @@ const closeCreateModal = () => { createForm.image_price_1k = null createForm.image_price_2k = null createForm.image_price_4k = null + createForm.sora_image_price_360 = null + createForm.sora_image_price_540 = null + createForm.sora_video_price_per_request = null + createForm.sora_video_price_per_request_hd = null createForm.claude_code_only = false createForm.fallback_group_id = null createModelRoutingRules.value = [] @@ -1544,6 +1681,10 @@ const handleEdit = async (group: AdminGroup) => { editForm.image_price_1k = group.image_price_1k editForm.image_price_2k = group.image_price_2k editForm.image_price_4k = group.image_price_4k + editForm.sora_image_price_360 = group.sora_image_price_360 + editForm.sora_image_price_540 = group.sora_image_price_540 + editForm.sora_video_price_per_request = group.sora_video_price_per_request + editForm.sora_video_price_per_request_hd = group.sora_video_price_per_request_hd editForm.claude_code_only = group.claude_code_only || false editForm.fallback_group_id = group.fallback_group_id editForm.model_routing_enabled = group.model_routing_enabled || false