From 399dd78b2ac62e6b008a16da5564ba423e7f8bbd Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 1 Feb 2026 21:37:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(Sora):=20=E7=9B=B4=E8=BF=9E=E7=94=9F?= =?UTF-8?q?=E6=88=90=E5=B9=B6=E7=A7=BB=E9=99=A4sora2api=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现直连 Sora 客户端、媒体落地与清理策略\n更新网关与前端配置以支持 Sora 平台\n补齐单元测试与契约测试,新增 curl 测试脚本\n\n测试: go test ./... -tags=unit --- backend/cmd/server/wire.go | 7 + backend/cmd/server/wire_gen.go | 25 +- backend/internal/config/config.go | 116 ++- .../internal/handler/admin/model_handler.go | 55 -- .../handler/admin/model_handler_test.go | 87 -- backend/internal/handler/gateway_handler.go | 16 +- backend/internal/handler/handler.go | 1 - .../internal/handler/sora_gateway_handler.go | 83 +- .../handler/sora_gateway_handler_test.go | 441 +++++++++ backend/internal/handler/wire.go | 3 - backend/internal/server/api_contract_test.go | 9 + backend/internal/server/routes/admin.go | 7 - .../internal/service/account_test_service.go | 2 +- backend/internal/service/admin_service.go | 76 +- .../service/admin_service_bulk_update_test.go | 28 - backend/internal/service/sora2api_service.go | 351 ------- .../internal/service/sora2api_sync_service.go | 255 ----- backend/internal/service/sora_client.go | 884 ++++++++++++++++++ backend/internal/service/sora_client_test.go | 54 ++ .../internal/service/sora_gateway_service.go | 618 ++++++++++-- .../service/sora_gateway_service_test.go | 99 ++ .../service/sora_media_cleanup_service.go | 117 +++ .../sora_media_cleanup_service_test.go | 46 + .../internal/service/sora_media_storage.go | 256 +++++ .../service/sora_media_storage_test.go | 69 ++ backend/internal/service/sora_models.go | 252 +++++ .../internal/service/token_refresh_service.go | 10 - backend/internal/service/token_refresher.go | 24 - backend/internal/service/wire.go | 18 +- build_image.sh | 8 + deploy/Dockerfile | 111 +++ deploy/config.example.yaml | 82 +- frontend/src/api/admin/index.ts | 7 +- frontend/src/api/admin/models.ts | 14 - .../components/account/CreateAccountModal.vue | 6 +- .../account/ModelWhitelistSelector.vue | 54 +- .../account/OAuthAuthorizationFlow.vue | 11 +- frontend/src/composables/useModelWhitelist.ts | 2 +- frontend/src/views/admin/GroupsView.vue | 5 - 39 files changed, 3120 insertions(+), 1189 deletions(-) delete mode 100644 backend/internal/handler/admin/model_handler.go delete mode 100644 backend/internal/handler/admin/model_handler_test.go create mode 100644 backend/internal/handler/sora_gateway_handler_test.go delete mode 100644 backend/internal/service/sora2api_service.go delete mode 100644 backend/internal/service/sora2api_sync_service.go create mode 100644 backend/internal/service/sora_client.go create mode 100644 backend/internal/service/sora_client_test.go create mode 100644 backend/internal/service/sora_gateway_service_test.go create mode 100644 backend/internal/service/sora_media_cleanup_service.go create mode 100644 backend/internal/service/sora_media_cleanup_service_test.go create mode 100644 backend/internal/service/sora_media_storage.go create mode 100644 backend/internal/service/sora_media_storage_test.go create mode 100644 backend/internal/service/sora_models.go create mode 100755 build_image.sh create mode 100644 deploy/Dockerfile delete mode 100644 frontend/src/api/admin/models.ts diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 5ef04a66..1e9e440e 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -67,6 +67,7 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, @@ -100,6 +101,12 @@ func provideCleanup( } return nil }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1d88b612..dd0eb0d9 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -87,12 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) soraAccountRepository := repository.NewSoraAccountRepository(db) - sora2APIService := service.NewSora2APIService(configConfig) - sora2APISyncService := service.NewSora2APISyncService(sora2APIService, accountRepository) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, sora2APISyncService, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -164,11 +162,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - modelHandler := admin.NewModelHandler(sora2APIService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, modelHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, sora2APIService, concurrencyService, billingCacheService, configConfig) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) - soraGatewayService := service.NewSoraGatewayService(sora2APIService, httpUpstream, rateLimitService, configConfig) + soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider) + soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) + soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler) @@ -182,9 +181,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, sora2APISyncService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) + soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -214,6 +214,7 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, @@ -246,6 +247,12 @@ func provideCleanup( } return nil }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index f3dec213..147cc3e9 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -58,7 +58,7 @@ type Config struct { UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - Sora2API Sora2APIConfig `mapstructure:"sora2api"` + Sora SoraConfig `mapstructure:"sora"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` @@ -205,22 +205,40 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } -// Sora2APIConfig Sora2API 服务配置 -type Sora2APIConfig struct { - // BaseURL Sora2API 服务地址(例如 http://localhost:8000) - BaseURL string `mapstructure:"base_url"` - // APIKey Sora2API OpenAI 兼容接口的 API Key - APIKey string `mapstructure:"api_key"` - // AdminUsername 管理员用户名(用于 token 同步) - AdminUsername string `mapstructure:"admin_username"` - // AdminPassword 管理员密码(用于 token 同步) - AdminPassword string `mapstructure:"admin_password"` - // AdminTokenTTLSeconds 管理员 Token 缓存时长(秒) - AdminTokenTTLSeconds int `mapstructure:"admin_token_ttl_seconds"` - // AdminTimeoutSeconds 管理接口请求超时(秒) - AdminTimeoutSeconds int `mapstructure:"admin_timeout_seconds"` - // TokenImportMode token 导入模式:at/offline - TokenImportMode string `mapstructure:"token_import_mode"` +// SoraConfig 直连 Sora 配置 +type SoraConfig struct { + Client SoraClientConfig `mapstructure:"client"` + Storage SoraStorageConfig `mapstructure:"storage"` +} + +// SoraClientConfig 直连 Sora 客户端配置 +type SoraClientConfig struct { + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + Debug bool `mapstructure:"debug"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` +} + +// SoraStorageConfig 媒体存储配置 +type SoraStorageConfig struct { + Type string `mapstructure:"type"` + LocalPath string `mapstructure:"local_path"` + FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` + MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` + Debug bool `mapstructure:"debug"` + Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` +} + +// SoraStorageCleanupConfig 媒体清理配置 +type SoraStorageCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + Schedule string `mapstructure:"schedule"` + RetentionDays int `mapstructure:"retention_days"` } // GatewayConfig API网关相关配置 @@ -905,6 +923,26 @@ func setDefaults() { viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) + // Sora 直连配置 + viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") + viper.SetDefault("sora.client.timeout_seconds", 120) + viper.SetDefault("sora.client.max_retries", 3) + viper.SetDefault("sora.client.poll_interval_seconds", 2) + viper.SetDefault("sora.client.max_poll_attempts", 600) + viper.SetDefault("sora.client.debug", false) + viper.SetDefault("sora.client.headers", map[string]string{}) + viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + viper.SetDefault("sora.client.disable_tls_fingerprint", false) + + viper.SetDefault("sora.storage.type", "local") + viper.SetDefault("sora.storage.local_path", "") + viper.SetDefault("sora.storage.fallback_to_upstream", true) + viper.SetDefault("sora.storage.max_concurrent_downloads", 4) + viper.SetDefault("sora.storage.debug", false) + viper.SetDefault("sora.storage.cleanup.enabled", true) + viper.SetDefault("sora.storage.cleanup.retention_days", 7) + viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *") + // TokenRefresh viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 @@ -920,15 +958,6 @@ func setDefaults() { viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") - // Sora2API - viper.SetDefault("sora2api.base_url", "") - viper.SetDefault("sora2api.api_key", "") - viper.SetDefault("sora2api.admin_username", "") - viper.SetDefault("sora2api.admin_password", "") - viper.SetDefault("sora2api.admin_token_ttl_seconds", 900) - viper.SetDefault("sora2api.admin_timeout_seconds", 10) - viper.SetDefault("sora2api.token_import_mode", "at") - } func (c *Config) Validate() error { @@ -1164,6 +1193,36 @@ func (c *Config) Validate() error { return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") } } + if c.Sora.Client.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.timeout_seconds must be non-negative") + } + if c.Sora.Client.MaxRetries < 0 { + return fmt.Errorf("sora.client.max_retries must be non-negative") + } + if c.Sora.Client.PollIntervalSeconds < 0 { + return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") + } + if c.Sora.Client.MaxPollAttempts < 0 { + return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") + } + if c.Sora.Storage.MaxConcurrentDownloads < 0 { + return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") + } + if c.Sora.Storage.Cleanup.Enabled { + if c.Sora.Storage.Cleanup.RetentionDays <= 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") + } + if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" { + return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled") + } + } else { + if c.Sora.Storage.Cleanup.RetentionDays < 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative") + } + } + if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" { + return fmt.Errorf("sora.storage.type must be 'local'") + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -1260,11 +1319,6 @@ func (c *Config) Validate() error { c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds { return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds") } - if strings.TrimSpace(c.Sora2API.BaseURL) != "" { - if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil { - return fmt.Errorf("sora2api.base_url invalid: %w", err) - } - } if c.Ops.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") } diff --git a/backend/internal/handler/admin/model_handler.go b/backend/internal/handler/admin/model_handler.go deleted file mode 100644 index 035b09bd..00000000 --- a/backend/internal/handler/admin/model_handler.go +++ /dev/null @@ -1,55 +0,0 @@ -package admin - -import ( - "net/http" - "strings" - - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -// ModelHandler handles admin model listing requests. -type ModelHandler struct { - sora2apiService *service.Sora2APIService -} - -// NewModelHandler creates a new ModelHandler. -func NewModelHandler(sora2apiService *service.Sora2APIService) *ModelHandler { - return &ModelHandler{ - sora2apiService: sora2apiService, - } -} - -// List handles listing models for a specific platform -// GET /api/v1/admin/models?platform=sora -func (h *ModelHandler) List(c *gin.Context) { - platform := strings.TrimSpace(strings.ToLower(c.Query("platform"))) - if platform == "" { - response.BadRequest(c, "platform is required") - return - } - - switch platform { - case service.PlatformSora: - if h.sora2apiService == nil || !h.sora2apiService.Enabled() { - response.Error(c, http.StatusServiceUnavailable, "sora2api not configured") - return - } - models, err := h.sora2apiService.ListModels(c.Request.Context()) - if err != nil { - response.Error(c, http.StatusServiceUnavailable, "failed to fetch sora models") - return - } - ids := make([]string, 0, len(models)) - for _, m := range models { - if strings.TrimSpace(m.ID) != "" { - ids = append(ids, m.ID) - } - } - response.Success(c, ids) - default: - response.BadRequest(c, "unsupported platform") - } -} diff --git a/backend/internal/handler/admin/model_handler_test.go b/backend/internal/handler/admin/model_handler_test.go deleted file mode 100644 index e61dc064..00000000 --- a/backend/internal/handler/admin/model_handler_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package admin - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -func TestModelHandlerListSoraSuccess(t *testing.T) { - gin.SetMode(gin.TestMode) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`)) - })) - t.Cleanup(upstream.Close) - - cfg := &config.Config{} - cfg.Sora2API.BaseURL = upstream.URL - cfg.Sora2API.APIKey = "test-key" - soraService := service.NewSora2APIService(cfg) - - h := NewModelHandler(soraService) - router := gin.New() - router.GET("/admin/models", h.List) - - req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) - recorder := httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusOK { - t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) - } - var resp response.Response - if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil { - t.Fatalf("解析响应失败: %v", err) - } - if resp.Code != 0 { - t.Fatalf("响应 code=%d", resp.Code) - } - data, ok := resp.Data.([]any) - if !ok { - t.Fatalf("响应 data 类型错误") - } - if len(data) != 2 { - t.Fatalf("模型数量不符: %d", len(data)) - } -} - -func TestModelHandlerListSoraNotConfigured(t *testing.T) { - gin.SetMode(gin.TestMode) - - h := NewModelHandler(&service.Sora2APIService{}) - router := gin.New() - router.GET("/admin/models", h.List) - - req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil) - recorder := httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusServiceUnavailable { - t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) - } -} - -func TestModelHandlerListInvalidPlatform(t *testing.T) { - gin.SetMode(gin.TestMode) - - h := NewModelHandler(&service.Sora2APIService{}) - router := gin.New() - router.GET("/admin/models", h.List) - - req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil) - recorder := httptest.NewRecorder() - router.ServeHTTP(recorder, req) - - if recorder.Code != http.StatusBadRequest { - t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String()) - } -} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 983cc6b3..a7b98940 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -29,11 +29,11 @@ type GatewayHandler struct { geminiCompatService *service.GeminiMessagesCompatService antigravityGatewayService *service.AntigravityGatewayService userService *service.UserService - sora2apiService *service.Sora2APIService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int + cfg *config.Config } // NewGatewayHandler creates a new GatewayHandler @@ -42,7 +42,6 @@ func NewGatewayHandler( geminiCompatService *service.GeminiMessagesCompatService, antigravityGatewayService *service.AntigravityGatewayService, userService *service.UserService, - sora2apiService *service.Sora2APIService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, cfg *config.Config, @@ -64,11 +63,11 @@ func NewGatewayHandler( geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, userService: userService, - sora2apiService: sora2apiService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, + cfg: cfg, } } @@ -486,18 +485,9 @@ func (h *GatewayHandler) Models(c *gin.Context) { } if platform == service.PlatformSora { - if h.sora2apiService == nil || !h.sora2apiService.Enabled() { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured") - return - } - models, err := h.sora2apiService.ListModels(c.Request.Context()) - if err != nil { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models") - return - } c.JSON(http.StatusOK, gin.H{ "object": "list", - "data": models, + "data": service.DefaultSoraModels(h.cfg), }) return } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 7905148c..d7014a22 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -23,7 +23,6 @@ type AdminHandlers struct { Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler - Model *admin.ModelHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 05833144..faed3b33 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -10,7 +10,9 @@ import ( "io" "log" "net/http" + "os" "path" + "path/filepath" "strconv" "strings" "time" @@ -31,9 +33,8 @@ type SoraGatewayHandler struct { concurrencyHelper *ConcurrencyHelper maxAccountSwitches int streamMode string - sora2apiBaseURL string soraMediaSigningKey string - mediaClient *http.Client + soraMediaRoot string } // NewSoraGatewayHandler creates a new SoraGatewayHandler @@ -48,6 +49,7 @@ func NewSoraGatewayHandler( maxAccountSwitches := 3 streamMode := "force" signKey := "" + mediaRoot := "/app/data/sora" if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second if cfg.Gateway.MaxAccountSwitches > 0 { @@ -57,14 +59,9 @@ func NewSoraGatewayHandler( streamMode = mode } signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) - } - baseURL := "" - if cfg != nil { - baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/") - } - mediaTimeout := 180 * time.Second - if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 { - mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second + if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { + mediaRoot = root + } } return &SoraGatewayHandler{ gatewayService: gatewayService, @@ -73,9 +70,8 @@ func NewSoraGatewayHandler( concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, streamMode: strings.ToLower(streamMode), - sora2apiBaseURL: baseURL, soraMediaSigningKey: signKey, - mediaClient: &http.Client{Timeout: mediaTimeout}, + soraMediaRoot: mediaRoot, } } @@ -377,34 +373,24 @@ func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, }) } -// MediaProxy proxies /tmp or /static media files from sora2api +// MediaProxy serves local Sora media files. func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) { h.proxySoraMedia(c, false) } -// MediaProxySigned proxies /tmp or /static media files with signature verification +// MediaProxySigned serves local Sora media files with signature verification. func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) { h.proxySoraMedia(c, true) } func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) { - if h.sora2apiBaseURL == "" { - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "type": "api_error", - "message": "sora2api 未配置", - }, - }) - return - } - rawPath := c.Param("filepath") if rawPath == "" { c.Status(http.StatusNotFound) return } cleaned := path.Clean(rawPath) - if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") { + if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") { c.Status(http.StatusNotFound) return } @@ -445,40 +431,25 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo return } } - - targetURL := h.sora2apiBaseURL + cleaned - if rawQuery := query.Encode(); rawQuery != "" { - targetURL += "?" + rawQuery - } - - req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil) - if err != nil { - c.Status(http.StatusBadGateway) + if strings.TrimSpace(h.soraMediaRoot) == "" { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Sora 媒体目录未配置", + }, + }) return } - copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"} - for _, key := range copyHeaders { - if val := c.GetHeader(key); val != "" { - req.Header.Set(key, val) - } - } - client := h.mediaClient - if client == nil { - client = http.DefaultClient - } - resp, err := client.Do(req) - if err != nil { - c.Status(http.StatusBadGateway) + relative := strings.TrimPrefix(cleaned, "/") + localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative)) + if _, err := os.Stat(localPath); err != nil { + if os.IsNotExist(err) { + c.Status(http.StatusNotFound) + return + } + c.Status(http.StatusInternalServerError) return } - defer func() { _ = resp.Body.Close() }() - - for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} { - if val := resp.Header.Get(key); val != "" { - c.Header(key, val) - } - } - c.Status(resp.StatusCode) - _, _ = io.Copy(c.Writer, resp.Body) + c.File(localPath) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go new file mode 100644 index 00000000..91881dec --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -0,0 +1,441 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type stubSoraClient struct { + imageURLs []string +} + +func (s *stubSoraClient) Enabled() bool { return true } +func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) { + return "upload", nil +} +func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) { + return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil +} +func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) { + return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil +} + +type stubConcurrencyCache struct{} + +func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + return nil +} +func (c stubConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} +func (c stubConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} +func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} +func (c stubConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + return nil +} +func (c stubConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} +func (c stubConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} +func (c stubConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} +func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + result := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return result, nil +} +func (c stubConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +type stubAccountRepo struct { + accounts map[int64]*service.Account +} + +func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil } +func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) { + if acc, ok := r.accounts[id]; ok { + return acc, nil + } + return nil, service.ErrAccountNotFound +} +func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { + var result []*service.Account + for _, id := range ids { + if acc, ok := r.accounts[id]; ok { + result = append(result, acc) + } + } + return result, nil +} +func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) { + _, ok := r.accounts[id] + return ok, nil +} +func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil } +func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil } +func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.listSchedulableByPlatform(platform), nil +} +func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return nil +} +func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil } +func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + return nil +} +func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + return 0, nil +} +func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { + return nil +} +func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) { + return r.listSchedulable(), nil +} +func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { + return r.listSchedulable(), nil +} +func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return r.listSchedulableByPlatform(platform), nil +} +func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return r.listSchedulableByPlatform(platform), nil +} +func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + var result []service.Account + for _, acc := range r.accounts { + for _, platform := range platforms { + if acc.Platform == platform && acc.IsSchedulable() { + result = append(result, *acc) + break + } + } + } + return result, nil +} +func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { + return r.ListSchedulableByPlatforms(ctx, platforms) +} +func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + return nil +} +func (r *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { + return nil +} +func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { + return nil +} +func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + return nil +} +func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + return nil +} +func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + return nil +} +func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil } +func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { + return nil +} +func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + return nil +} +func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { + return 0, nil +} + +func (r *stubAccountRepo) listSchedulable() []service.Account { + var result []service.Account + for _, acc := range r.accounts { + if acc.IsSchedulable() { + result = append(result, *acc) + } + } + return result +} + +func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account { + var result []service.Account + for _, acc := range r.accounts { + if acc.Platform == platform && acc.IsSchedulable() { + result = append(result, *acc) + } + } + return result +} + +type stubGroupRepo struct { + group *service.Group +} + +func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil } +func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) { + return r.group, nil +} +func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) { + return r.group, nil +} +func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil } +func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil } +func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil } +func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { + return nil, nil +} +func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { + return false, nil +} +func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} +func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} + +type stubUsageLogRepo struct{} + +func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) { + return true, nil +} +func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { + return nil, nil +} +func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil } +func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) { + return nil, nil +} +func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, nil +} +func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) { + return nil, nil +} + +func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + RunMode: config.RunModeSimple, + Gateway: config.GatewayConfig{ + SoraStreamMode: "force", + MaxAccountSwitches: 1, + Scheduling: config.GatewaySchedulingConfig{ + LoadBatchEnabled: false, + }, + }, + Concurrency: config.ConcurrencyConfig{PingInterval: 0}, + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.test", + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + + account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1} + accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}} + group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true} + groupRepo := &stubGroupRepo{group: group} + + usageLogRepo := &stubUsageLogRepo{} + deferredService := service.NewDeferredService(accountRepo, nil, 0) + billingService := service.NewBillingService(cfg, nil) + concurrencyService := service.NewConcurrencyService(stubConcurrencyCache{}) + billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg) + t.Cleanup(func() { + billingCacheService.Stop() + }) + + gatewayService := service.NewGatewayService( + accountRepo, + groupRepo, + usageLogRepo, + nil, + nil, + nil, + cfg, + nil, + concurrencyService, + billingService, + nil, + billingCacheService, + nil, + nil, + deferredService, + nil, + nil, + ) + + soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} + soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg) + + handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}` + c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + apiKey := &service.APIKey{ + ID: 1, + UserID: 1, + Status: service.StatusActive, + GroupID: &group.ID, + User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive}, + Group: group, + } + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency}) + + handler.ChatCompletions(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.NotEmpty(t, resp["media_url"]) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 1e3ef17d..c20b7fbc 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -26,7 +26,6 @@ func ProvideAdminHandlers( subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, - modelHandler *admin.ModelHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -46,7 +45,6 @@ func ProvideAdminHandlers( Subscription: subscriptionHandler, Usage: usageHandler, UserAttribute: userAttributeHandler, - Model: modelHandler, } } @@ -121,7 +119,6 @@ var ProviderSet = wire.NewSet( admin.NewSubscriptionHandler, admin.NewUsageHandler, admin.NewUserAttributeHandler, - admin.NewModelHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index f3eebd41..409a7625 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -178,6 +178,10 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, + "sora_image_price_360": null, + "sora_image_price_540": null, + "sora_video_price_per_request": null, + "sora_video_price_per_request_hd": null, "claude_code_only": false, "fallback_group_id": null, "created_at": "2025-01-02T03:04:05Z", @@ -394,6 +398,7 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, + "media_type": null, "created_at": "2025-01-02T03:04:05Z", "user_agent": null } @@ -887,6 +892,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st return nil, errors.New("not implemented") } +func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 2c1762d3..050e724d 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -64,9 +64,6 @@ func RegisterAdminRoutes( // 用户属性管理 registerUserAttributeRoutes(admin, h) - - // 模型列表 - registerModelRoutes(admin, h) } } @@ -374,7 +371,3 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) } } - -func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { - admin.GET("/models", h.Admin.Model.List) -} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index f80a2af8..a76c4d20 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -491,7 +491,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * return s.sendErrorAndEnd(c, "Failed to create request") } - // 使用 Sora 客户端标准请求头(参考 sora2api) + // 使用 Sora 客户端标准请求头 req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") req.Header.Set("Accept", "application/json") diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index a29bf4db..94b18322 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -283,7 +283,6 @@ type adminServiceImpl struct { groupRepo GroupRepository accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 - soraSyncService *Sora2APISyncService // Sora2API 同步服务 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -299,7 +298,6 @@ func NewAdminService( groupRepo GroupRepository, accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, - soraSyncService *Sora2APISyncService, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -313,7 +311,6 @@ func NewAdminService( groupRepo: groupRepo, accountRepo: accountRepo, soraAccountRepo: soraAccountRepo, - soraSyncService: soraSyncService, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -917,9 +914,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } - // 同步到 sora2api(异步,不阻塞创建) - s.syncSoraAccountAsync(account) - return account, nil } @@ -1014,7 +1008,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U if err != nil { return nil, err } - s.syncSoraAccountAsync(updated) return updated, nil } @@ -1032,17 +1025,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck - needSoraSync := s != nil && s.soraSyncService != nil // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} - if needMixedChannelCheck || needSoraSync { + if needMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { if needMixedChannelCheck { return nil, err } - log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err) } else { for _, account := range accounts { if account != nil { @@ -1134,45 +1125,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Success++ result.SuccessIDs = append(result.SuccessIDs, accountID) result.Results = append(result.Results, entry) - - // 批量更新后同步 sora2api - if needSoraSync { - platform := platformByID[accountID] - if platform == "" { - updated, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) - continue - } - if updated.Platform == PlatformSora { - s.syncSoraAccountAsync(updated) - } - continue - } - - if platform == PlatformSora { - updated, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err) - continue - } - s.syncSoraAccountAsync(updated) - } - } } return result, nil } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - account, err := s.accountRepo.GetByID(ctx, id) - if err != nil { - return err - } if err := s.accountRepo.Delete(ctx, id); err != nil { return err } - s.deleteSoraAccountAsync(account) return nil } @@ -1210,44 +1171,9 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, if err != nil { return nil, err } - s.syncSoraAccountAsync(updated) return updated, nil } -func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) { - if s == nil || s.soraSyncService == nil || account == nil { - return - } - if account.Platform != PlatformSora { - return - } - syncAccount := *account - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil { - log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err) - } - }() -} - -func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) { - if s == nil || s.soraSyncService == nil || account == nil { - return - } - if account.Platform != PlatformSora { - return - } - syncAccount := *account - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil { - log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err) - } - }() -} - // Proxy management implementations func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index cbdbe625..0dccacbb 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -105,31 +105,3 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } - -// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。 -func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) { - repo := &accountRepoStubForBulkUpdate{ - getByIDsAccounts: []*Account{ - {ID: 1, Platform: PlatformSora}, - }, - getByIDAccounts: map[int64]*Account{ - 1: {ID: 1, Platform: PlatformSora}, - }, - } - svc := &adminServiceImpl{ - accountRepo: repo, - soraSyncService: &Sora2APISyncService{}, - } - - schedulable := true - input := &BulkUpdateAccountsInput{ - AccountIDs: []int64{1}, - Schedulable: &schedulable, - } - - result, err := svc.BulkUpdateAccounts(context.Background(), input) - require.NoError(t, err) - require.Equal(t, 1, result.Success) - require.True(t, repo.getByIDsCalled) - require.ElementsMatch(t, []int64{1}, repo.getByIDCalled) -} diff --git a/backend/internal/service/sora2api_service.go b/backend/internal/service/sora2api_service.go deleted file mode 100644 index c047cd40..00000000 --- a/backend/internal/service/sora2api_service.go +++ /dev/null @@ -1,351 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "log" - "net/http" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" -) - -// Sora2APIModel represents a model entry returned by sora2api. -type Sora2APIModel struct { - ID string `json:"id"` - Object string `json:"object"` - OwnedBy string `json:"owned_by,omitempty"` - Description string `json:"description,omitempty"` -} - -// Sora2APIModelList represents /v1/models response. -type Sora2APIModelList struct { - Object string `json:"object"` - Data []Sora2APIModel `json:"data"` -} - -// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem. -type Sora2APIImportTokenItem struct { - Email string `json:"email"` - AccessToken string `json:"access_token,omitempty"` - SessionToken string `json:"session_token,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - ClientID string `json:"client_id,omitempty"` - ProxyURL string `json:"proxy_url,omitempty"` - Remark string `json:"remark,omitempty"` - IsActive bool `json:"is_active"` - ImageEnabled bool `json:"image_enabled"` - VideoEnabled bool `json:"video_enabled"` - ImageConcurrency int `json:"image_concurrency"` - VideoConcurrency int `json:"video_concurrency"` -} - -// Sora2APIToken represents minimal fields for admin list. -type Sora2APIToken struct { - ID int64 `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Remark string `json:"remark"` -} - -// Sora2APIService provides access to sora2api endpoints. -type Sora2APIService struct { - cfg *config.Config - - baseURL string - apiKey string - adminUsername string - adminPassword string - adminTokenTTL time.Duration - tokenImportMode string - - client *http.Client - adminClient *http.Client - - adminToken string - adminTokenAt time.Time - adminMu sync.Mutex - - modelCache []Sora2APIModel - modelMu sync.RWMutex -} - -func NewSora2APIService(cfg *config.Config) *Sora2APIService { - if cfg == nil { - return &Sora2APIService{} - } - adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second - if adminTTL <= 0 { - adminTTL = 15 * time.Minute - } - adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second - if adminTimeout <= 0 { - adminTimeout = 10 * time.Second - } - return &Sora2APIService{ - cfg: cfg, - baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"), - apiKey: strings.TrimSpace(cfg.Sora2API.APIKey), - adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername), - adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword), - adminTokenTTL: adminTTL, - tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)), - client: &http.Client{}, - adminClient: &http.Client{Timeout: adminTimeout}, - } -} - -func (s *Sora2APIService) Enabled() bool { - return s != nil && s.baseURL != "" && s.apiKey != "" -} - -func (s *Sora2APIService) AdminEnabled() bool { - return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != "" -} - -func (s *Sora2APIService) buildURL(path string) string { - if s.baseURL == "" { - return path - } - if strings.HasPrefix(path, "/") { - return s.baseURL + path - } - return s.baseURL + "/" + path -} - -// BuildURL 返回完整的 sora2api URL(用于代理媒体) -func (s *Sora2APIService) BuildURL(path string) string { - return s.buildURL(path) -} - -func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) { - if !s.Enabled() { - return nil, errors.New("sora2api not configured") - } - req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body)) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+s.apiKey) - req.Header.Set("Content-Type", "application/json") - return req, nil -} - -func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) { - if !s.Enabled() { - return nil, errors.New("sora2api not configured") - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+s.apiKey) - resp, err := s.client.Do(req) - if err != nil { - return s.cachedModelsOnError(err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode)) - } - - var payload Sora2APIModelList - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { - return s.cachedModelsOnError(err) - } - models := payload.Data - if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance { - filtered := make([]Sora2APIModel, 0, len(models)) - for _, m := range models { - if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") { - continue - } - filtered = append(filtered, m) - } - models = filtered - } - - s.modelMu.Lock() - s.modelCache = models - s.modelMu.Unlock() - - return models, nil -} - -func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) { - s.modelMu.RLock() - cached := append([]Sora2APIModel(nil), s.modelCache...) - s.modelMu.RUnlock() - if len(cached) > 0 { - log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err) - return cached, nil - } - return nil, err -} - -func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error { - if !s.AdminEnabled() { - return errors.New("sora2api admin not configured") - } - mode := s.tokenImportMode - if mode == "" { - mode = "at" - } - payload := map[string]any{ - "tokens": items, - "mode": mode, - } - _, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil) - return err -} - -func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) { - if !s.AdminEnabled() { - return nil, errors.New("sora2api admin not configured") - } - var tokens []Sora2APIToken - _, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens) - return tokens, err -} - -func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error { - if !s.AdminEnabled() { - return errors.New("sora2api admin not configured") - } - path := fmt.Sprintf("/api/tokens/%d/disable", tokenID) - _, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil) - return err -} - -func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error { - if !s.AdminEnabled() { - return errors.New("sora2api admin not configured") - } - path := fmt.Sprintf("/api/tokens/%d", tokenID) - _, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil) - return err -} - -func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) { - if !s.AdminEnabled() { - return nil, errors.New("sora2api admin not configured") - } - token, err := s.getAdminToken(ctx) - if err != nil { - return nil, err - } - resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out) - if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized { - return resp, nil - } - if resp != nil && resp.StatusCode == http.StatusUnauthorized { - s.invalidateAdminToken() - token, err = s.getAdminToken(ctx) - if err != nil { - return resp, err - } - return s.doAdminRequestWithToken(ctx, method, path, token, body, out) - } - return resp, err -} - -func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) { - var reader *bytes.Reader - if body != nil { - buf, err := json.Marshal(body) - if err != nil { - return nil, err - } - reader = bytes.NewReader(buf) - } else { - reader = bytes.NewReader(nil) - } - req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+token) - if body != nil { - req.Header.Set("Content-Type", "application/json") - } - resp, err := s.adminClient.Do(req) - if err != nil { - return resp, err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode) - } - if out != nil { - if err := json.NewDecoder(resp.Body).Decode(out); err != nil { - return resp, err - } - } - return resp, nil -} - -func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) { - s.adminMu.Lock() - defer s.adminMu.Unlock() - - if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL { - return s.adminToken, nil - } - - if !s.AdminEnabled() { - return "", errors.New("sora2api admin not configured") - } - - payload := map[string]string{ - "username": s.adminUsername, - "password": s.adminPassword, - } - buf, err := json.Marshal(payload) - if err != nil { - return "", err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf)) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") - resp, err := s.adminClient.Do(req) - if err != nil { - return "", err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode) - } - var result struct { - Success bool `json:"success"` - Token string `json:"token"` - Message string `json:"message"` - } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", err - } - if !result.Success || result.Token == "" { - if result.Message == "" { - result.Message = "sora2api login failed" - } - return "", errors.New(result.Message) - } - s.adminToken = result.Token - s.adminTokenAt = time.Now() - return result.Token, nil -} - -func (s *Sora2APIService) invalidateAdminToken() { - s.adminMu.Lock() - defer s.adminMu.Unlock() - s.adminToken = "" - s.adminTokenAt = time.Time{} -} diff --git a/backend/internal/service/sora2api_sync_service.go b/backend/internal/service/sora2api_sync_service.go deleted file mode 100644 index 33978432..00000000 --- a/backend/internal/service/sora2api_sync_service.go +++ /dev/null @@ -1,255 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log" - "net/http" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" -) - -// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池 -type Sora2APISyncService struct { - sora2api *Sora2APIService - accountRepo AccountRepository - httpClient *http.Client -} - -func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService { - return &Sora2APISyncService{ - sora2api: sora2api, - accountRepo: accountRepo, - httpClient: &http.Client{Timeout: 10 * time.Second}, - } -} - -func (s *Sora2APISyncService) Enabled() bool { - return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled() -} - -// SyncAccount 将 Sora 账号同步到 sora2api(导入或更新) -func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error { - if !s.Enabled() { - return nil - } - if account == nil || account.Platform != PlatformSora { - return nil - } - - accessToken := strings.TrimSpace(account.GetCredential("access_token")) - if accessToken == "" { - return errors.New("sora 账号缺少 access_token") - } - - email, updated := s.resolveAccountEmail(ctx, account) - if email == "" { - return errors.New("无法解析 Sora 账号邮箱") - } - if updated && s.accountRepo != nil { - if err := s.accountRepo.Update(ctx, account); err != nil { - log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err) - } - } - - item := Sora2APIImportTokenItem{ - Email: email, - AccessToken: accessToken, - SessionToken: strings.TrimSpace(account.GetCredential("session_token")), - RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")), - ClientID: strings.TrimSpace(account.GetCredential("client_id")), - Remark: account.Name, - IsActive: account.IsActive() && account.Schedulable, - ImageEnabled: true, - VideoEnabled: true, - ImageConcurrency: normalizeSoraConcurrency(account.Concurrency), - VideoConcurrency: normalizeSoraConcurrency(account.Concurrency), - } - - if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil { - return err - } - return nil -} - -// DisableAccount 禁用 sora2api 中的 token -func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error { - if !s.Enabled() { - return nil - } - if account == nil || account.Platform != PlatformSora { - return nil - } - tokenID, err := s.resolveTokenID(ctx, account) - if err != nil { - return err - } - return s.sora2api.DisableToken(ctx, tokenID) -} - -// DeleteAccount 删除 sora2api 中的 token -func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error { - if !s.Enabled() { - return nil - } - if account == nil || account.Platform != PlatformSora { - return nil - } - tokenID, err := s.resolveTokenID(ctx, account) - if err != nil { - return err - } - return s.sora2api.DeleteToken(ctx, tokenID) -} - -func normalizeSoraConcurrency(value int) int { - if value <= 0 { - return -1 - } - return value -} - -func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) { - if account == nil { - return "", false - } - if email := strings.TrimSpace(account.GetCredential("email")); email != "" { - return email, false - } - if email := strings.TrimSpace(account.GetExtraString("email")); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - - accessToken := strings.TrimSpace(account.GetCredential("access_token")) - if accessToken != "" { - if email := extractEmailFromAccessToken(accessToken); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - if email := s.fetchEmailFromSora(ctx, accessToken); email != "" { - if account.Credentials == nil { - account.Credentials = map[string]any{} - } - account.Credentials["email"] = email - return email, true - } - } - - return "", false -} - -func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) { - if account == nil { - return 0, errors.New("account is nil") - } - - if account.Extra != nil { - if v, ok := account.Extra["sora2api_token_id"]; ok { - if id, ok := v.(float64); ok && id > 0 { - return int64(id), nil - } - if id, ok := v.(int64); ok && id > 0 { - return id, nil - } - if id, ok := v.(int); ok && id > 0 { - return int64(id), nil - } - } - } - - email := strings.TrimSpace(account.GetCredential("email")) - if email == "" { - email, _ = s.resolveAccountEmail(ctx, account) - } - if email == "" { - return 0, errors.New("sora2api token email missing") - } - - tokenID, err := s.findTokenIDByEmail(ctx, email) - if err != nil { - return 0, err - } - return tokenID, nil -} - -func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) { - if !s.Enabled() { - return 0, errors.New("sora2api admin not configured") - } - tokens, err := s.sora2api.ListTokens(ctx) - if err != nil { - return 0, err - } - for _, token := range tokens { - if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) { - return token.ID, nil - } - } - return 0, fmt.Errorf("sora2api token not found for email: %s", email) -} - -func extractEmailFromAccessToken(accessToken string) string { - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) - claims := jwt.MapClaims{} - _, _, err := parser.ParseUnverified(accessToken, claims) - if err != nil { - return "" - } - if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" { - return email - } - if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok { - if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" { - return email - } - } - return "" -} - -func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string { - if s.httpClient == nil { - return "" - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil) - if err != nil { - return "" - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - req.Header.Set("Accept", "application/json") - - resp, err := s.httpClient.Do(req) - if err != nil { - return "" - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return "" - } - var payload map[string]any - if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { - return "" - } - if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" { - return email - } - return "" -} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go new file mode 100644 index 00000000..9ecb4688 --- /dev/null +++ b/backend/internal/service/sora_client.go @@ -0,0 +1,884 @@ +package service + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math/rand" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/google/uuid" + "golang.org/x/crypto/sha3" +) + +const ( + soraChatGPTBaseURL = "https://chatgpt.com" + soraSentinelFlow = "sora_2_create_task" + soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" +) + +const ( + soraPowMaxIteration = 500000 +) + +var soraPowCores = []int{8, 16, 24, 32} + +var soraPowScripts = []string{ + "https://cdn.oaistatic.com/_next/static/cXh69klOLzS0Gy2joLDRS/_ssgManifest.js?dpl=453ebaec0d44c2decab71692e1bfe39be35a24b3", +} + +var soraPowDPL = []string{ + "prod-f501fe933b3edf57aea882da888e1a544df99840", +} + +var soraPowNavigatorKeys = []string{ + "registerProtocolHandler−function registerProtocolHandler() { [native code] }", + "storage−[object StorageManager]", + "locks−[object LockManager]", + "appCodeName−Mozilla", + "permissions−[object Permissions]", + "webdriver−false", + "vendor−Google Inc.", + "mediaDevices−[object MediaDevices]", + "cookieEnabled−true", + "product−Gecko", + "productSub−20030107", + "hardwareConcurrency−32", + "onLine−true", +} + +var soraPowDocumentKeys = []string{ + "_reactListeningo743lnnpvdg", + "location", +} + +var soraPowWindowKeys = []string{ + "0", "window", "self", "document", "name", "location", + "navigator", "screen", "innerWidth", "innerHeight", + "localStorage", "sessionStorage", "crypto", "performance", + "fetch", "setTimeout", "setInterval", "console", +} + +var soraDesktopUserAgents = []string{ + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", +} + +var soraRand = rand.New(rand.NewSource(time.Now().UnixNano())) +var soraRandMu sync.Mutex +var soraPerfStart = time.Now() + +// SoraClient 定义直连 Sora 的任务操作接口。 +type SoraClient interface { + Enabled() bool + UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) + CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) + CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) + GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) +} + +// SoraImageRequest 图片生成请求参数 +type SoraImageRequest struct { + Prompt string + Width int + Height int + MediaID string +} + +// SoraVideoRequest 视频生成请求参数 +type SoraVideoRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string + RemixTargetID string +} + +// SoraImageTaskStatus 图片任务状态 +type SoraImageTaskStatus struct { + ID string + Status string + ProgressPct float64 + URLs []string + ErrorMsg string +} + +// SoraVideoTaskStatus 视频任务状态 +type SoraVideoTaskStatus struct { + ID string + Status string + ProgressPct int + URLs []string + ErrorMsg string +} + +// SoraUpstreamError 上游错误 +type SoraUpstreamError struct { + StatusCode int + Message string + Headers http.Header + Body []byte +} + +func (e *SoraUpstreamError) Error() string { + if e == nil { + return "sora upstream error" + } + if e.Message != "" { + return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) + } + return fmt.Sprintf("sora upstream error: %d", e.StatusCode) +} + +// SoraDirectClient 直连 Sora 实现 +type SoraDirectClient struct { + cfg *config.Config + httpUpstream HTTPUpstream + tokenProvider *OpenAITokenProvider +} + +// NewSoraDirectClient 创建 Sora 直连客户端 +func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { + return &SoraDirectClient{ + cfg: cfg, + httpUpstream: httpUpstream, + tokenProvider: tokenProvider, + } +} + +// Enabled 判断是否启用 Sora 直连 +func (c *SoraDirectClient) Enabled() bool { + if c == nil || c.cfg == nil { + return false + } + return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != "" +} + +func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + if len(data) == 0 { + return "", errors.New("empty image data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + if filename == "" { + filename = "image.png" + } + var body bytes.Buffer + writer := multipart.NewWriter(&body) + contentType := mime.TypeByExtension(path.Ext(filename)) + if contentType == "" { + contentType = "application/octet-stream" + } + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, filename)) + partHeader.Set("Content-Type", contentType) + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("file_name", filename); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", writer.FormDataContentType()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) + if err != nil { + return "", err + } + var payload map[string]any + if err := json.Unmarshal(respBody, &payload); err != nil { + return "", fmt.Errorf("parse upload response: %w", err) + } + id, _ := payload["id"].(string) + if strings.TrimSpace(id) == "" { + return "", errors.New("upload response missing id") + } + return id, nil +} + +func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + operation := "simple_compose" + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + operation = "remix" + inpaintItems = append(inpaintItems, map[string]any{ + "type": "image", + "frame_index": 0, + "upload_media_id": req.MediaID, + }) + } + payload := map[string]any{ + "type": "image_gen", + "operation": operation, + "prompt": req.Prompt, + "width": req.Width, + "height": req.Height, + "n_variants": 1, + "n_frames": 1, + "inpaint_items": inpaintItems, + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + taskID, _ := resp["id"].(string) + if strings.TrimSpace(taskID) == "" { + return "", errors.New("image task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + model := req.Model + if model == "" { + model = "sy_8" + } + size := req.Size + if size == "" { + size = "small" + } + + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": req.MediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": req.Prompt, + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "model": model, + "inpaint_items": inpaintItems, + } + if strings.TrimSpace(req.RemixTargetID) != "" { + payload["remix_target_id"] = req.RemixTargetID + payload["cameo_ids"] = []string{} + payload["cameo_replacements"] = map[string]any{} + } + + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + taskID, _ := resp["id"].(string) + if strings.TrimSpace(taskID) == "" { + return "", errors.New("video task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false) + if err != nil { + return nil, err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return nil, err + } + taskResponses, _ := resp["task_responses"].([]any) + for _, item := range taskResponses { + taskResp, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := taskResp["id"].(string); id == taskID { + status := strings.TrimSpace(fmt.Sprintf("%v", taskResp["status"])) + progress := 0.0 + if v, ok := taskResp["progress_pct"].(float64); ok { + progress = v + } + urls := []string{} + if generations, ok := taskResp["generations"].([]any); ok { + for _, genItem := range generations { + gen, ok := genItem.(map[string]any) + if !ok { + continue + } + if urlStr, ok := gen["url"].(string); ok && strings.TrimSpace(urlStr) != "" { + urls = append(urls, urlStr) + } + } + } + return &SoraImageTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + URLs: urls, + }, nil + } + } + return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil +} + +func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) + if err != nil { + return nil, err + } + var pending any + if err := json.Unmarshal(respBody, &pending); err == nil { + if list, ok := pending.([]any); ok { + for _, item := range list { + task, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := task["id"].(string); id == taskID { + progress := 0 + if v, ok := task["progress_pct"].(float64); ok { + progress = int(v * 100) + } + status := strings.TrimSpace(fmt.Sprintf("%v", task["status"])) + return &SoraVideoTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + }, nil + } + } + } + } + + respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) + if err != nil { + return nil, err + } + var draftsResp map[string]any + if err := json.Unmarshal(respBody, &draftsResp); err != nil { + return nil, err + } + items, _ := draftsResp["items"].([]any) + for _, item := range items { + draft, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := draft["task_id"].(string); id == taskID { + kind := strings.TrimSpace(fmt.Sprintf("%v", draft["kind"])) + reason := strings.TrimSpace(fmt.Sprintf("%v", draft["reason_str"])) + if reason == "" { + reason = strings.TrimSpace(fmt.Sprintf("%v", draft["markdown_reason_str"])) + } + urlStr := strings.TrimSpace(fmt.Sprintf("%v", draft["downloadable_url"])) + if urlStr == "" { + urlStr = strings.TrimSpace(fmt.Sprintf("%v", draft["url"])) + } + + if kind == "sora_content_violation" || reason != "" || urlStr == "" { + msg := reason + if msg == "" { + msg = "Content violates guardrails" + } + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "failed", + ErrorMsg: msg, + }, nil + } + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "completed", + URLs: []string{urlStr}, + }, nil + } + } + + return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil +} + +func (c *SoraDirectClient) buildURL(endpoint string) string { + base := "" + if c != nil && c.cfg != nil { + base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/") + } + if base == "" { + return endpoint + } + if strings.HasPrefix(endpoint, "/") { + return base + endpoint + } + return base + "/" + endpoint +} + +func (c *SoraDirectClient) defaultUserAgent() string { + if c == nil || c.cfg == nil { + return soraDefaultUserAgent + } + ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent) + if ua == "" { + return soraDefaultUserAgent + } + return ua +} + +func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if c.tokenProvider != nil { + return c.tokenProvider.GetAccessToken(ctx, account) + } + token := strings.TrimSpace(account.GetCredential("access_token")) + if token == "" { + return "", errors.New("access_token not found") + } + return token, nil +} + +func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { + headers := http.Header{} + if token != "" { + headers.Set("Authorization", "Bearer "+token) + } + if userAgent != "" { + headers.Set("User-Agent", userAgent) + } + if c != nil && c.cfg != nil { + for key, value := range c.cfg.Sora.Client.Headers { + if strings.EqualFold(key, "authorization") || strings.EqualFold(key, "openai-sentinel-token") { + continue + } + headers.Set(key, value) + } + } + return headers +} + +func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) { + if strings.TrimSpace(urlStr) == "" { + return nil, nil, errors.New("empty upstream url") + } + timeout := 0 + if c != nil && c.cfg != nil { + timeout = c.cfg.Sora.Client.TimeoutSeconds + } + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + } + maxRetries := 0 + if allowRetry && c != nil && c.cfg != nil { + maxRetries = c.cfg.Sora.Client.MaxRetries + } + if maxRetries < 0 { + maxRetries = 0 + } + + var bodyBytes []byte + if body != nil { + b, err := io.ReadAll(body) + if err != nil { + return nil, nil, err + } + bodyBytes = b + } + + attempts := maxRetries + 1 + for attempt := 1; attempt <= attempts; attempt++ { + var reader io.Reader + if bodyBytes != nil { + reader = bytes.NewReader(bodyBytes) + } + req, err := http.NewRequestWithContext(ctx, method, urlStr, reader) + if err != nil { + return nil, nil, err + } + req.Header = headers.Clone() + start := time.Now() + + proxyURL := "" + if account != nil && account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := c.doHTTP(req, proxyURL, account) + if err != nil { + if attempt < attempts && allowRetry { + c.sleepRetry(attempt) + continue + } + return nil, nil, err + } + + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if readErr != nil { + return nil, resp.Header, readErr + } + + if c.cfg != nil && c.cfg.Sora.Client.Debug { + log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start)) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody) + if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { + c.sleepRetry(attempt) + continue + } + return nil, resp.Header, upstreamErr + } + return respBody, resp.Header, nil + } + return nil, nil, errors.New("upstream retries exhausted") +} + +func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { + enableTLS := false + if c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint { + enableTLS = true + } + if c.httpUpstream != nil { + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + return c.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS) + } + return http.DefaultClient.Do(req) +} + +func (c *SoraDirectClient) sleepRetry(attempt int) { + backoff := time.Duration(attempt*attempt) * time.Second + if backoff > 10*time.Second { + backoff = 10 * time.Second + } + time.Sleep(backoff) +} + +func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error { + msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + msg = sanitizeUpstreamErrorMessage(msg) + if msg == "" { + msg = truncateForLog(body, 256) + } + return &SoraUpstreamError{ + StatusCode: status, + Message: msg, + Headers: headers, + Body: body, + } +} + +func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) { + reqID := uuid.NewString() + userAgent := soraRandChoice(soraDesktopUserAgents) + powToken := soraGetPowToken(userAgent) + payload := map[string]any{ + "p": powToken, + "flow": soraSentinelFlow, + "id": reqID, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := http.Header{} + headers.Set("Accept", "application/json, text/plain, */*") + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + headers.Set("User-Agent", userAgent) + if accessToken != "" { + headers.Set("Authorization", "Bearer "+accessToken) + } + + urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req" + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + + sentinel := soraBuildSentinelToken(soraSentinelFlow, reqID, powToken, resp, userAgent) + if sentinel == "" { + return "", errors.New("failed to build sentinel token") + } + return sentinel, nil +} + +func soraRandChoice(items []string) string { + if len(items) == 0 { + return "" + } + soraRandMu.Lock() + idx := soraRand.Intn(len(items)) + soraRandMu.Unlock() + return items[idx] +} + +func soraGetPowToken(userAgent string) string { + configList := soraBuildPowConfig(userAgent) + seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64) + difficulty := "0fffff" + solution, _ := soraSolvePow(seed, difficulty, configList) + return "gAAAAAC" + solution +} + +func soraRandFloat() float64 { + soraRandMu.Lock() + defer soraRandMu.Unlock() + return soraRand.Float64() +} + +func soraBuildPowConfig(userAgent string) []any { + screen := soraRandChoice([]string{ + strconv.Itoa(1920 + 1080), + strconv.Itoa(2560 + 1440), + strconv.Itoa(1920 + 1200), + strconv.Itoa(2560 + 1600), + }) + screenVal, _ := strconv.Atoi(screen) + perfMs := float64(time.Since(soraPerfStart).Milliseconds()) + wallMs := float64(time.Now().UnixNano()) / 1e6 + diff := wallMs - perfMs + return []any{ + screenVal, + soraPowParseTime(), + 4294705152, + 0, + userAgent, + soraRandChoice(soraPowScripts), + soraRandChoice(soraPowDPL), + "en-US", + "en-US,es-US,en,es", + 0, + soraRandChoice(soraPowNavigatorKeys), + soraRandChoice(soraPowDocumentKeys), + soraRandChoice(soraPowWindowKeys), + perfMs, + uuid.NewString(), + "", + soraRandChoiceInt(soraPowCores), + diff, + } +} + +func soraRandChoiceInt(items []int) int { + if len(items) == 0 { + return 0 + } + soraRandMu.Lock() + idx := soraRand.Intn(len(items)) + soraRandMu.Unlock() + return items[idx] +} + +func soraPowParseTime() string { + loc := time.FixedZone("EST", -5*3600) + return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)") +} + +func soraSolvePow(seed, difficulty string, configList []any) (string, bool) { + diffLen := len(difficulty) / 2 + target, err := hexDecodeString(difficulty) + if err != nil { + return "", false + } + seedBytes := []byte(seed) + + part1 := mustMarshalJSON(configList[:3]) + part2 := mustMarshalJSON(configList[4:9]) + part3 := mustMarshalJSON(configList[10:]) + + staticPart1 := append(part1[:len(part1)-1], ',') + staticPart2 := append([]byte(","), append(part2[1:len(part2)-1], ',')...) + staticPart3 := append([]byte(","), part3[1:]...) + + for i := 0; i < soraPowMaxIteration; i++ { + dynamicI := []byte(strconv.Itoa(i)) + dynamicJ := []byte(strconv.Itoa(i >> 1)) + finalJSON := make([]byte, 0, len(staticPart1)+len(dynamicI)+len(staticPart2)+len(dynamicJ)+len(staticPart3)) + finalJSON = append(finalJSON, staticPart1...) + finalJSON = append(finalJSON, dynamicI...) + finalJSON = append(finalJSON, staticPart2...) + finalJSON = append(finalJSON, dynamicJ...) + finalJSON = append(finalJSON, staticPart3...) + + b64 := base64.StdEncoding.EncodeToString(finalJSON) + hash := sha3.Sum512(append(seedBytes, []byte(b64)...)) + if bytes.Compare(hash[:diffLen], target[:diffLen]) <= 0 { + return b64, true + } + } + + errorToken := "wQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\"%s\"", seed))) + return errorToken, false +} + +func soraBuildSentinelToken(flow, reqID, powToken string, resp map[string]any, userAgent string) string { + finalPow := powToken + proof, _ := resp["proofofwork"].(map[string]any) + if required, _ := proof["required"].(bool); required { + seed, _ := proof["seed"].(string) + difficulty, _ := proof["difficulty"].(string) + if seed != "" && difficulty != "" { + configList := soraBuildPowConfig(userAgent) + solution, _ := soraSolvePow(seed, difficulty, configList) + finalPow = "gAAAAAB" + solution + } + } + if !strings.HasSuffix(finalPow, "~S") { + finalPow += "~S" + } + turnstile, _ := resp["turnstile"].(map[string]any) + tokenPayload := map[string]any{ + "p": finalPow, + "t": safeMapString(turnstile, "dx"), + "c": safeString(resp["token"]), + "id": reqID, + "flow": flow, + } + encoded, _ := json.Marshal(tokenPayload) + return string(encoded) +} + +func safeMapString(m map[string]any, key string) string { + if m == nil { + return "" + } + if v, ok := m[key]; ok { + return safeString(v) + } + return "" +} + +func safeString(v any) string { + switch val := v.(type) { + case string: + return val + default: + return fmt.Sprintf("%v", val) + } +} + +func mustMarshalJSON(v any) []byte { + b, _ := json.Marshal(v) + return b +} + +func hexDecodeString(s string) ([]byte, error) { + dst := make([]byte, len(s)/2) + _, err := hex.Decode(dst, []byte(s)) + return dst, err +} + +func sanitizeSoraLogURL(raw string) string { + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + q := parsed.Query() + q.Del("sig") + q.Del("expires") + parsed.RawQuery = q.Encode() + return parsed.String() +} diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go new file mode 100644 index 00000000..abbe47a1 --- /dev/null +++ b/backend/internal/service/sora_client_test.go @@ -0,0 +1,54 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSoraDirectClient_DoRequestSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{BaseURL: server.URL}, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + + body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false) + require.NoError(t, err) + require.Contains(t, string(body), "ok") +} + +func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + Headers: map[string]string{ + "X-Test": "yes", + "Authorization": "should-ignore", + "openai-sentinel-token": "skip", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + + headers := client.buildBaseHeaders("token-123", "UA") + require.Equal(t, "Bearer token-123", headers.Get("Authorization")) + require.Equal(t, "UA", headers.Get("User-Agent")) + require.Equal(t, "yes", headers.Get("X-Test")) + require.Empty(t, headers.Get("openai-sentinel-token")) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 2909a76f..49cd7bba 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -4,10 +4,12 @@ import ( "bufio" "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" "io" + "mime" "net/http" "net/url" "regexp" @@ -39,23 +41,23 @@ type soraStreamingResult struct { firstTokenMs *int } -// SoraGatewayService handles forwarding requests to sora2api. +// SoraGatewayService handles forwarding requests to Sora upstream. type SoraGatewayService struct { - sora2api *Sora2APIService - httpUpstream HTTPUpstream + soraClient SoraClient + mediaStorage *SoraMediaStorage rateLimitService *RateLimitService cfg *config.Config } func NewSoraGatewayService( - sora2api *Sora2APIService, - httpUpstream HTTPUpstream, + soraClient SoraClient, + mediaStorage *SoraMediaStorage, rateLimitService *RateLimitService, cfg *config.Config, ) *SoraGatewayService { return &SoraGatewayService{ - sora2api: sora2api, - httpUpstream: httpUpstream, + soraClient: soraClient, + mediaStorage: mediaStorage, rateLimitService: rateLimitService, cfg: cfg, } @@ -64,31 +66,53 @@ func NewSoraGatewayService( func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { startTime := time.Now() - if s.sora2api == nil || !s.sora2api.Enabled() { + if s.soraClient == nil || !s.soraClient.Enabled() { if c != nil { c.JSON(http.StatusServiceUnavailable, gin.H{ "error": gin.H{ "type": "api_error", - "message": "sora2api 未配置", + "message": "Sora 上游未配置", }, }) } - return nil, errors.New("sora2api not configured") + return nil, errors.New("sora upstream not configured") } var reqBody map[string]any if err := json.Unmarshal(body, &reqBody); err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) return nil, fmt.Errorf("parse request: %w", err) } reqModel, _ := reqBody["model"].(string) reqStream, _ := reqBody["stream"].(bool) + if strings.TrimSpace(reqModel) == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) + return nil, errors.New("model is required") + } mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel && mappedModel != "" { - reqBody["model"] = mappedModel - if updated, err := json.Marshal(reqBody); err == nil { - body = updated - } + if mappedModel != "" && mappedModel != reqModel { + reqModel = mappedModel + } + + modelCfg, ok := GetSoraModelConfig(reqModel) + if !ok { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) + return nil, fmt.Errorf("unsupported model: %s", reqModel) + } + if modelCfg.Type == "prompt_enhance" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream) + return nil, fmt.Errorf("prompt-enhance not supported") + } + + prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) + if strings.TrimSpace(prompt) == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + if strings.TrimSpace(videoInput) != "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream) + return nil, errors.New("video input not supported") } reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) @@ -96,81 +120,122 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun defer cancel() } - upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body) - if err != nil { - return nil, err - } - if c != nil { - if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { - upstreamReq.Header.Set("User-Agent", ua) + var imageData []byte + imageFilename := "" + if strings.TrimSpace(imageInput) != "" { + decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) + if err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) + return nil, err } - } - if reqStream { - upstreamReq.Header.Set("Accept", "text/event-stream") + imageData = decoded + imageFilename = filename } - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) + mediaID := "" + if len(imageData) > 0 { + uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + mediaID = uploadID } - proxyURL := "" - if account != nil && account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + taskID := "" + var err error + switch modelCfg.Type { + case "image": + taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ + Prompt: prompt, + Width: modelCfg.Width, + Height: modelCfg.Height, + MediaID: mediaID, + }) + case "video": + taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ + Prompt: prompt, + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + RemixTargetID: remixTargetID, + }) + default: + err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) + } + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) } - var resp *http.Response - if s.httpUpstream != nil { - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if clientStream && c != nil { + s.prepareSoraStream(c, taskID) + } + + var mediaURLs []string + mediaType := modelCfg.Type + imageCount := 0 + imageSize := "" + if modelCfg.Type == "image" { + urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + mediaURLs = urls + imageCount = len(urls) + imageSize = soraImageSizeFromModel(reqModel) + } else if modelCfg.Type == "video" { + urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + mediaURLs = urls } else { - resp, err = http.DefaultClient.Do(upstreamReq) + mediaType = "prompt" } - if err != nil { - s.setUpstreamRequestError(c, account, err) - return nil, fmt.Errorf("upstream request failed: %w", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode >= 400 { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "failover", - Message: upstreamMsg, - }) - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + finalURLs := mediaURLs + if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { + stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) + if storeErr != nil { + return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream) } - return s.handleErrorResponse(ctx, resp, c, account, reqModel) + finalURLs = s.normalizeSoraMediaURLs(stored) + } else { + finalURLs = s.normalizeSoraMediaURLs(mediaURLs) } - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream) - if err != nil { - return nil, err + content := buildSoraContent(mediaType, finalURLs) + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + response := buildSoraNonStreamResponse(content, reqModel) + if len(finalURLs) > 0 { + response["media_url"] = finalURLs[0] + if len(finalURLs) > 1 { + response["media_urls"] = finalURLs + } + } + c.JSON(http.StatusOK, response) } - result := &ForwardResult{ - RequestID: resp.Header.Get("x-request-id"), + return &ForwardResult{ + RequestID: taskID, Model: reqModel, Stream: clientStream, Duration: time.Since(startTime), - FirstTokenMs: streamResult.firstTokenMs, + FirstTokenMs: firstTokenMs, Usage: ClaudeUsage{}, - MediaType: streamResult.mediaType, - MediaURL: firstMediaURL(streamResult.mediaURLs), - ImageCount: streamResult.imageCount, - ImageSize: streamResult.imageSize, - } - - return result, nil + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, + }, nil } func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { @@ -780,3 +845,414 @@ func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) str } return prefix + path + "?" + encoded } + +func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { + if c == nil { + return + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if strings.TrimSpace(requestID) != "" { + c.Header("x-request-id", requestID) + } +} + +func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { + if c == nil { + return nil, nil + } + writer := c.Writer + flusher, _ := writer.(http.Flusher) + + chunk := map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "content": content, + }, + }, + }, + } + encoded, _ := json.Marshal(chunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { + return nil, err + } + if flusher != nil { + flusher.Flush() + } + ms := int(time.Since(startTime).Milliseconds()) + finalChunk := map[string]any{ + "id": chunk["id"], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }, + }, + } + finalEncoded, _ := json.Marshal(finalChunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { + return &ms, err + } + if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { + return &ms, err + } + if flusher != nil { + flusher.Flush() + } + return &ms, nil +} + +func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { + if c == nil { + return + } + if stream { + flusher, _ := c.Writer.(http.Flusher) + errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + _, _ = fmt.Fprint(c.Writer, errorEvent) + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + return + } + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { + if err == nil { + return nil + } + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) { + if s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) + } + if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { + return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode} + } + msg := upstreamErr.Message + if override := soraProErrorMessage(model, msg); override != "" { + msg = override + } + s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) + return err + } + if errors.Is(err, context.DeadlineExceeded) { + s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) + return err + } + s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) + return err +} + +func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetImageTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "succeeded", "completed": + return status.URLs, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("Sora image generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("Sora image generation timeout") +} + +func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetVideoTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "completed", "succeeded": + return status.URLs, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("Sora video generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("Sora video generation timeout") +} + +func (s *SoraGatewayService) pollInterval() time.Duration { + if s == nil || s.cfg == nil { + return 2 * time.Second + } + interval := s.cfg.Sora.Client.PollIntervalSeconds + if interval <= 0 { + interval = 2 + } + return time.Duration(interval) * time.Second +} + +func (s *SoraGatewayService) pollMaxAttempts() int { + if s == nil || s.cfg == nil { + return 600 + } + maxAttempts := s.cfg.Sora.Client.MaxPollAttempts + if maxAttempts <= 0 { + maxAttempts = 600 + } + return maxAttempts +} + +func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { + if c == nil { + return + } + interval := 10 * time.Second + if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { + interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second + } + if time.Since(*lastPing) < interval { + return + } + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + *lastPing = time.Now() + } +} + +func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { + if len(urls) == 0 { + return urls + } + output := make([]string, 0, len(urls)) + for _, raw := range urls { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + output = append(output, raw) + continue + } + pathVal := raw + if !strings.HasPrefix(pathVal, "/") { + pathVal = "/" + pathVal + } + output = append(output, s.buildSoraMediaURL(pathVal, "")) + } + return output +} + +func buildSoraContent(mediaType string, urls []string) string { + switch mediaType { + case "image": + parts := make([]string, 0, len(urls)) + for _, u := range urls { + parts = append(parts, fmt.Sprintf("![image](%s)", u)) + } + return strings.Join(parts, "\n") + case "video": + if len(urls) == 0 { + return "" + } + return fmt.Sprintf("```html\n\n```", urls[0]) + default: + return "" + } +} + +func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { + if body == nil { + return "", "", "", "" + } + if v, ok := body["remix_target_id"].(string); ok { + remixTargetID = v + } + if v, ok := body["image"].(string); ok { + imageInput = v + } + if v, ok := body["video"].(string); ok { + videoInput = v + } + if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { + prompt = v + } + if messages, ok := body["messages"].([]any); ok { + builder := strings.Builder{} + for _, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if role != "" && role != "user" { + continue + } + content := msg["content"] + text, img, vid := parseSoraMessageContent(content) + if text != "" { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString(text) + } + if imageInput == "" && img != "" { + imageInput = img + } + if videoInput == "" && vid != "" { + videoInput = vid + } + } + if prompt == "" { + prompt = builder.String() + } + } + return prompt, imageInput, videoInput, remixTargetID +} + +func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { + switch val := content.(type) { + case string: + return val, "", "" + case []any: + builder := strings.Builder{} + for _, item := range val { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + t, _ := itemMap["type"].(string) + switch t { + case "text": + if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString(txt) + } + case "image_url": + if imageInput == "" { + if urlVal, ok := itemMap["image_url"].(map[string]any); ok { + imageInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["image_url"].(string); ok { + imageInput = urlStr + } + } + case "video_url": + if videoInput == "" { + if urlVal, ok := itemMap["video_url"].(map[string]any); ok { + videoInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["video_url"].(string); ok { + videoInput = urlStr + } + } + } + } + return builder.String(), imageInput, videoInput + default: + return "", "", "" + } +} + +func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, "", errors.New("empty image input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, "", errors.New("invalid data url") + } + meta := parts[0] + payload := parts[1] + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return nil, "", err + } + ext := "" + if strings.HasPrefix(meta, "data:") { + metaParts := strings.SplitN(meta[5:], ";", 2) + if len(metaParts) > 0 { + if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { + ext = exts[0] + } + } + } + filename := "image" + ext + return decoded, filename, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraImageInput(ctx, raw) + } + decoded, err := base64.StdEncoding.DecodeString(raw) + if err != nil { + return nil, "", errors.New("invalid base64 image") + } + return decoded, "image.png", nil +} + +func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return nil, "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, "", err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20)) + if err != nil { + return nil, "", err + } + ext := fileExtFromURL(rawURL) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + filename := "image" + ext + return data, filename, nil +} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go new file mode 100644 index 00000000..e4de8256 --- /dev/null +++ b/backend/internal/service/sora_gateway_service_test.go @@ -0,0 +1,99 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type stubSoraClientForPoll struct { + imageStatus *SoraImageTaskStatus + videoStatus *SoraVideoTaskStatus + imageCalls int + videoCalls int +} + +func (s *stubSoraClientForPoll) Enabled() bool { return true } +func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + return "", nil +} +func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + s.imageCalls++ + return s.imageStatus, nil +} +func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + s.videoCalls++ + return s.videoStatus, nil +} + +func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { + client := &stubSoraClientForPoll{ + imageStatus: &SoraImageTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/a.png"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) + require.NoError(t, err) + require.Equal(t, []string{"https://example.com/a.png"}, urls) + require.Equal(t, 1, client.imageCalls) +} + +func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "reject", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false) + require.Error(t, err) + require.Empty(t, urls) + require.Contains(t, err.Error(), "reject") + require.Equal(t, 1, client.videoCalls) +} + +func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + SoraMediaSigningKey: "test-key", + SoraMediaSignedURLTTLSeconds: 600, + }, + } + service := NewSoraGatewayService(nil, nil, nil, cfg) + + url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") + require.Contains(t, url, "/sora/media-signed") + require.Contains(t, url, "expires=") + require.Contains(t, url, "sig=") +} diff --git a/backend/internal/service/sora_media_cleanup_service.go b/backend/internal/service/sora_media_cleanup_service.go new file mode 100644 index 00000000..7de0f1c4 --- /dev/null +++ b/backend/internal/service/sora_media_cleanup_service.go @@ -0,0 +1,117 @@ +package service + +import ( + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/robfig/cron/v3" +) + +var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + +// SoraMediaCleanupService 定期清理本地媒体文件 +type SoraMediaCleanupService struct { + storage *SoraMediaStorage + cfg *config.Config + + cron *cron.Cron + + startOnce sync.Once + stopOnce sync.Once +} + +func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { + return &SoraMediaCleanupService{ + storage: storage, + cfg: cfg, + } +} + +func (s *SoraMediaCleanupService) Start() { + if s == nil || s.cfg == nil { + return + } + if !s.cfg.Sora.Storage.Cleanup.Enabled { + log.Printf("[SoraCleanup] not started (disabled)") + return + } + if s.storage == nil || !s.storage.Enabled() { + log.Printf("[SoraCleanup] not started (storage disabled)") + return + } + + s.startOnce.Do(func() { + schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule) + if schedule == "" { + log.Printf("[SoraCleanup] not started (empty schedule)") + return + } + loc := time.Local + if strings.TrimSpace(s.cfg.Timezone) != "" { + if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil { + loc = parsed + } + } + c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc)) + if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil { + log.Printf("[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err) + return + } + s.cron = c + s.cron.Start() + log.Printf("[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) + }) +} + +func (s *SoraMediaCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.cron != nil { + ctx := s.cron.Stop() + select { + case <-ctx.Done(): + case <-time.After(3 * time.Second): + log.Printf("[SoraCleanup] cron stop timed out") + } + } + }) +} + +func (s *SoraMediaCleanupService) runCleanup() { + retention := s.cfg.Sora.Storage.Cleanup.RetentionDays + if retention <= 0 { + log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention) + return + } + cutoff := time.Now().AddDate(0, 0, -retention) + deleted := 0 + + roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()} + for _, root := range roots { + if root == "" { + continue + } + _ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + return nil + } + if info.ModTime().Before(cutoff) { + if rmErr := os.Remove(p); rmErr == nil { + deleted++ + } + } + return nil + }) + } + log.Printf("[SoraCleanup] cleanup finished, deleted=%d", deleted) +} diff --git a/backend/internal/service/sora_media_cleanup_service_test.go b/backend/internal/service/sora_media_cleanup_service_test.go new file mode 100644 index 00000000..63204104 --- /dev/null +++ b/backend/internal/service/sora_media_cleanup_service_test.go @@ -0,0 +1,46 @@ +//go:build unit + +package service + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSoraMediaCleanupService_RunCleanup(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + Cleanup: config.SoraStorageCleanupConfig{ + Enabled: true, + RetentionDays: 1, + }, + }, + }, + } + + storage := NewSoraMediaStorage(cfg) + require.NoError(t, storage.EnsureLocalDirs()) + + oldImage := filepath.Join(storage.ImageRoot(), "old.png") + newVideo := filepath.Join(storage.VideoRoot(), "new.mp4") + require.NoError(t, os.WriteFile(oldImage, []byte("old"), 0o644)) + require.NoError(t, os.WriteFile(newVideo, []byte("new"), 0o644)) + + oldTime := time.Now().Add(-48 * time.Hour) + require.NoError(t, os.Chtimes(oldImage, oldTime, oldTime)) + + cleanup := NewSoraMediaCleanupService(storage, cfg) + cleanup.runCleanup() + + require.NoFileExists(t, oldImage) + require.FileExists(t, newVideo) +} diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go new file mode 100644 index 00000000..53214bb7 --- /dev/null +++ b/backend/internal/service/sora_media_storage.go @@ -0,0 +1,256 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "mime" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/google/uuid" +) + +const ( + soraStorageDefaultRoot = "/app/data/sora" +) + +// SoraMediaStorage 负责下载并落地 Sora 媒体 +type SoraMediaStorage struct { + cfg *config.Config + root string + imageRoot string + videoRoot string + maxConcurrent int + fallbackToUpstream bool + debug bool + sem chan struct{} + ready bool +} + +func NewSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { + storage := &SoraMediaStorage{cfg: cfg} + storage.refreshConfig() + if storage.Enabled() { + if err := storage.EnsureLocalDirs(); err != nil { + log.Printf("[SoraStorage] 初始化失败: %v", err) + } + } + return storage +} + +func (s *SoraMediaStorage) Enabled() bool { + if s == nil || s.cfg == nil { + return false + } + return strings.ToLower(strings.TrimSpace(s.cfg.Sora.Storage.Type)) == "local" +} + +func (s *SoraMediaStorage) Root() string { + if s == nil { + return "" + } + return s.root +} + +func (s *SoraMediaStorage) ImageRoot() string { + if s == nil { + return "" + } + return s.imageRoot +} + +func (s *SoraMediaStorage) VideoRoot() string { + if s == nil { + return "" + } + return s.videoRoot +} + +func (s *SoraMediaStorage) refreshConfig() { + if s == nil || s.cfg == nil { + return + } + root := strings.TrimSpace(s.cfg.Sora.Storage.LocalPath) + if root == "" { + root = soraStorageDefaultRoot + } + s.root = root + s.imageRoot = filepath.Join(root, "image") + s.videoRoot = filepath.Join(root, "video") + + maxConcurrent := s.cfg.Sora.Storage.MaxConcurrentDownloads + if maxConcurrent <= 0 { + maxConcurrent = 4 + } + s.maxConcurrent = maxConcurrent + s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream + s.debug = s.cfg.Sora.Storage.Debug + s.sem = make(chan struct{}, maxConcurrent) +} + +// EnsureLocalDirs 创建并校验本地目录 +func (s *SoraMediaStorage) EnsureLocalDirs() error { + if s == nil || !s.Enabled() { + return nil + } + if err := os.MkdirAll(s.imageRoot, 0o755); err != nil { + return fmt.Errorf("create image dir: %w", err) + } + if err := os.MkdirAll(s.videoRoot, 0o755); err != nil { + return fmt.Errorf("create video dir: %w", err) + } + s.ready = true + return nil +} + +// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL +func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, urls []string) ([]string, error) { + if len(urls) == 0 { + return nil, nil + } + if s == nil || !s.Enabled() { + return urls, nil + } + if !s.ready { + if err := s.EnsureLocalDirs(); err != nil { + return nil, err + } + } + results := make([]string, 0, len(urls)) + for _, raw := range urls { + relative, err := s.downloadAndStore(ctx, mediaType, raw) + if err != nil { + if s.fallbackToUpstream { + results = append(results, raw) + continue + } + return nil, err + } + results = append(results, relative) + } + return results, nil +} + +func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) { + if strings.TrimSpace(rawURL) == "" { + return "", errors.New("empty url") + } + root := s.imageRoot + if mediaType == "video" { + root = s.videoRoot + } + if root == "" { + return "", errors.New("storage root not configured") + } + + retries := 3 + for attempt := 1; attempt <= retries; attempt++ { + release, err := s.acquire(ctx) + if err != nil { + return "", err + } + relative, err := s.downloadOnce(ctx, root, mediaType, rawURL) + release() + if err == nil { + return relative, nil + } + if s.debug { + log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err) + } + if attempt < retries { + time.Sleep(time.Duration(attempt*attempt) * time.Second) + continue + } + return "", err + } + return "", errors.New("download retries exhausted") +} + +func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, rawURL string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body)) + } + + ext := fileExtFromURL(rawURL) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + if ext == "" { + ext = ".bin" + } + + datePath := time.Now().Format("2006/01/02") + destDir := filepath.Join(root, filepath.FromSlash(datePath)) + if err := os.MkdirAll(destDir, 0o755); err != nil { + return "", err + } + filename := uuid.NewString() + ext + destPath := filepath.Join(destDir, filename) + out, err := os.Create(destPath) + if err != nil { + return "", err + } + defer func() { _ = out.Close() }() + + if _, err := io.Copy(out, resp.Body); err != nil { + _ = os.Remove(destPath) + return "", err + } + + relative := path.Join("/", mediaType, datePath, filename) + if s.debug { + log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative) + } + return relative, nil +} + +func (s *SoraMediaStorage) acquire(ctx context.Context) (func(), error) { + if s.sem == nil { + return func() {}, nil + } + select { + case s.sem <- struct{}{}: + return func() { <-s.sem }, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func fileExtFromURL(raw string) string { + parsed, err := url.Parse(raw) + if err != nil { + return "" + } + ext := path.Ext(parsed.Path) + return strings.ToLower(ext) +} + +func fileExtFromContentType(ct string) string { + if ct == "" { + return "" + } + if exts, err := mime.ExtensionsByType(ct); err == nil && len(exts) > 0 { + return strings.ToLower(exts[0]) + } + return "" +} diff --git a/backend/internal/service/sora_media_storage_test.go b/backend/internal/service/sora_media_storage_test.go new file mode 100644 index 00000000..f86234d2 --- /dev/null +++ b/backend/internal/service/sora_media_storage_test.go @@ -0,0 +1,69 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSoraMediaStorage_StoreFromURLs(t *testing.T) { + tmpDir := t.TempDir() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data")) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + MaxConcurrentDownloads: 1, + }, + }, + } + + storage := NewSoraMediaStorage(cfg) + urls, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"}) + require.NoError(t, err) + require.Len(t, urls, 1) + require.True(t, strings.HasPrefix(urls[0], "/image/")) + require.True(t, strings.HasSuffix(urls[0], ".png")) + + localPath := filepath.Join(tmpDir, filepath.FromSlash(strings.TrimPrefix(urls[0], "/"))) + require.FileExists(t, localPath) +} + +func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) { + tmpDir := t.TempDir() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + FallbackToUpstream: true, + }, + }, + } + + storage := NewSoraMediaStorage(cfg) + url := server.URL + "/broken.png" + urls, err := storage.StoreFromURLs(context.Background(), "image", []string{url}) + require.NoError(t, err) + require.Equal(t, []string{url}, urls) +} diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go new file mode 100644 index 00000000..ab095e46 --- /dev/null +++ b/backend/internal/service/sora_models.go @@ -0,0 +1,252 @@ +package service + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" +) + +// SoraModelConfig Sora 模型配置 +type SoraModelConfig struct { + Type string + Width int + Height int + Orientation string + Frames int + Model string + Size string + RequirePro bool +} + +var soraModelConfigs = map[string]SoraModelConfig{ + "gpt-image": { + Type: "image", + Width: 360, + Height: 360, + }, + "gpt-image-landscape": { + Type: "image", + Width: 540, + Height: 360, + }, + "gpt-image-portrait": { + Type: "image", + Width: 360, + Height: 540, + }, + "sora2-landscape-10s": { + Type: "video", + Orientation: "landscape", + Frames: 300, + Model: "sy_8", + Size: "small", + }, + "sora2-portrait-10s": { + Type: "video", + Orientation: "portrait", + Frames: 300, + Model: "sy_8", + Size: "small", + }, + "sora2-landscape-15s": { + Type: "video", + Orientation: "landscape", + Frames: 450, + Model: "sy_8", + Size: "small", + }, + "sora2-portrait-15s": { + Type: "video", + Orientation: "portrait", + Frames: 450, + Model: "sy_8", + Size: "small", + }, + "sora2-landscape-25s": { + Type: "video", + Orientation: "landscape", + Frames: 750, + Model: "sy_8", + Size: "small", + RequirePro: true, + }, + "sora2-portrait-25s": { + Type: "video", + Orientation: "portrait", + Frames: 750, + Model: "sy_8", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-10s": { + Type: "video", + Orientation: "landscape", + Frames: 300, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-10s": { + Type: "video", + Orientation: "portrait", + Frames: 300, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-15s": { + Type: "video", + Orientation: "landscape", + Frames: 450, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-15s": { + Type: "video", + Orientation: "portrait", + Frames: 450, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-landscape-25s": { + Type: "video", + Orientation: "landscape", + Frames: 750, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-portrait-25s": { + Type: "video", + Orientation: "portrait", + Frames: 750, + Model: "sy_ore", + Size: "small", + RequirePro: true, + }, + "sora2pro-hd-landscape-10s": { + Type: "video", + Orientation: "landscape", + Frames: 300, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-portrait-10s": { + Type: "video", + Orientation: "portrait", + Frames: 300, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-landscape-15s": { + Type: "video", + Orientation: "landscape", + Frames: 450, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "sora2pro-hd-portrait-15s": { + Type: "video", + Orientation: "portrait", + Frames: 450, + Model: "sy_ore", + Size: "large", + RequirePro: true, + }, + "prompt-enhance-short-10s": { + Type: "prompt_enhance", + }, + "prompt-enhance-short-15s": { + Type: "prompt_enhance", + }, + "prompt-enhance-short-20s": { + Type: "prompt_enhance", + }, + "prompt-enhance-medium-10s": { + Type: "prompt_enhance", + }, + "prompt-enhance-medium-15s": { + Type: "prompt_enhance", + }, + "prompt-enhance-medium-20s": { + Type: "prompt_enhance", + }, + "prompt-enhance-long-10s": { + Type: "prompt_enhance", + }, + "prompt-enhance-long-15s": { + Type: "prompt_enhance", + }, + "prompt-enhance-long-20s": { + Type: "prompt_enhance", + }, +} + +var soraModelIDs = []string{ + "gpt-image", + "gpt-image-landscape", + "gpt-image-portrait", + "sora2-landscape-10s", + "sora2-portrait-10s", + "sora2-landscape-15s", + "sora2-portrait-15s", + "sora2-landscape-25s", + "sora2-portrait-25s", + "sora2pro-landscape-10s", + "sora2pro-portrait-10s", + "sora2pro-landscape-15s", + "sora2pro-portrait-15s", + "sora2pro-landscape-25s", + "sora2pro-portrait-25s", + "sora2pro-hd-landscape-10s", + "sora2pro-hd-portrait-10s", + "sora2pro-hd-landscape-15s", + "sora2pro-hd-portrait-15s", + "prompt-enhance-short-10s", + "prompt-enhance-short-15s", + "prompt-enhance-short-20s", + "prompt-enhance-medium-10s", + "prompt-enhance-medium-15s", + "prompt-enhance-medium-20s", + "prompt-enhance-long-10s", + "prompt-enhance-long-15s", + "prompt-enhance-long-20s", +} + +// GetSoraModelConfig 返回 Sora 模型配置 +func GetSoraModelConfig(model string) (SoraModelConfig, bool) { + key := strings.ToLower(strings.TrimSpace(model)) + cfg, ok := soraModelConfigs[key] + return cfg, ok +} + +// DefaultSoraModels returns the default Sora model list. +func DefaultSoraModels(cfg *config.Config) []openai.Model { + models := make([]openai.Model, 0, len(soraModelIDs)) + for _, id := range soraModelIDs { + models = append(models, openai.Model{ + ID: id, + Object: "model", + OwnedBy: "openai", + Type: "model", + DisplayName: id, + }) + } + if cfg != nil && cfg.Gateway.SoraModelFilters.HidePromptEnhance { + filtered := models[:0] + for _, model := range models { + if strings.HasPrefix(strings.ToLower(model.ID), "prompt-enhance") { + continue + } + filtered = append(filtered, model) + } + models = filtered + } + return models +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 435056ab..7dccf393 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -63,16 +63,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) { } } -// SetSoraSyncService 设置 Sora2API 同步服务 -// 需要在 Start() 之前调用 -func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) { - for _, refresher := range s.refreshers { - if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok { - openaiRefresher.SetSoraSyncService(svc) - } - } -} - // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 7e084bd5..46033f75 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -86,7 +86,6 @@ type OpenAITokenRefresher struct { openaiOAuthService *OpenAIOAuthService accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 - soraSyncService *Sora2APISyncService // Sora2API 同步服务 } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -104,11 +103,6 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { r.soraAccountRepo = repo } -// SetSoraSyncService 设置 Sora2API 同步服务 -func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) { - r.soraSyncService = svc -} - // CanRefresh 检查是否能处理此账号 // 只处理 openai 平台的 oauth 类型账号 func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { @@ -151,17 +145,6 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) } - // 如果是 Sora 平台账号,同步到 sora2api(不阻塞主流程) - if account.Platform == PlatformSora && r.soraSyncService != nil { - syncAccount := *account - syncAccount.Credentials = newCredentials - go func() { - if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil { - log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err) - } - }() - } - return newCredentials, nil } @@ -218,13 +201,6 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena } } - // 2.3 同步到 sora2api(如果配置) - if r.soraSyncService != nil { - if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil { - log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err) - } - } - log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v", soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil) } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index fb0946d2..9c13be93 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { func ProvideTokenRefreshService( accountRepo AccountRepository, soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步 - soraSyncService *Sora2APISyncService, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, geminiOAuthService *GeminiOAuthService, @@ -51,7 +50,6 @@ func ProvideTokenRefreshService( svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg) // 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表 svc.SetSoraAccountRepo(soraAccountRepo) - svc.SetSoraSyncService(soraSyncService) svc.Start() return svc } @@ -187,6 +185,18 @@ func ProvideOpsCleanupService( return svc } +// ProvideSoraMediaStorage 初始化 Sora 媒体存储 +func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { + return NewSoraMediaStorage(cfg) +} + +// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务 +func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { + svc := NewSoraMediaCleanupService(storage, cfg) + svc.Start() + return svc +} + // ProvideOpsScheduledReportService creates and starts OpsScheduledReportService. func ProvideOpsScheduledReportService( opsService *OpsService, @@ -226,6 +236,10 @@ var ProviderSet = wire.NewSet( NewBillingCacheService, NewAdminService, NewGatewayService, + ProvideSoraMediaStorage, + ProvideSoraMediaCleanupService, + NewSoraDirectClient, + wire.Bind(new(SoraClient), new(*SoraDirectClient)), NewSoraGatewayService, NewOpenAIGatewayService, NewOAuthService, diff --git a/build_image.sh b/build_image.sh new file mode 100755 index 00000000..2cea4925 --- /dev/null +++ b/build_image.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# 本地构建镜像的快速脚本,避免在命令行反复输入构建参数。 + +docker build -t sub2api:latest \ + --build-arg GOPROXY=https://goproxy.cn,direct \ + --build-arg GOSUMDB=sum.golang.google.cn \ + -f Dockerfile \ + . diff --git a/deploy/Dockerfile b/deploy/Dockerfile new file mode 100644 index 00000000..b3320300 --- /dev/null +++ b/deploy/Dockerfile @@ -0,0 +1,111 @@ +# ============================================================================= +# Sub2API Multi-Stage Dockerfile +# ============================================================================= +# Stage 1: Build frontend +# Stage 2: Build Go backend with embedded frontend +# Stage 3: Final minimal image +# ============================================================================= + +ARG NODE_IMAGE=node:24-alpine +ARG GOLANG_IMAGE=golang:1.25.5-alpine +ARG ALPINE_IMAGE=alpine:3.20 +ARG GOPROXY=https://goproxy.cn,direct +ARG GOSUMDB=sum.golang.google.cn + +# ----------------------------------------------------------------------------- +# Stage 1: Frontend Builder +# ----------------------------------------------------------------------------- +FROM ${NODE_IMAGE} AS frontend-builder + +WORKDIR /app/frontend + +# Install pnpm +RUN corepack enable && corepack prepare pnpm@latest --activate + +# Install dependencies first (better caching) +COPY frontend/package.json frontend/pnpm-lock.yaml ./ +RUN pnpm install --frozen-lockfile + +# Copy frontend source and build +COPY frontend/ ./ +RUN pnpm run build + +# ----------------------------------------------------------------------------- +# Stage 2: Backend Builder +# ----------------------------------------------------------------------------- +FROM ${GOLANG_IMAGE} AS backend-builder + +# Build arguments for version info (set by CI) +ARG VERSION=docker +ARG COMMIT=docker +ARG DATE +ARG GOPROXY +ARG GOSUMDB + +ENV GOPROXY=${GOPROXY} +ENV GOSUMDB=${GOSUMDB} + +# Install build dependencies +RUN apk add --no-cache git ca-certificates tzdata + +WORKDIR /app/backend + +# Copy go mod files first (better caching) +COPY backend/go.mod backend/go.sum ./ +RUN go mod download + +# Copy backend source first +COPY backend/ ./ + +# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten) +COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist + +# Build the binary (BuildType=release for CI builds, embed frontend) +RUN CGO_ENABLED=0 GOOS=linux go build \ + -tags embed \ + -ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \ + -o /app/sub2api \ + ./cmd/server + +# ----------------------------------------------------------------------------- +# Stage 3: Final Runtime Image +# ----------------------------------------------------------------------------- +FROM ${ALPINE_IMAGE} + +# Labels +LABEL maintainer="Wei-Shaw " +LABEL description="Sub2API - AI API Gateway Platform" +LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" + +# Install runtime dependencies +RUN apk add --no-cache \ + ca-certificates \ + tzdata \ + curl \ + && rm -rf /var/cache/apk/* + +# Create non-root user +RUN addgroup -g 1000 sub2api && \ + adduser -u 1000 -G sub2api -s /bin/sh -D sub2api + +# Set working directory +WORKDIR /app + +# Copy binary from builder +COPY --from=backend-builder /app/sub2api /app/sub2api + +# Create data directory +RUN mkdir -p /app/data && chown -R sub2api:sub2api /app + +# Switch to non-root user +USER sub2api + +# Expose port (can be overridden by SERVER_PORT env var) +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ + CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + +# Run the application +ENTRYPOINT ["/app/sub2api"] diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 99386fc9..2c7a1778 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -249,32 +249,64 @@ gateway: # name: "Custom Profile 2" # ============================================================================= -# Sora2API Configuration -# Sora2API 配置 +# Sora Direct Client Configuration +# Sora 直连配置 # ============================================================================= -sora2api: - # Sora2API base URL - # Sora2API 服务地址 - base_url: "http://127.0.0.1:8000" - # Sora2API API Key (for /v1/chat/completions and /v1/models) - # Sora2API API Key(用于生成/模型列表) - api_key: "" - # Admin username/password (for token sync) - # 管理口用户名/密码(用于 token 同步) - admin_username: "admin" - admin_password: "admin" - # Admin token cache ttl (seconds) - # 管理口 token 缓存时长(秒) - admin_token_ttl_seconds: 900 - # Admin request timeout (seconds) - # 管理口请求超时(秒) - admin_timeout_seconds: 10 - # Token import mode: at/offline - # Token 导入模式:at/offline - token_import_mode: "at" - # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] - # curves: [29, 23, 24] - # point_formats: [0] +sora: + client: + # Sora backend base URL + # Sora 上游 Base URL + base_url: "https://sora.chatgpt.com/backend" + # Request timeout (seconds) + # 请求超时(秒) + timeout_seconds: 120 + # Max retries for upstream requests + # 上游请求最大重试次数 + max_retries: 3 + # Poll interval (seconds) + # 轮询间隔(秒) + poll_interval_seconds: 2 + # Max poll attempts + # 最大轮询次数 + max_poll_attempts: 600 + # Enable debug logs for Sora upstream requests + # 启用 Sora 直连调试日志 + debug: false + # Optional custom headers (key-value) + # 额外请求头(键值对) + headers: {} + # Default User-Agent for Sora requests + # Sora 默认 User-Agent + user_agent: "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" + # Disable TLS fingerprint for Sora upstream + # 关闭 Sora 上游 TLS 指纹伪装 + disable_tls_fingerprint: false + storage: + # Storage type (local only for now) + # 存储类型(首发仅支持 local) + type: "local" + # Local base path; empty uses /app/data/sora + # 本地存储基础路径;为空使用 /app/data/sora + local_path: "" + # Fallback to upstream URL when download fails + # 下载失败时回退到上游 URL + fallback_to_upstream: true + # Max concurrent downloads + # 并发下载上限 + max_concurrent_downloads: 4 + # Enable debug logs for media storage + # 启用媒体存储调试日志 + debug: false + cleanup: + # Enable cleanup task + # 启用清理任务 + enabled: true + # Retention days + # 保留天数 + retention_days: 7 + # Cron schedule + # Cron 调度表达式 + schedule: "0 3 * * *" # ============================================================================= # API Key Auth Cache Configuration diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 505c1419..e86f6348 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -18,7 +18,6 @@ import geminiAPI from './gemini' import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' -import modelsAPI from './models' /** * Unified admin API object for convenient access @@ -38,8 +37,7 @@ export const adminAPI = { gemini: geminiAPI, antigravity: antigravityAPI, userAttributes: userAttributesAPI, - ops: opsAPI, - models: modelsAPI + ops: opsAPI } export { @@ -57,8 +55,7 @@ export { geminiAPI, antigravityAPI, userAttributesAPI, - opsAPI, - modelsAPI + opsAPI } export default adminAPI diff --git a/frontend/src/api/admin/models.ts b/frontend/src/api/admin/models.ts deleted file mode 100644 index 897304ac..00000000 --- a/frontend/src/api/admin/models.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { apiClient } from '@/api/client' - -export async function getPlatformModels(platform: string): Promise { - const { data } = await apiClient.get('/admin/models', { - params: { platform } - }) - return data -} - -export const modelsAPI = { - getPlatformModels -} - -export default modelsAPI diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 0e81a717..30ec9e63 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1501,9 +1501,9 @@ - diff --git a/frontend/src/components/account/ModelWhitelistSelector.vue b/frontend/src/components/account/ModelWhitelistSelector.vue index 227e6e61..16ffa225 100644 --- a/frontend/src/components/account/ModelWhitelistSelector.vue +++ b/frontend/src/components/account/ModelWhitelistSelector.vue @@ -45,19 +45,6 @@ :placeholder="t('admin.accounts.searchModels')" @click.stop /> -
- - {{ t('admin.accounts.soraModelsLoading') }} - - -