feat(Sora): 直连生成并移除sora2api依赖
实现直连 Sora 客户端、媒体落地与清理策略\n更新网关与前端配置以支持 Sora 平台\n补齐单元测试与契约测试,新增 curl 测试脚本\n\n测试: go test ./... -tags=unit
This commit is contained in:
@@ -67,6 +67,7 @@ func provideCleanup(
|
|||||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||||
opsCleanup *service.OpsCleanupService,
|
opsCleanup *service.OpsCleanupService,
|
||||||
opsScheduledReport *service.OpsScheduledReportService,
|
opsScheduledReport *service.OpsScheduledReportService,
|
||||||
|
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
@@ -100,6 +101,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SoraMediaCleanupService", func() error {
|
||||||
|
if soraMediaCleanup != nil {
|
||||||
|
soraMediaCleanup.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"OpsAlertEvaluatorService", func() error {
|
{"OpsAlertEvaluatorService", func() error {
|
||||||
if opsAlertEvaluator != nil {
|
if opsAlertEvaluator != nil {
|
||||||
opsAlertEvaluator.Stop()
|
opsAlertEvaluator.Stop()
|
||||||
|
|||||||
@@ -87,12 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
schedulerCache := repository.NewSchedulerCache(redisClient)
|
schedulerCache := repository.NewSchedulerCache(redisClient)
|
||||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||||
soraAccountRepository := repository.NewSoraAccountRepository(db)
|
soraAccountRepository := repository.NewSoraAccountRepository(db)
|
||||||
sora2APIService := service.NewSora2APIService(configConfig)
|
|
||||||
sora2APISyncService := service.NewSora2APISyncService(sora2APIService, accountRepository)
|
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, sora2APISyncService, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
@@ -164,11 +162,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||||
modelHandler := admin.NewModelHandler(sora2APIService)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, modelHandler)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, sora2APIService, concurrencyService, billingCacheService, configConfig)
|
|
||||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||||
soraGatewayService := service.NewSoraGatewayService(sora2APIService, httpUpstream, rateLimitService, configConfig)
|
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
|
||||||
|
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||||
|
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
|
||||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
|
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler)
|
||||||
@@ -182,9 +181,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
|
||||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, sora2APISyncService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
|
||||||
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
|
||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
Cleanup: v,
|
Cleanup: v,
|
||||||
@@ -214,6 +214,7 @@ func provideCleanup(
|
|||||||
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
opsAlertEvaluator *service.OpsAlertEvaluatorService,
|
||||||
opsCleanup *service.OpsCleanupService,
|
opsCleanup *service.OpsCleanupService,
|
||||||
opsScheduledReport *service.OpsScheduledReportService,
|
opsScheduledReport *service.OpsScheduledReportService,
|
||||||
|
soraMediaCleanup *service.SoraMediaCleanupService,
|
||||||
schedulerSnapshot *service.SchedulerSnapshotService,
|
schedulerSnapshot *service.SchedulerSnapshotService,
|
||||||
tokenRefresh *service.TokenRefreshService,
|
tokenRefresh *service.TokenRefreshService,
|
||||||
accountExpiry *service.AccountExpiryService,
|
accountExpiry *service.AccountExpiryService,
|
||||||
@@ -246,6 +247,12 @@ func provideCleanup(
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}},
|
}},
|
||||||
|
{"SoraMediaCleanupService", func() error {
|
||||||
|
if soraMediaCleanup != nil {
|
||||||
|
soraMediaCleanup.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"OpsAlertEvaluatorService", func() error {
|
{"OpsAlertEvaluatorService", func() error {
|
||||||
if opsAlertEvaluator != nil {
|
if opsAlertEvaluator != nil {
|
||||||
opsAlertEvaluator.Stop()
|
opsAlertEvaluator.Stop()
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ type Config struct {
|
|||||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
Sora2API Sora2APIConfig `mapstructure:"sora2api"`
|
Sora SoraConfig `mapstructure:"sora"`
|
||||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
@@ -205,22 +205,40 @@ type ConcurrencyConfig struct {
|
|||||||
PingInterval int `mapstructure:"ping_interval"`
|
PingInterval int `mapstructure:"ping_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sora2APIConfig Sora2API 服务配置
|
// SoraConfig 直连 Sora 配置
|
||||||
type Sora2APIConfig struct {
|
type SoraConfig struct {
|
||||||
// BaseURL Sora2API 服务地址(例如 http://localhost:8000)
|
Client SoraClientConfig `mapstructure:"client"`
|
||||||
|
Storage SoraStorageConfig `mapstructure:"storage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraClientConfig 直连 Sora 客户端配置
|
||||||
|
type SoraClientConfig struct {
|
||||||
BaseURL string `mapstructure:"base_url"`
|
BaseURL string `mapstructure:"base_url"`
|
||||||
// APIKey Sora2API OpenAI 兼容接口的 API Key
|
TimeoutSeconds int `mapstructure:"timeout_seconds"`
|
||||||
APIKey string `mapstructure:"api_key"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
// AdminUsername 管理员用户名(用于 token 同步)
|
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||||
AdminUsername string `mapstructure:"admin_username"`
|
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||||
// AdminPassword 管理员密码(用于 token 同步)
|
Debug bool `mapstructure:"debug"`
|
||||||
AdminPassword string `mapstructure:"admin_password"`
|
Headers map[string]string `mapstructure:"headers"`
|
||||||
// AdminTokenTTLSeconds 管理员 Token 缓存时长(秒)
|
UserAgent string `mapstructure:"user_agent"`
|
||||||
AdminTokenTTLSeconds int `mapstructure:"admin_token_ttl_seconds"`
|
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
|
||||||
// AdminTimeoutSeconds 管理接口请求超时(秒)
|
}
|
||||||
AdminTimeoutSeconds int `mapstructure:"admin_timeout_seconds"`
|
|
||||||
// TokenImportMode token 导入模式:at/offline
|
// SoraStorageConfig 媒体存储配置
|
||||||
TokenImportMode string `mapstructure:"token_import_mode"`
|
type SoraStorageConfig struct {
|
||||||
|
Type string `mapstructure:"type"`
|
||||||
|
LocalPath string `mapstructure:"local_path"`
|
||||||
|
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
||||||
|
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
||||||
|
Debug bool `mapstructure:"debug"`
|
||||||
|
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraStorageCleanupConfig 媒体清理配置
|
||||||
|
type SoraStorageCleanupConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Schedule string `mapstructure:"schedule"`
|
||||||
|
RetentionDays int `mapstructure:"retention_days"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GatewayConfig API网关相关配置
|
// GatewayConfig API网关相关配置
|
||||||
@@ -905,6 +923,26 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||||
viper.SetDefault("concurrency.ping_interval", 10)
|
viper.SetDefault("concurrency.ping_interval", 10)
|
||||||
|
|
||||||
|
// Sora 直连配置
|
||||||
|
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
|
||||||
|
viper.SetDefault("sora.client.timeout_seconds", 120)
|
||||||
|
viper.SetDefault("sora.client.max_retries", 3)
|
||||||
|
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||||
|
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||||
|
viper.SetDefault("sora.client.debug", false)
|
||||||
|
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||||
|
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
|
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
|
||||||
|
|
||||||
|
viper.SetDefault("sora.storage.type", "local")
|
||||||
|
viper.SetDefault("sora.storage.local_path", "")
|
||||||
|
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
||||||
|
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
||||||
|
viper.SetDefault("sora.storage.debug", false)
|
||||||
|
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
||||||
|
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
||||||
|
viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
|
||||||
|
|
||||||
// TokenRefresh
|
// TokenRefresh
|
||||||
viper.SetDefault("token_refresh.enabled", true)
|
viper.SetDefault("token_refresh.enabled", true)
|
||||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||||
@@ -920,15 +958,6 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gemini.oauth.scopes", "")
|
viper.SetDefault("gemini.oauth.scopes", "")
|
||||||
viper.SetDefault("gemini.quota.policy", "")
|
viper.SetDefault("gemini.quota.policy", "")
|
||||||
|
|
||||||
// Sora2API
|
|
||||||
viper.SetDefault("sora2api.base_url", "")
|
|
||||||
viper.SetDefault("sora2api.api_key", "")
|
|
||||||
viper.SetDefault("sora2api.admin_username", "")
|
|
||||||
viper.SetDefault("sora2api.admin_password", "")
|
|
||||||
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
|
|
||||||
viper.SetDefault("sora2api.admin_timeout_seconds", 10)
|
|
||||||
viper.SetDefault("sora2api.token_import_mode", "at")
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
@@ -1164,6 +1193,36 @@ func (c *Config) Validate() error {
|
|||||||
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
|
return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.Sora.Client.TimeoutSeconds < 0 {
|
||||||
|
return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Client.MaxRetries < 0 {
|
||||||
|
return fmt.Errorf("sora.client.max_retries must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Client.PollIntervalSeconds < 0 {
|
||||||
|
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Client.MaxPollAttempts < 0 {
|
||||||
|
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||||
|
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Storage.Cleanup.Enabled {
|
||||||
|
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
||||||
|
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
|
||||||
|
return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if c.Sora.Storage.Cleanup.RetentionDays < 0 {
|
||||||
|
return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
|
||||||
|
return fmt.Errorf("sora.storage.type must be 'local'")
|
||||||
|
}
|
||||||
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
|
||||||
switch c.Gateway.ConnectionPoolIsolation {
|
switch c.Gateway.ConnectionPoolIsolation {
|
||||||
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
|
||||||
@@ -1260,11 +1319,6 @@ func (c *Config) Validate() error {
|
|||||||
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
|
c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds {
|
||||||
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
|
return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds")
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(c.Sora2API.BaseURL) != "" {
|
|
||||||
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
|
|
||||||
return fmt.Errorf("sora2api.base_url invalid: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
||||||
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
package admin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ModelHandler handles admin model listing requests.
|
|
||||||
type ModelHandler struct {
|
|
||||||
sora2apiService *service.Sora2APIService
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewModelHandler creates a new ModelHandler.
|
|
||||||
func NewModelHandler(sora2apiService *service.Sora2APIService) *ModelHandler {
|
|
||||||
return &ModelHandler{
|
|
||||||
sora2apiService: sora2apiService,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// List handles listing models for a specific platform
|
|
||||||
// GET /api/v1/admin/models?platform=sora
|
|
||||||
func (h *ModelHandler) List(c *gin.Context) {
|
|
||||||
platform := strings.TrimSpace(strings.ToLower(c.Query("platform")))
|
|
||||||
if platform == "" {
|
|
||||||
response.BadRequest(c, "platform is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch platform {
|
|
||||||
case service.PlatformSora:
|
|
||||||
if h.sora2apiService == nil || !h.sora2apiService.Enabled() {
|
|
||||||
response.Error(c, http.StatusServiceUnavailable, "sora2api not configured")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
models, err := h.sora2apiService.ListModels(c.Request.Context())
|
|
||||||
if err != nil {
|
|
||||||
response.Error(c, http.StatusServiceUnavailable, "failed to fetch sora models")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ids := make([]string, 0, len(models))
|
|
||||||
for _, m := range models {
|
|
||||||
if strings.TrimSpace(m.ID) != "" {
|
|
||||||
ids = append(ids, m.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response.Success(c, ids)
|
|
||||||
default:
|
|
||||||
response.BadRequest(c, "unsupported platform")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
package admin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestModelHandlerListSoraSuccess(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_, _ = w.Write([]byte(`{"object":"list","data":[{"id":"m1"},{"id":"m2"}]}`))
|
|
||||||
}))
|
|
||||||
t.Cleanup(upstream.Close)
|
|
||||||
|
|
||||||
cfg := &config.Config{}
|
|
||||||
cfg.Sora2API.BaseURL = upstream.URL
|
|
||||||
cfg.Sora2API.APIKey = "test-key"
|
|
||||||
soraService := service.NewSora2APIService(cfg)
|
|
||||||
|
|
||||||
h := NewModelHandler(soraService)
|
|
||||||
router := gin.New()
|
|
||||||
router.GET("/admin/models", h.List)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
if recorder.Code != http.StatusOK {
|
|
||||||
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
|
|
||||||
}
|
|
||||||
var resp response.Response
|
|
||||||
if err := json.Unmarshal(recorder.Body.Bytes(), &resp); err != nil {
|
|
||||||
t.Fatalf("解析响应失败: %v", err)
|
|
||||||
}
|
|
||||||
if resp.Code != 0 {
|
|
||||||
t.Fatalf("响应 code=%d", resp.Code)
|
|
||||||
}
|
|
||||||
data, ok := resp.Data.([]any)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("响应 data 类型错误")
|
|
||||||
}
|
|
||||||
if len(data) != 2 {
|
|
||||||
t.Fatalf("模型数量不符: %d", len(data))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelHandlerListSoraNotConfigured(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
h := NewModelHandler(&service.Sora2APIService{})
|
|
||||||
router := gin.New()
|
|
||||||
router.GET("/admin/models", h.List)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=sora", nil)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
if recorder.Code != http.StatusServiceUnavailable {
|
|
||||||
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelHandlerListInvalidPlatform(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
h := NewModelHandler(&service.Sora2APIService{})
|
|
||||||
router := gin.New()
|
|
||||||
router.GET("/admin/models", h.List)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/admin/models?platform=unknown", nil)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
if recorder.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("status=%d body=%s", recorder.Code, recorder.Body.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -29,11 +29,11 @@ type GatewayHandler struct {
|
|||||||
geminiCompatService *service.GeminiMessagesCompatService
|
geminiCompatService *service.GeminiMessagesCompatService
|
||||||
antigravityGatewayService *service.AntigravityGatewayService
|
antigravityGatewayService *service.AntigravityGatewayService
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
sora2apiService *service.Sora2APIService
|
|
||||||
billingCacheService *service.BillingCacheService
|
billingCacheService *service.BillingCacheService
|
||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
maxAccountSwitchesGemini int
|
maxAccountSwitchesGemini int
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayHandler creates a new GatewayHandler
|
// NewGatewayHandler creates a new GatewayHandler
|
||||||
@@ -42,7 +42,6 @@ func NewGatewayHandler(
|
|||||||
geminiCompatService *service.GeminiMessagesCompatService,
|
geminiCompatService *service.GeminiMessagesCompatService,
|
||||||
antigravityGatewayService *service.AntigravityGatewayService,
|
antigravityGatewayService *service.AntigravityGatewayService,
|
||||||
userService *service.UserService,
|
userService *service.UserService,
|
||||||
sora2apiService *service.Sora2APIService,
|
|
||||||
concurrencyService *service.ConcurrencyService,
|
concurrencyService *service.ConcurrencyService,
|
||||||
billingCacheService *service.BillingCacheService,
|
billingCacheService *service.BillingCacheService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
@@ -64,11 +63,11 @@ func NewGatewayHandler(
|
|||||||
geminiCompatService: geminiCompatService,
|
geminiCompatService: geminiCompatService,
|
||||||
antigravityGatewayService: antigravityGatewayService,
|
antigravityGatewayService: antigravityGatewayService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
sora2apiService: sora2apiService,
|
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||||
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -486,18 +485,9 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if platform == service.PlatformSora {
|
if platform == service.PlatformSora {
|
||||||
if h.sora2apiService == nil || !h.sora2apiService.Enabled() {
|
|
||||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "sora2api not configured")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
models, err := h.sora2apiService.ListModels(c.Request.Context())
|
|
||||||
if err != nil {
|
|
||||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Failed to fetch Sora models")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": models,
|
"data": service.DefaultSoraModels(h.cfg),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ type AdminHandlers struct {
|
|||||||
Subscription *admin.SubscriptionHandler
|
Subscription *admin.SubscriptionHandler
|
||||||
Usage *admin.UsageHandler
|
Usage *admin.UsageHandler
|
||||||
UserAttribute *admin.UserAttributeHandler
|
UserAttribute *admin.UserAttributeHandler
|
||||||
Model *admin.ModelHandler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handlers contains all HTTP handlers
|
// Handlers contains all HTTP handlers
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -31,9 +33,8 @@ type SoraGatewayHandler struct {
|
|||||||
concurrencyHelper *ConcurrencyHelper
|
concurrencyHelper *ConcurrencyHelper
|
||||||
maxAccountSwitches int
|
maxAccountSwitches int
|
||||||
streamMode string
|
streamMode string
|
||||||
sora2apiBaseURL string
|
|
||||||
soraMediaSigningKey string
|
soraMediaSigningKey string
|
||||||
mediaClient *http.Client
|
soraMediaRoot string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
||||||
@@ -48,6 +49,7 @@ func NewSoraGatewayHandler(
|
|||||||
maxAccountSwitches := 3
|
maxAccountSwitches := 3
|
||||||
streamMode := "force"
|
streamMode := "force"
|
||||||
signKey := ""
|
signKey := ""
|
||||||
|
mediaRoot := "/app/data/sora"
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||||
@@ -57,14 +59,9 @@ func NewSoraGatewayHandler(
|
|||||||
streamMode = mode
|
streamMode = mode
|
||||||
}
|
}
|
||||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
||||||
|
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
||||||
|
mediaRoot = root
|
||||||
}
|
}
|
||||||
baseURL := ""
|
|
||||||
if cfg != nil {
|
|
||||||
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
|
|
||||||
}
|
|
||||||
mediaTimeout := 180 * time.Second
|
|
||||||
if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 {
|
|
||||||
mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second
|
|
||||||
}
|
}
|
||||||
return &SoraGatewayHandler{
|
return &SoraGatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
@@ -73,9 +70,8 @@ func NewSoraGatewayHandler(
|
|||||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||||
maxAccountSwitches: maxAccountSwitches,
|
maxAccountSwitches: maxAccountSwitches,
|
||||||
streamMode: strings.ToLower(streamMode),
|
streamMode: strings.ToLower(streamMode),
|
||||||
sora2apiBaseURL: baseURL,
|
|
||||||
soraMediaSigningKey: signKey,
|
soraMediaSigningKey: signKey,
|
||||||
mediaClient: &http.Client{Timeout: mediaTimeout},
|
soraMediaRoot: mediaRoot,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,34 +373,24 @@ func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// MediaProxy proxies /tmp or /static media files from sora2api
|
// MediaProxy serves local Sora media files.
|
||||||
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
|
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
|
||||||
h.proxySoraMedia(c, false)
|
h.proxySoraMedia(c, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MediaProxySigned proxies /tmp or /static media files with signature verification
|
// MediaProxySigned serves local Sora media files with signature verification.
|
||||||
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
|
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
|
||||||
h.proxySoraMedia(c, true)
|
h.proxySoraMedia(c, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
|
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
|
||||||
if h.sora2apiBaseURL == "" {
|
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "api_error",
|
|
||||||
"message": "sora2api 未配置",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawPath := c.Param("filepath")
|
rawPath := c.Param("filepath")
|
||||||
if rawPath == "" {
|
if rawPath == "" {
|
||||||
c.Status(http.StatusNotFound)
|
c.Status(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleaned := path.Clean(rawPath)
|
cleaned := path.Clean(rawPath)
|
||||||
if !strings.HasPrefix(cleaned, "/tmp/") && !strings.HasPrefix(cleaned, "/static/") {
|
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
|
||||||
c.Status(http.StatusNotFound)
|
c.Status(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -445,40 +431,25 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(h.soraMediaRoot) == "" {
|
||||||
targetURL := h.sora2apiBaseURL + cleaned
|
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||||
if rawQuery := query.Encode(); rawQuery != "" {
|
"error": gin.H{
|
||||||
targetURL += "?" + rawQuery
|
"type": "api_error",
|
||||||
}
|
"message": "Sora 媒体目录未配置",
|
||||||
|
},
|
||||||
req, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL, nil)
|
})
|
||||||
if err != nil {
|
|
||||||
c.Status(http.StatusBadGateway)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
copyHeaders := []string{"Range", "If-Range", "If-Modified-Since", "If-None-Match", "Accept", "User-Agent"}
|
|
||||||
for _, key := range copyHeaders {
|
|
||||||
if val := c.GetHeader(key); val != "" {
|
|
||||||
req.Header.Set(key, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
client := h.mediaClient
|
relative := strings.TrimPrefix(cleaned, "/")
|
||||||
if client == nil {
|
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
|
||||||
client = http.DefaultClient
|
if _, err := os.Stat(localPath); err != nil {
|
||||||
}
|
if os.IsNotExist(err) {
|
||||||
resp, err := client.Do(req)
|
c.Status(http.StatusNotFound)
|
||||||
if err != nil {
|
|
||||||
c.Status(http.StatusBadGateway)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
c.Status(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
for _, key := range []string{"Content-Type", "Content-Length", "Accept-Ranges", "Content-Range", "Cache-Control", "Last-Modified", "ETag"} {
|
|
||||||
if val := resp.Header.Get(key); val != "" {
|
|
||||||
c.Header(key, val)
|
|
||||||
}
|
}
|
||||||
}
|
c.File(localPath)
|
||||||
c.Status(resp.StatusCode)
|
|
||||||
_, _ = io.Copy(c.Writer, resp.Body)
|
|
||||||
}
|
}
|
||||||
|
|||||||
441
backend/internal/handler/sora_gateway_handler_test.go
Normal file
441
backend/internal/handler/sora_gateway_handler_test.go
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubSoraClient struct {
|
||||||
|
imageURLs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubSoraClient) Enabled() bool { return true }
|
||||||
|
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
|
||||||
|
return "upload", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
|
||||||
|
return "task-image", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||||
|
return "task-video", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||||
|
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
|
||||||
|
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubConcurrencyCache struct{}
|
||||||
|
|
||||||
|
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||||
|
result := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
result[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
func (c stubConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubAccountRepo struct {
|
||||||
|
accounts map[int64]*service.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
|
||||||
|
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||||
|
if acc, ok := r.accounts[id]; ok {
|
||||||
|
return acc, nil
|
||||||
|
}
|
||||||
|
return nil, service.ErrAccountNotFound
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||||
|
var result []*service.Account
|
||||||
|
for _, id := range ids {
|
||||||
|
if acc, ok := r.accounts[id]; ok {
|
||||||
|
result = append(result, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||||
|
_, ok := r.accounts[id]
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
|
||||||
|
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
|
||||||
|
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return r.listSchedulableByPlatform(platform), nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
|
||||||
|
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||||||
|
return r.listSchedulable(), nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||||
|
return r.listSchedulable(), nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||||
|
return r.listSchedulableByPlatform(platform), nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||||
|
return r.listSchedulableByPlatform(platform), nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||||
|
var result []service.Account
|
||||||
|
for _, acc := range r.accounts {
|
||||||
|
for _, platform := range platforms {
|
||||||
|
if acc.Platform == platform && acc.IsSchedulable() {
|
||||||
|
result = append(result, *acc)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||||
|
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||||
|
var result []service.Account
|
||||||
|
for _, acc := range r.accounts {
|
||||||
|
if acc.IsSchedulable() {
|
||||||
|
result = append(result, *acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
|
||||||
|
var result []service.Account
|
||||||
|
for _, acc := range r.accounts {
|
||||||
|
if acc.Platform == platform && acc.IsSchedulable() {
|
||||||
|
result = append(result, *acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubGroupRepo struct {
|
||||||
|
group *service.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
|
||||||
|
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||||
|
return r.group, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||||
|
return r.group, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
|
||||||
|
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
|
||||||
|
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubUsageLogRepo struct{}
|
||||||
|
|
||||||
|
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||||
|
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
RunMode: config.RunModeSimple,
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
SoraStreamMode: "force",
|
||||||
|
MaxAccountSwitches: 1,
|
||||||
|
Scheduling: config.GatewaySchedulingConfig{
|
||||||
|
LoadBatchEnabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: "https://sora.test",
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
|
||||||
|
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
|
||||||
|
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
|
||||||
|
groupRepo := &stubGroupRepo{group: group}
|
||||||
|
|
||||||
|
usageLogRepo := &stubUsageLogRepo{}
|
||||||
|
deferredService := service.NewDeferredService(accountRepo, nil, 0)
|
||||||
|
billingService := service.NewBillingService(cfg, nil)
|
||||||
|
concurrencyService := service.NewConcurrencyService(stubConcurrencyCache{})
|
||||||
|
billingCacheService := service.NewBillingCacheService(nil, nil, nil, cfg)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
billingCacheService.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
gatewayService := service.NewGatewayService(
|
||||||
|
accountRepo,
|
||||||
|
groupRepo,
|
||||||
|
usageLogRepo,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
cfg,
|
||||||
|
nil,
|
||||||
|
concurrencyService,
|
||||||
|
billingService,
|
||||||
|
nil,
|
||||||
|
billingCacheService,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
deferredService,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||||
|
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
|
||||||
|
|
||||||
|
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, cfg)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 1,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
GroupID: &group.ID,
|
||||||
|
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
|
||||||
|
Group: group,
|
||||||
|
}
|
||||||
|
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||||
|
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
|
||||||
|
|
||||||
|
handler.ChatCompletions(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
var resp map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
|
require.NotEmpty(t, resp["media_url"])
|
||||||
|
}
|
||||||
@@ -26,7 +26,6 @@ func ProvideAdminHandlers(
|
|||||||
subscriptionHandler *admin.SubscriptionHandler,
|
subscriptionHandler *admin.SubscriptionHandler,
|
||||||
usageHandler *admin.UsageHandler,
|
usageHandler *admin.UsageHandler,
|
||||||
userAttributeHandler *admin.UserAttributeHandler,
|
userAttributeHandler *admin.UserAttributeHandler,
|
||||||
modelHandler *admin.ModelHandler,
|
|
||||||
) *AdminHandlers {
|
) *AdminHandlers {
|
||||||
return &AdminHandlers{
|
return &AdminHandlers{
|
||||||
Dashboard: dashboardHandler,
|
Dashboard: dashboardHandler,
|
||||||
@@ -46,7 +45,6 @@ func ProvideAdminHandlers(
|
|||||||
Subscription: subscriptionHandler,
|
Subscription: subscriptionHandler,
|
||||||
Usage: usageHandler,
|
Usage: usageHandler,
|
||||||
UserAttribute: userAttributeHandler,
|
UserAttribute: userAttributeHandler,
|
||||||
Model: modelHandler,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +119,6 @@ var ProviderSet = wire.NewSet(
|
|||||||
admin.NewSubscriptionHandler,
|
admin.NewSubscriptionHandler,
|
||||||
admin.NewUsageHandler,
|
admin.NewUsageHandler,
|
||||||
admin.NewUserAttributeHandler,
|
admin.NewUserAttributeHandler,
|
||||||
admin.NewModelHandler,
|
|
||||||
|
|
||||||
// AdminHandlers and Handlers constructors
|
// AdminHandlers and Handlers constructors
|
||||||
ProvideAdminHandlers,
|
ProvideAdminHandlers,
|
||||||
|
|||||||
@@ -178,6 +178,10 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"image_price_1k": null,
|
"image_price_1k": null,
|
||||||
"image_price_2k": null,
|
"image_price_2k": null,
|
||||||
"image_price_4k": null,
|
"image_price_4k": null,
|
||||||
|
"sora_image_price_360": null,
|
||||||
|
"sora_image_price_540": null,
|
||||||
|
"sora_video_price_per_request": null,
|
||||||
|
"sora_video_price_per_request_hd": null,
|
||||||
"claude_code_only": false,
|
"claude_code_only": false,
|
||||||
"fallback_group_id": null,
|
"fallback_group_id": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
@@ -394,6 +398,7 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"first_token_ms": 50,
|
"first_token_ms": 50,
|
||||||
"image_count": 0,
|
"image_count": 0,
|
||||||
"image_size": null,
|
"image_size": null,
|
||||||
|
"media_type": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"user_agent": null
|
"user_agent": null
|
||||||
}
|
}
|
||||||
@@ -887,6 +892,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
|
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,9 +64,6 @@ func RegisterAdminRoutes(
|
|||||||
|
|
||||||
// 用户属性管理
|
// 用户属性管理
|
||||||
registerUserAttributeRoutes(admin, h)
|
registerUserAttributeRoutes(admin, h)
|
||||||
|
|
||||||
// 模型列表
|
|
||||||
registerModelRoutes(admin, h)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,7 +371,3 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
|
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerModelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|
||||||
admin.GET("/models", h.Admin.Model.List)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -491,7 +491,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
|||||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 Sora 客户端标准请求头(参考 sora2api)
|
// 使用 Sora 客户端标准请求头
|
||||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|||||||
@@ -283,7 +283,6 @@ type adminServiceImpl struct {
|
|||||||
groupRepo GroupRepository
|
groupRepo GroupRepository
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
|
soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
|
||||||
soraSyncService *Sora2APISyncService // Sora2API 同步服务
|
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
apiKeyRepo APIKeyRepository
|
apiKeyRepo APIKeyRepository
|
||||||
redeemCodeRepo RedeemCodeRepository
|
redeemCodeRepo RedeemCodeRepository
|
||||||
@@ -299,7 +298,6 @@ func NewAdminService(
|
|||||||
groupRepo GroupRepository,
|
groupRepo GroupRepository,
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
soraAccountRepo SoraAccountRepository,
|
soraAccountRepo SoraAccountRepository,
|
||||||
soraSyncService *Sora2APISyncService,
|
|
||||||
proxyRepo ProxyRepository,
|
proxyRepo ProxyRepository,
|
||||||
apiKeyRepo APIKeyRepository,
|
apiKeyRepo APIKeyRepository,
|
||||||
redeemCodeRepo RedeemCodeRepository,
|
redeemCodeRepo RedeemCodeRepository,
|
||||||
@@ -313,7 +311,6 @@ func NewAdminService(
|
|||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
soraAccountRepo: soraAccountRepo,
|
soraAccountRepo: soraAccountRepo,
|
||||||
soraSyncService: soraSyncService,
|
|
||||||
proxyRepo: proxyRepo,
|
proxyRepo: proxyRepo,
|
||||||
apiKeyRepo: apiKeyRepo,
|
apiKeyRepo: apiKeyRepo,
|
||||||
redeemCodeRepo: redeemCodeRepo,
|
redeemCodeRepo: redeemCodeRepo,
|
||||||
@@ -917,9 +914,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 同步到 sora2api(异步,不阻塞创建)
|
|
||||||
s.syncSoraAccountAsync(account)
|
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1014,7 +1008,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.syncSoraAccountAsync(updated)
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1032,17 +1025,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
}
|
}
|
||||||
|
|
||||||
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
|
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
|
||||||
needSoraSync := s != nil && s.soraSyncService != nil
|
|
||||||
|
|
||||||
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
|
// 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
|
||||||
platformByID := map[int64]string{}
|
platformByID := map[int64]string{}
|
||||||
if needMixedChannelCheck || needSoraSync {
|
if needMixedChannelCheck {
|
||||||
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
|
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if needMixedChannelCheck {
|
if needMixedChannelCheck {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Printf("[AdminService] 预加载账号平台信息失败,将逐个降级同步: err=%v", err)
|
|
||||||
} else {
|
} else {
|
||||||
for _, account := range accounts {
|
for _, account := range accounts {
|
||||||
if account != nil {
|
if account != nil {
|
||||||
@@ -1134,45 +1125,15 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
|||||||
result.Success++
|
result.Success++
|
||||||
result.SuccessIDs = append(result.SuccessIDs, accountID)
|
result.SuccessIDs = append(result.SuccessIDs, accountID)
|
||||||
result.Results = append(result.Results, entry)
|
result.Results = append(result.Results, entry)
|
||||||
|
|
||||||
// 批量更新后同步 sora2api
|
|
||||||
if needSoraSync {
|
|
||||||
platform := platformByID[accountID]
|
|
||||||
if platform == "" {
|
|
||||||
updated, err := s.accountRepo.GetByID(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if updated.Platform == PlatformSora {
|
|
||||||
s.syncSoraAccountAsync(updated)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if platform == PlatformSora {
|
|
||||||
updated, err := s.accountRepo.GetByID(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[AdminService] 批量更新后获取账号失败,无法同步 sora2api: account_id=%d err=%v", accountID, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.syncSoraAccountAsync(updated)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||||
account, err := s.accountRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.deleteSoraAccountAsync(account)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1210,44 +1171,9 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.syncSoraAccountAsync(updated)
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) syncSoraAccountAsync(account *Account) {
|
|
||||||
if s == nil || s.soraSyncService == nil || account == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if account.Platform != PlatformSora {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
syncAccount := *account
|
|
||||||
go func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := s.soraSyncService.SyncAccount(ctx, &syncAccount); err != nil {
|
|
||||||
log.Printf("[AdminService] 同步 sora2api 失败: account_id=%d err=%v", syncAccount.ID, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *adminServiceImpl) deleteSoraAccountAsync(account *Account) {
|
|
||||||
if s == nil || s.soraSyncService == nil || account == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if account.Platform != PlatformSora {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
syncAccount := *account
|
|
||||||
go func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := s.soraSyncService.DeleteAccount(ctx, &syncAccount); err != nil {
|
|
||||||
log.Printf("[AdminService] 删除 sora2api token 失败: account_id=%d err=%v", syncAccount.ID, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy management implementations
|
// Proxy management implementations
|
||||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
|
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
|
||||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||||
|
|||||||
@@ -105,31 +105,3 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
|||||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||||
require.Len(t, result.Results, 3)
|
require.Len(t, result.Results, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs 验证无分组更新时仍会触发 Sora 同步。
|
|
||||||
func TestAdminService_BulkUpdateAccounts_SoraSyncWithoutGroupIDs(t *testing.T) {
|
|
||||||
repo := &accountRepoStubForBulkUpdate{
|
|
||||||
getByIDsAccounts: []*Account{
|
|
||||||
{ID: 1, Platform: PlatformSora},
|
|
||||||
},
|
|
||||||
getByIDAccounts: map[int64]*Account{
|
|
||||||
1: {ID: 1, Platform: PlatformSora},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
svc := &adminServiceImpl{
|
|
||||||
accountRepo: repo,
|
|
||||||
soraSyncService: &Sora2APISyncService{},
|
|
||||||
}
|
|
||||||
|
|
||||||
schedulable := true
|
|
||||||
input := &BulkUpdateAccountsInput{
|
|
||||||
AccountIDs: []int64{1},
|
|
||||||
Schedulable: &schedulable,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, result.Success)
|
|
||||||
require.True(t, repo.getByIDsCalled)
|
|
||||||
require.ElementsMatch(t, []int64{1}, repo.getByIDCalled)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,351 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Sora2APIModel represents a model entry returned by sora2api.
|
|
||||||
type Sora2APIModel struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
OwnedBy string `json:"owned_by,omitempty"`
|
|
||||||
Description string `json:"description,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sora2APIModelList represents /v1/models response.
|
|
||||||
type Sora2APIModelList struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []Sora2APIModel `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sora2APIImportTokenItem mirrors sora2api ImportTokenItem.
|
|
||||||
type Sora2APIImportTokenItem struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
AccessToken string `json:"access_token,omitempty"`
|
|
||||||
SessionToken string `json:"session_token,omitempty"`
|
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
|
||||||
ClientID string `json:"client_id,omitempty"`
|
|
||||||
ProxyURL string `json:"proxy_url,omitempty"`
|
|
||||||
Remark string `json:"remark,omitempty"`
|
|
||||||
IsActive bool `json:"is_active"`
|
|
||||||
ImageEnabled bool `json:"image_enabled"`
|
|
||||||
VideoEnabled bool `json:"video_enabled"`
|
|
||||||
ImageConcurrency int `json:"image_concurrency"`
|
|
||||||
VideoConcurrency int `json:"video_concurrency"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sora2APIToken represents minimal fields for admin list.
|
|
||||||
type Sora2APIToken struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Remark string `json:"remark"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sora2APIService provides access to sora2api endpoints.
|
|
||||||
type Sora2APIService struct {
|
|
||||||
cfg *config.Config
|
|
||||||
|
|
||||||
baseURL string
|
|
||||||
apiKey string
|
|
||||||
adminUsername string
|
|
||||||
adminPassword string
|
|
||||||
adminTokenTTL time.Duration
|
|
||||||
tokenImportMode string
|
|
||||||
|
|
||||||
client *http.Client
|
|
||||||
adminClient *http.Client
|
|
||||||
|
|
||||||
adminToken string
|
|
||||||
adminTokenAt time.Time
|
|
||||||
adminMu sync.Mutex
|
|
||||||
|
|
||||||
modelCache []Sora2APIModel
|
|
||||||
modelMu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSora2APIService(cfg *config.Config) *Sora2APIService {
|
|
||||||
if cfg == nil {
|
|
||||||
return &Sora2APIService{}
|
|
||||||
}
|
|
||||||
adminTTL := time.Duration(cfg.Sora2API.AdminTokenTTLSeconds) * time.Second
|
|
||||||
if adminTTL <= 0 {
|
|
||||||
adminTTL = 15 * time.Minute
|
|
||||||
}
|
|
||||||
adminTimeout := time.Duration(cfg.Sora2API.AdminTimeoutSeconds) * time.Second
|
|
||||||
if adminTimeout <= 0 {
|
|
||||||
adminTimeout = 10 * time.Second
|
|
||||||
}
|
|
||||||
return &Sora2APIService{
|
|
||||||
cfg: cfg,
|
|
||||||
baseURL: strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/"),
|
|
||||||
apiKey: strings.TrimSpace(cfg.Sora2API.APIKey),
|
|
||||||
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
|
|
||||||
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
|
|
||||||
adminTokenTTL: adminTTL,
|
|
||||||
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
|
|
||||||
client: &http.Client{},
|
|
||||||
adminClient: &http.Client{Timeout: adminTimeout},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) Enabled() bool {
|
|
||||||
return s != nil && s.baseURL != "" && s.apiKey != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) AdminEnabled() bool {
|
|
||||||
return s != nil && s.baseURL != "" && s.adminUsername != "" && s.adminPassword != ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) buildURL(path string) string {
|
|
||||||
if s.baseURL == "" {
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(path, "/") {
|
|
||||||
return s.baseURL + path
|
|
||||||
}
|
|
||||||
return s.baseURL + "/" + path
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildURL 返回完整的 sora2api URL(用于代理媒体)
|
|
||||||
func (s *Sora2APIService) BuildURL(path string) string {
|
|
||||||
return s.buildURL(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) NewAPIRequest(ctx context.Context, method string, path string, body []byte) (*http.Request, error) {
|
|
||||||
if !s.Enabled() {
|
|
||||||
return nil, errors.New("sora2api not configured")
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, error) {
|
|
||||||
if !s.Enabled() {
|
|
||||||
return nil, errors.New("sora2api not configured")
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.buildURL("/v1/models"), nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
|
||||||
resp, err := s.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return s.cachedModelsOnError(err)
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return s.cachedModelsOnError(fmt.Errorf("sora2api models status: %d", resp.StatusCode))
|
|
||||||
}
|
|
||||||
|
|
||||||
var payload Sora2APIModelList
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
|
||||||
return s.cachedModelsOnError(err)
|
|
||||||
}
|
|
||||||
models := payload.Data
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.SoraModelFilters.HidePromptEnhance {
|
|
||||||
filtered := make([]Sora2APIModel, 0, len(models))
|
|
||||||
for _, m := range models {
|
|
||||||
if strings.HasPrefix(strings.ToLower(m.ID), "prompt-enhance") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
filtered = append(filtered, m)
|
|
||||||
}
|
|
||||||
models = filtered
|
|
||||||
}
|
|
||||||
|
|
||||||
s.modelMu.Lock()
|
|
||||||
s.modelCache = models
|
|
||||||
s.modelMu.Unlock()
|
|
||||||
|
|
||||||
return models, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) cachedModelsOnError(err error) ([]Sora2APIModel, error) {
|
|
||||||
s.modelMu.RLock()
|
|
||||||
cached := append([]Sora2APIModel(nil), s.modelCache...)
|
|
||||||
s.modelMu.RUnlock()
|
|
||||||
if len(cached) > 0 {
|
|
||||||
log.Printf("[Sora2API] 模型列表拉取失败,回退缓存: %v", err)
|
|
||||||
return cached, nil
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) ImportTokens(ctx context.Context, items []Sora2APIImportTokenItem) error {
|
|
||||||
if !s.AdminEnabled() {
|
|
||||||
return errors.New("sora2api admin not configured")
|
|
||||||
}
|
|
||||||
mode := s.tokenImportMode
|
|
||||||
if mode == "" {
|
|
||||||
mode = "at"
|
|
||||||
}
|
|
||||||
payload := map[string]any{
|
|
||||||
"tokens": items,
|
|
||||||
"mode": mode,
|
|
||||||
}
|
|
||||||
_, err := s.doAdminRequest(ctx, http.MethodPost, "/api/tokens/import", payload, nil)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) ListTokens(ctx context.Context) ([]Sora2APIToken, error) {
|
|
||||||
if !s.AdminEnabled() {
|
|
||||||
return nil, errors.New("sora2api admin not configured")
|
|
||||||
}
|
|
||||||
var tokens []Sora2APIToken
|
|
||||||
_, err := s.doAdminRequest(ctx, http.MethodGet, "/api/tokens", nil, &tokens)
|
|
||||||
return tokens, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) DisableToken(ctx context.Context, tokenID int64) error {
|
|
||||||
if !s.AdminEnabled() {
|
|
||||||
return errors.New("sora2api admin not configured")
|
|
||||||
}
|
|
||||||
path := fmt.Sprintf("/api/tokens/%d/disable", tokenID)
|
|
||||||
_, err := s.doAdminRequest(ctx, http.MethodPost, path, nil, nil)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) DeleteToken(ctx context.Context, tokenID int64) error {
|
|
||||||
if !s.AdminEnabled() {
|
|
||||||
return errors.New("sora2api admin not configured")
|
|
||||||
}
|
|
||||||
path := fmt.Sprintf("/api/tokens/%d", tokenID)
|
|
||||||
_, err := s.doAdminRequest(ctx, http.MethodDelete, path, nil, nil)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) doAdminRequest(ctx context.Context, method string, path string, body any, out any) (*http.Response, error) {
|
|
||||||
if !s.AdminEnabled() {
|
|
||||||
return nil, errors.New("sora2api admin not configured")
|
|
||||||
}
|
|
||||||
token, err := s.getAdminToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp, err := s.doAdminRequestWithToken(ctx, method, path, token, body, out)
|
|
||||||
if err == nil && resp != nil && resp.StatusCode != http.StatusUnauthorized {
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
|
|
||||||
s.invalidateAdminToken()
|
|
||||||
token, err = s.getAdminToken(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
return s.doAdminRequestWithToken(ctx, method, path, token, body, out)
|
|
||||||
}
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) doAdminRequestWithToken(ctx context.Context, method string, path string, token string, body any, out any) (*http.Response, error) {
|
|
||||||
var reader *bytes.Reader
|
|
||||||
if body != nil {
|
|
||||||
buf, err := json.Marshal(body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
reader = bytes.NewReader(buf)
|
|
||||||
} else {
|
|
||||||
reader = bytes.NewReader(nil)
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, s.buildURL(path), reader)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+token)
|
|
||||||
if body != nil {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
resp, err := s.adminClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
return resp, fmt.Errorf("sora2api admin status: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) getAdminToken(ctx context.Context) (string, error) {
|
|
||||||
s.adminMu.Lock()
|
|
||||||
defer s.adminMu.Unlock()
|
|
||||||
|
|
||||||
if s.adminToken != "" && time.Since(s.adminTokenAt) < s.adminTokenTTL {
|
|
||||||
return s.adminToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.AdminEnabled() {
|
|
||||||
return "", errors.New("sora2api admin not configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := map[string]string{
|
|
||||||
"username": s.adminUsername,
|
|
||||||
"password": s.adminPassword,
|
|
||||||
}
|
|
||||||
buf, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.buildURL("/api/login"), bytes.NewReader(buf))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
resp, err := s.adminClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return "", fmt.Errorf("sora2api login failed: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
var result struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !result.Success || result.Token == "" {
|
|
||||||
if result.Message == "" {
|
|
||||||
result.Message = "sora2api login failed"
|
|
||||||
}
|
|
||||||
return "", errors.New(result.Message)
|
|
||||||
}
|
|
||||||
s.adminToken = result.Token
|
|
||||||
s.adminTokenAt = time.Now()
|
|
||||||
return result.Token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APIService) invalidateAdminToken() {
|
|
||||||
s.adminMu.Lock()
|
|
||||||
defer s.adminMu.Unlock()
|
|
||||||
s.adminToken = ""
|
|
||||||
s.adminTokenAt = time.Time{}
|
|
||||||
}
|
|
||||||
@@ -1,255 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Sora2APISyncService 用于同步 Sora 账号到 sora2api token 池
|
|
||||||
type Sora2APISyncService struct {
|
|
||||||
sora2api *Sora2APIService
|
|
||||||
accountRepo AccountRepository
|
|
||||||
httpClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSora2APISyncService(sora2api *Sora2APIService, accountRepo AccountRepository) *Sora2APISyncService {
|
|
||||||
return &Sora2APISyncService{
|
|
||||||
sora2api: sora2api,
|
|
||||||
accountRepo: accountRepo,
|
|
||||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APISyncService) Enabled() bool {
|
|
||||||
return s != nil && s.sora2api != nil && s.sora2api.AdminEnabled()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncAccount 将 Sora 账号同步到 sora2api(导入或更新)
|
|
||||||
func (s *Sora2APISyncService) SyncAccount(ctx context.Context, account *Account) error {
|
|
||||||
if !s.Enabled() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if account == nil || account.Platform != PlatformSora {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
|
||||||
if accessToken == "" {
|
|
||||||
return errors.New("sora 账号缺少 access_token")
|
|
||||||
}
|
|
||||||
|
|
||||||
email, updated := s.resolveAccountEmail(ctx, account)
|
|
||||||
if email == "" {
|
|
||||||
return errors.New("无法解析 Sora 账号邮箱")
|
|
||||||
}
|
|
||||||
if updated && s.accountRepo != nil {
|
|
||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
|
||||||
log.Printf("[SoraSync] 更新账号邮箱失败: account_id=%d err=%v", account.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
item := Sora2APIImportTokenItem{
|
|
||||||
Email: email,
|
|
||||||
AccessToken: accessToken,
|
|
||||||
SessionToken: strings.TrimSpace(account.GetCredential("session_token")),
|
|
||||||
RefreshToken: strings.TrimSpace(account.GetCredential("refresh_token")),
|
|
||||||
ClientID: strings.TrimSpace(account.GetCredential("client_id")),
|
|
||||||
Remark: account.Name,
|
|
||||||
IsActive: account.IsActive() && account.Schedulable,
|
|
||||||
ImageEnabled: true,
|
|
||||||
VideoEnabled: true,
|
|
||||||
ImageConcurrency: normalizeSoraConcurrency(account.Concurrency),
|
|
||||||
VideoConcurrency: normalizeSoraConcurrency(account.Concurrency),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.sora2api.ImportTokens(ctx, []Sora2APIImportTokenItem{item}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisableAccount 禁用 sora2api 中的 token
|
|
||||||
func (s *Sora2APISyncService) DisableAccount(ctx context.Context, account *Account) error {
|
|
||||||
if !s.Enabled() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if account == nil || account.Platform != PlatformSora {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
tokenID, err := s.resolveTokenID(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.sora2api.DisableToken(ctx, tokenID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteAccount 删除 sora2api 中的 token
|
|
||||||
func (s *Sora2APISyncService) DeleteAccount(ctx context.Context, account *Account) error {
|
|
||||||
if !s.Enabled() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if account == nil || account.Platform != PlatformSora {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
tokenID, err := s.resolveTokenID(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.sora2api.DeleteToken(ctx, tokenID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeSoraConcurrency(value int) int {
|
|
||||||
if value <= 0 {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APISyncService) resolveAccountEmail(ctx context.Context, account *Account) (string, bool) {
|
|
||||||
if account == nil {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if email := strings.TrimSpace(account.GetCredential("email")); email != "" {
|
|
||||||
return email, false
|
|
||||||
}
|
|
||||||
if email := strings.TrimSpace(account.GetExtraString("email")); email != "" {
|
|
||||||
if account.Credentials == nil {
|
|
||||||
account.Credentials = map[string]any{}
|
|
||||||
}
|
|
||||||
account.Credentials["email"] = email
|
|
||||||
return email, true
|
|
||||||
}
|
|
||||||
if email := strings.TrimSpace(account.GetExtraString("sora_email")); email != "" {
|
|
||||||
if account.Credentials == nil {
|
|
||||||
account.Credentials = map[string]any{}
|
|
||||||
}
|
|
||||||
account.Credentials["email"] = email
|
|
||||||
return email, true
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken := strings.TrimSpace(account.GetCredential("access_token"))
|
|
||||||
if accessToken != "" {
|
|
||||||
if email := extractEmailFromAccessToken(accessToken); email != "" {
|
|
||||||
if account.Credentials == nil {
|
|
||||||
account.Credentials = map[string]any{}
|
|
||||||
}
|
|
||||||
account.Credentials["email"] = email
|
|
||||||
return email, true
|
|
||||||
}
|
|
||||||
if email := s.fetchEmailFromSora(ctx, accessToken); email != "" {
|
|
||||||
if account.Credentials == nil {
|
|
||||||
account.Credentials = map[string]any{}
|
|
||||||
}
|
|
||||||
account.Credentials["email"] = email
|
|
||||||
return email, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APISyncService) resolveTokenID(ctx context.Context, account *Account) (int64, error) {
|
|
||||||
if account == nil {
|
|
||||||
return 0, errors.New("account is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if account.Extra != nil {
|
|
||||||
if v, ok := account.Extra["sora2api_token_id"]; ok {
|
|
||||||
if id, ok := v.(float64); ok && id > 0 {
|
|
||||||
return int64(id), nil
|
|
||||||
}
|
|
||||||
if id, ok := v.(int64); ok && id > 0 {
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
if id, ok := v.(int); ok && id > 0 {
|
|
||||||
return int64(id), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
email := strings.TrimSpace(account.GetCredential("email"))
|
|
||||||
if email == "" {
|
|
||||||
email, _ = s.resolveAccountEmail(ctx, account)
|
|
||||||
}
|
|
||||||
if email == "" {
|
|
||||||
return 0, errors.New("sora2api token email missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenID, err := s.findTokenIDByEmail(ctx, email)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return tokenID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APISyncService) findTokenIDByEmail(ctx context.Context, email string) (int64, error) {
|
|
||||||
if !s.Enabled() {
|
|
||||||
return 0, errors.New("sora2api admin not configured")
|
|
||||||
}
|
|
||||||
tokens, err := s.sora2api.ListTokens(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, token := range tokens {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(token.Email), strings.TrimSpace(email)) {
|
|
||||||
return token.ID, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("sora2api token not found for email: %s", email)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractEmailFromAccessToken(accessToken string) string {
|
|
||||||
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
|
||||||
claims := jwt.MapClaims{}
|
|
||||||
_, _, err := parser.ParseUnverified(accessToken, claims)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if email, ok := claims["email"].(string); ok && strings.TrimSpace(email) != "" {
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
if profile, ok := claims["https://api.openai.com/profile"].(map[string]any); ok {
|
|
||||||
if email, ok := profile["email"].(string); ok && strings.TrimSpace(email) != "" {
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sora2APISyncService) fetchEmailFromSora(ctx context.Context, accessToken string) string {
|
|
||||||
if s.httpClient == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, soraMeAPIURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
resp, err := s.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
var payload map[string]any
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if email, ok := payload["email"].(string); ok && strings.TrimSpace(email) != "" {
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
884
backend/internal/service/sora_client.go
Normal file
884
backend/internal/service/sora_client.go
Normal file
@@ -0,0 +1,884 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/crypto/sha3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
soraChatGPTBaseURL = "https://chatgpt.com"
|
||||||
|
soraSentinelFlow = "sora_2_create_task"
|
||||||
|
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
soraPowMaxIteration = 500000
|
||||||
|
)
|
||||||
|
|
||||||
|
var soraPowCores = []int{8, 16, 24, 32}
|
||||||
|
|
||||||
|
var soraPowScripts = []string{
|
||||||
|
"https://cdn.oaistatic.com/_next/static/cXh69klOLzS0Gy2joLDRS/_ssgManifest.js?dpl=453ebaec0d44c2decab71692e1bfe39be35a24b3",
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraPowDPL = []string{
|
||||||
|
"prod-f501fe933b3edf57aea882da888e1a544df99840",
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraPowNavigatorKeys = []string{
|
||||||
|
"registerProtocolHandler−function registerProtocolHandler() { [native code] }",
|
||||||
|
"storage−[object StorageManager]",
|
||||||
|
"locks−[object LockManager]",
|
||||||
|
"appCodeName−Mozilla",
|
||||||
|
"permissions−[object Permissions]",
|
||||||
|
"webdriver−false",
|
||||||
|
"vendor−Google Inc.",
|
||||||
|
"mediaDevices−[object MediaDevices]",
|
||||||
|
"cookieEnabled−true",
|
||||||
|
"product−Gecko",
|
||||||
|
"productSub−20030107",
|
||||||
|
"hardwareConcurrency−32",
|
||||||
|
"onLine−true",
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraPowDocumentKeys = []string{
|
||||||
|
"_reactListeningo743lnnpvdg",
|
||||||
|
"location",
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraPowWindowKeys = []string{
|
||||||
|
"0", "window", "self", "document", "name", "location",
|
||||||
|
"navigator", "screen", "innerWidth", "innerHeight",
|
||||||
|
"localStorage", "sessionStorage", "crypto", "performance",
|
||||||
|
"fetch", "setTimeout", "setInterval", "console",
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraDesktopUserAgents = []string{
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36",
|
||||||
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||||
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
|
||||||
|
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||||
|
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
var soraRandMu sync.Mutex
|
||||||
|
var soraPerfStart = time.Now()
|
||||||
|
|
||||||
|
// SoraClient 定义直连 Sora 的任务操作接口。
|
||||||
|
type SoraClient interface {
|
||||||
|
Enabled() bool
|
||||||
|
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
||||||
|
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
||||||
|
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
||||||
|
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
||||||
|
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraImageRequest 图片生成请求参数
|
||||||
|
type SoraImageRequest struct {
|
||||||
|
Prompt string
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
|
MediaID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraVideoRequest 视频生成请求参数
|
||||||
|
type SoraVideoRequest struct {
|
||||||
|
Prompt string
|
||||||
|
Orientation string
|
||||||
|
Frames int
|
||||||
|
Model string
|
||||||
|
Size string
|
||||||
|
MediaID string
|
||||||
|
RemixTargetID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraImageTaskStatus 图片任务状态
|
||||||
|
type SoraImageTaskStatus struct {
|
||||||
|
ID string
|
||||||
|
Status string
|
||||||
|
ProgressPct float64
|
||||||
|
URLs []string
|
||||||
|
ErrorMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraVideoTaskStatus 视频任务状态
|
||||||
|
type SoraVideoTaskStatus struct {
|
||||||
|
ID string
|
||||||
|
Status string
|
||||||
|
ProgressPct int
|
||||||
|
URLs []string
|
||||||
|
ErrorMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraUpstreamError 上游错误
|
||||||
|
type SoraUpstreamError struct {
|
||||||
|
StatusCode int
|
||||||
|
Message string
|
||||||
|
Headers http.Header
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SoraUpstreamError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return "sora upstream error"
|
||||||
|
}
|
||||||
|
if e.Message != "" {
|
||||||
|
return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SoraDirectClient 直连 Sora 实现
|
||||||
|
type SoraDirectClient struct {
|
||||||
|
cfg *config.Config
|
||||||
|
httpUpstream HTTPUpstream
|
||||||
|
tokenProvider *OpenAITokenProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSoraDirectClient 创建 Sora 直连客户端
|
||||||
|
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
|
||||||
|
return &SoraDirectClient{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: httpUpstream,
|
||||||
|
tokenProvider: tokenProvider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enabled 判断是否启用 Sora 直连
|
||||||
|
func (c *SoraDirectClient) Enabled() bool {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "", errors.New("empty image data")
|
||||||
|
}
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if filename == "" {
|
||||||
|
filename = "image.png"
|
||||||
|
}
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
contentType := mime.TypeByExtension(path.Ext(filename))
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
partHeader := make(textproto.MIMEHeader)
|
||||||
|
partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, filename))
|
||||||
|
partHeader.Set("Content-Type", contentType)
|
||||||
|
part, err := writer.CreatePart(partHeader)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if _, err := part.Write(data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := writer.WriteField("file_name", filename); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
headers.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(respBody, &payload); err != nil {
|
||||||
|
return "", fmt.Errorf("parse upload response: %w", err)
|
||||||
|
}
|
||||||
|
id, _ := payload["id"].(string)
|
||||||
|
if strings.TrimSpace(id) == "" {
|
||||||
|
return "", errors.New("upload response missing id")
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
operation := "simple_compose"
|
||||||
|
inpaintItems := []map[string]any{}
|
||||||
|
if strings.TrimSpace(req.MediaID) != "" {
|
||||||
|
operation = "remix"
|
||||||
|
inpaintItems = append(inpaintItems, map[string]any{
|
||||||
|
"type": "image",
|
||||||
|
"frame_index": 0,
|
||||||
|
"upload_media_id": req.MediaID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
payload := map[string]any{
|
||||||
|
"type": "image_gen",
|
||||||
|
"operation": operation,
|
||||||
|
"prompt": req.Prompt,
|
||||||
|
"width": req.Width,
|
||||||
|
"height": req.Height,
|
||||||
|
"n_variants": 1,
|
||||||
|
"n_frames": 1,
|
||||||
|
"inpaint_items": inpaintItems,
|
||||||
|
}
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
headers.Set("Content-Type", "application/json")
|
||||||
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sentinel, err := c.generateSentinelToken(ctx, account, token)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
headers.Set("openai-sentinel-token", sentinel)
|
||||||
|
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
taskID, _ := resp["id"].(string)
|
||||||
|
if strings.TrimSpace(taskID) == "" {
|
||||||
|
return "", errors.New("image task response missing id")
|
||||||
|
}
|
||||||
|
return taskID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
orientation := req.Orientation
|
||||||
|
if orientation == "" {
|
||||||
|
orientation = "landscape"
|
||||||
|
}
|
||||||
|
nFrames := req.Frames
|
||||||
|
if nFrames <= 0 {
|
||||||
|
nFrames = 450
|
||||||
|
}
|
||||||
|
model := req.Model
|
||||||
|
if model == "" {
|
||||||
|
model = "sy_8"
|
||||||
|
}
|
||||||
|
size := req.Size
|
||||||
|
if size == "" {
|
||||||
|
size = "small"
|
||||||
|
}
|
||||||
|
|
||||||
|
inpaintItems := []map[string]any{}
|
||||||
|
if strings.TrimSpace(req.MediaID) != "" {
|
||||||
|
inpaintItems = append(inpaintItems, map[string]any{
|
||||||
|
"kind": "upload",
|
||||||
|
"upload_id": req.MediaID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
payload := map[string]any{
|
||||||
|
"kind": "video",
|
||||||
|
"prompt": req.Prompt,
|
||||||
|
"orientation": orientation,
|
||||||
|
"size": size,
|
||||||
|
"n_frames": nFrames,
|
||||||
|
"model": model,
|
||||||
|
"inpaint_items": inpaintItems,
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.RemixTargetID) != "" {
|
||||||
|
payload["remix_target_id"] = req.RemixTargetID
|
||||||
|
payload["cameo_ids"] = []string{}
|
||||||
|
payload["cameo_replacements"] = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
headers.Set("Content-Type", "application/json")
|
||||||
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sentinel, err := c.generateSentinelToken(ctx, account, token)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
headers.Set("openai-sentinel-token", sentinel)
|
||||||
|
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
taskID, _ := resp["id"].(string)
|
||||||
|
if strings.TrimSpace(taskID) == "" {
|
||||||
|
return "", errors.New("video task response missing id")
|
||||||
|
}
|
||||||
|
return taskID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
taskResponses, _ := resp["task_responses"].([]any)
|
||||||
|
for _, item := range taskResponses {
|
||||||
|
taskResp, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, _ := taskResp["id"].(string); id == taskID {
|
||||||
|
status := strings.TrimSpace(fmt.Sprintf("%v", taskResp["status"]))
|
||||||
|
progress := 0.0
|
||||||
|
if v, ok := taskResp["progress_pct"].(float64); ok {
|
||||||
|
progress = v
|
||||||
|
}
|
||||||
|
urls := []string{}
|
||||||
|
if generations, ok := taskResp["generations"].([]any); ok {
|
||||||
|
for _, genItem := range generations {
|
||||||
|
gen, ok := genItem.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if urlStr, ok := gen["url"].(string); ok && strings.TrimSpace(urlStr) != "" {
|
||||||
|
urls = append(urls, urlStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &SoraImageTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: status,
|
||||||
|
ProgressPct: progress,
|
||||||
|
URLs: urls,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var pending any
|
||||||
|
if err := json.Unmarshal(respBody, &pending); err == nil {
|
||||||
|
if list, ok := pending.([]any); ok {
|
||||||
|
for _, item := range list {
|
||||||
|
task, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, _ := task["id"].(string); id == taskID {
|
||||||
|
progress := 0
|
||||||
|
if v, ok := task["progress_pct"].(float64); ok {
|
||||||
|
progress = int(v * 100)
|
||||||
|
}
|
||||||
|
status := strings.TrimSpace(fmt.Sprintf("%v", task["status"]))
|
||||||
|
return &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: status,
|
||||||
|
ProgressPct: progress,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var draftsResp map[string]any
|
||||||
|
if err := json.Unmarshal(respBody, &draftsResp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items, _ := draftsResp["items"].([]any)
|
||||||
|
for _, item := range items {
|
||||||
|
draft, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, _ := draft["task_id"].(string); id == taskID {
|
||||||
|
kind := strings.TrimSpace(fmt.Sprintf("%v", draft["kind"]))
|
||||||
|
reason := strings.TrimSpace(fmt.Sprintf("%v", draft["reason_str"]))
|
||||||
|
if reason == "" {
|
||||||
|
reason = strings.TrimSpace(fmt.Sprintf("%v", draft["markdown_reason_str"]))
|
||||||
|
}
|
||||||
|
urlStr := strings.TrimSpace(fmt.Sprintf("%v", draft["downloadable_url"]))
|
||||||
|
if urlStr == "" {
|
||||||
|
urlStr = strings.TrimSpace(fmt.Sprintf("%v", draft["url"]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
|
||||||
|
msg := reason
|
||||||
|
if msg == "" {
|
||||||
|
msg = "Content violates guardrails"
|
||||||
|
}
|
||||||
|
return &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: "failed",
|
||||||
|
ErrorMsg: msg,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: "completed",
|
||||||
|
URLs: []string{urlStr},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) buildURL(endpoint string) string {
|
||||||
|
base := ""
|
||||||
|
if c != nil && c.cfg != nil {
|
||||||
|
base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
|
||||||
|
}
|
||||||
|
if base == "" {
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(endpoint, "/") {
|
||||||
|
return base + endpoint
|
||||||
|
}
|
||||||
|
return base + "/" + endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) defaultUserAgent() string {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return soraDefaultUserAgent
|
||||||
|
}
|
||||||
|
ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent)
|
||||||
|
if ua == "" {
|
||||||
|
return soraDefaultUserAgent
|
||||||
|
}
|
||||||
|
return ua
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", errors.New("account is nil")
|
||||||
|
}
|
||||||
|
if c.tokenProvider != nil {
|
||||||
|
return c.tokenProvider.GetAccessToken(ctx, account)
|
||||||
|
}
|
||||||
|
token := strings.TrimSpace(account.GetCredential("access_token"))
|
||||||
|
if token == "" {
|
||||||
|
return "", errors.New("access_token not found")
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
|
||||||
|
headers := http.Header{}
|
||||||
|
if token != "" {
|
||||||
|
headers.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
if userAgent != "" {
|
||||||
|
headers.Set("User-Agent", userAgent)
|
||||||
|
}
|
||||||
|
if c != nil && c.cfg != nil {
|
||||||
|
for key, value := range c.cfg.Sora.Client.Headers {
|
||||||
|
if strings.EqualFold(key, "authorization") || strings.EqualFold(key, "openai-sentinel-token") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
headers.Set(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) {
|
||||||
|
if strings.TrimSpace(urlStr) == "" {
|
||||||
|
return nil, nil, errors.New("empty upstream url")
|
||||||
|
}
|
||||||
|
timeout := 0
|
||||||
|
if c != nil && c.cfg != nil {
|
||||||
|
timeout = c.cfg.Sora.Client.TimeoutSeconds
|
||||||
|
}
|
||||||
|
if timeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
maxRetries := 0
|
||||||
|
if allowRetry && c != nil && c.cfg != nil {
|
||||||
|
maxRetries = c.cfg.Sora.Client.MaxRetries
|
||||||
|
}
|
||||||
|
if maxRetries < 0 {
|
||||||
|
maxRetries = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var bodyBytes []byte
|
||||||
|
if body != nil {
|
||||||
|
b, err := io.ReadAll(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
bodyBytes = b
|
||||||
|
}
|
||||||
|
|
||||||
|
attempts := maxRetries + 1
|
||||||
|
for attempt := 1; attempt <= attempts; attempt++ {
|
||||||
|
var reader io.Reader
|
||||||
|
if bodyBytes != nil {
|
||||||
|
reader = bytes.NewReader(bodyBytes)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, urlStr, reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
req.Header = headers.Clone()
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
proxyURL := ""
|
||||||
|
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
resp, err := c.doHTTP(req, proxyURL, account)
|
||||||
|
if err != nil {
|
||||||
|
if attempt < attempts && allowRetry {
|
||||||
|
c.sleepRetry(attempt)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, resp.Header, readErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.cfg != nil && c.cfg.Sora.Client.Debug {
|
||||||
|
log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||||
|
upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
|
||||||
|
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
|
||||||
|
c.sleepRetry(attempt)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, resp.Header, upstreamErr
|
||||||
|
}
|
||||||
|
return respBody, resp.Header, nil
|
||||||
|
}
|
||||||
|
return nil, nil, errors.New("upstream retries exhausted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||||
|
enableTLS := false
|
||||||
|
if c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint {
|
||||||
|
enableTLS = true
|
||||||
|
}
|
||||||
|
if c.httpUpstream != nil {
|
||||||
|
accountID := int64(0)
|
||||||
|
accountConcurrency := 0
|
||||||
|
if account != nil {
|
||||||
|
accountID = account.ID
|
||||||
|
accountConcurrency = account.Concurrency
|
||||||
|
}
|
||||||
|
return c.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
|
||||||
|
}
|
||||||
|
return http.DefaultClient.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) sleepRetry(attempt int) {
|
||||||
|
backoff := time.Duration(attempt*attempt) * time.Second
|
||||||
|
if backoff > 10*time.Second {
|
||||||
|
backoff = 10 * time.Second
|
||||||
|
}
|
||||||
|
time.Sleep(backoff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
|
||||||
|
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
msg = sanitizeUpstreamErrorMessage(msg)
|
||||||
|
if msg == "" {
|
||||||
|
msg = truncateForLog(body, 256)
|
||||||
|
}
|
||||||
|
return &SoraUpstreamError{
|
||||||
|
StatusCode: status,
|
||||||
|
Message: msg,
|
||||||
|
Headers: headers,
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||||
|
reqID := uuid.NewString()
|
||||||
|
userAgent := soraRandChoice(soraDesktopUserAgents)
|
||||||
|
powToken := soraGetPowToken(userAgent)
|
||||||
|
payload := map[string]any{
|
||||||
|
"p": powToken,
|
||||||
|
"flow": soraSentinelFlow,
|
||||||
|
"id": reqID,
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
headers := http.Header{}
|
||||||
|
headers.Set("Accept", "application/json, text/plain, */*")
|
||||||
|
headers.Set("Content-Type", "application/json")
|
||||||
|
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||||
|
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||||
|
headers.Set("User-Agent", userAgent)
|
||||||
|
if accessToken != "" {
|
||||||
|
headers.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req"
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
sentinel := soraBuildSentinelToken(soraSentinelFlow, reqID, powToken, resp, userAgent)
|
||||||
|
if sentinel == "" {
|
||||||
|
return "", errors.New("failed to build sentinel token")
|
||||||
|
}
|
||||||
|
return sentinel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraRandChoice(items []string) string {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
soraRandMu.Lock()
|
||||||
|
idx := soraRand.Intn(len(items))
|
||||||
|
soraRandMu.Unlock()
|
||||||
|
return items[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraGetPowToken(userAgent string) string {
|
||||||
|
configList := soraBuildPowConfig(userAgent)
|
||||||
|
seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64)
|
||||||
|
difficulty := "0fffff"
|
||||||
|
solution, _ := soraSolvePow(seed, difficulty, configList)
|
||||||
|
return "gAAAAAC" + solution
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraRandFloat() float64 {
|
||||||
|
soraRandMu.Lock()
|
||||||
|
defer soraRandMu.Unlock()
|
||||||
|
return soraRand.Float64()
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraBuildPowConfig(userAgent string) []any {
|
||||||
|
screen := soraRandChoice([]string{
|
||||||
|
strconv.Itoa(1920 + 1080),
|
||||||
|
strconv.Itoa(2560 + 1440),
|
||||||
|
strconv.Itoa(1920 + 1200),
|
||||||
|
strconv.Itoa(2560 + 1600),
|
||||||
|
})
|
||||||
|
screenVal, _ := strconv.Atoi(screen)
|
||||||
|
perfMs := float64(time.Since(soraPerfStart).Milliseconds())
|
||||||
|
wallMs := float64(time.Now().UnixNano()) / 1e6
|
||||||
|
diff := wallMs - perfMs
|
||||||
|
return []any{
|
||||||
|
screenVal,
|
||||||
|
soraPowParseTime(),
|
||||||
|
4294705152,
|
||||||
|
0,
|
||||||
|
userAgent,
|
||||||
|
soraRandChoice(soraPowScripts),
|
||||||
|
soraRandChoice(soraPowDPL),
|
||||||
|
"en-US",
|
||||||
|
"en-US,es-US,en,es",
|
||||||
|
0,
|
||||||
|
soraRandChoice(soraPowNavigatorKeys),
|
||||||
|
soraRandChoice(soraPowDocumentKeys),
|
||||||
|
soraRandChoice(soraPowWindowKeys),
|
||||||
|
perfMs,
|
||||||
|
uuid.NewString(),
|
||||||
|
"",
|
||||||
|
soraRandChoiceInt(soraPowCores),
|
||||||
|
diff,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraRandChoiceInt(items []int) int {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
soraRandMu.Lock()
|
||||||
|
idx := soraRand.Intn(len(items))
|
||||||
|
soraRandMu.Unlock()
|
||||||
|
return items[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraPowParseTime() string {
|
||||||
|
loc := time.FixedZone("EST", -5*3600)
|
||||||
|
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraSolvePow(seed, difficulty string, configList []any) (string, bool) {
|
||||||
|
diffLen := len(difficulty) / 2
|
||||||
|
target, err := hexDecodeString(difficulty)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
seedBytes := []byte(seed)
|
||||||
|
|
||||||
|
part1 := mustMarshalJSON(configList[:3])
|
||||||
|
part2 := mustMarshalJSON(configList[4:9])
|
||||||
|
part3 := mustMarshalJSON(configList[10:])
|
||||||
|
|
||||||
|
staticPart1 := append(part1[:len(part1)-1], ',')
|
||||||
|
staticPart2 := append([]byte(","), append(part2[1:len(part2)-1], ',')...)
|
||||||
|
staticPart3 := append([]byte(","), part3[1:]...)
|
||||||
|
|
||||||
|
for i := 0; i < soraPowMaxIteration; i++ {
|
||||||
|
dynamicI := []byte(strconv.Itoa(i))
|
||||||
|
dynamicJ := []byte(strconv.Itoa(i >> 1))
|
||||||
|
finalJSON := make([]byte, 0, len(staticPart1)+len(dynamicI)+len(staticPart2)+len(dynamicJ)+len(staticPart3))
|
||||||
|
finalJSON = append(finalJSON, staticPart1...)
|
||||||
|
finalJSON = append(finalJSON, dynamicI...)
|
||||||
|
finalJSON = append(finalJSON, staticPart2...)
|
||||||
|
finalJSON = append(finalJSON, dynamicJ...)
|
||||||
|
finalJSON = append(finalJSON, staticPart3...)
|
||||||
|
|
||||||
|
b64 := base64.StdEncoding.EncodeToString(finalJSON)
|
||||||
|
hash := sha3.Sum512(append(seedBytes, []byte(b64)...))
|
||||||
|
if bytes.Compare(hash[:diffLen], target[:diffLen]) <= 0 {
|
||||||
|
return b64, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
errorToken := "wQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\"%s\"", seed)))
|
||||||
|
return errorToken, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func soraBuildSentinelToken(flow, reqID, powToken string, resp map[string]any, userAgent string) string {
|
||||||
|
finalPow := powToken
|
||||||
|
proof, _ := resp["proofofwork"].(map[string]any)
|
||||||
|
if required, _ := proof["required"].(bool); required {
|
||||||
|
seed, _ := proof["seed"].(string)
|
||||||
|
difficulty, _ := proof["difficulty"].(string)
|
||||||
|
if seed != "" && difficulty != "" {
|
||||||
|
configList := soraBuildPowConfig(userAgent)
|
||||||
|
solution, _ := soraSolvePow(seed, difficulty, configList)
|
||||||
|
finalPow = "gAAAAAB" + solution
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(finalPow, "~S") {
|
||||||
|
finalPow += "~S"
|
||||||
|
}
|
||||||
|
turnstile, _ := resp["turnstile"].(map[string]any)
|
||||||
|
tokenPayload := map[string]any{
|
||||||
|
"p": finalPow,
|
||||||
|
"t": safeMapString(turnstile, "dx"),
|
||||||
|
"c": safeString(resp["token"]),
|
||||||
|
"id": reqID,
|
||||||
|
"flow": flow,
|
||||||
|
}
|
||||||
|
encoded, _ := json.Marshal(tokenPayload)
|
||||||
|
return string(encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeMapString(m map[string]any, key string) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if v, ok := m[key]; ok {
|
||||||
|
return safeString(v)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeString(v any) string {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return val
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshalJSON(v any) []byte {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func hexDecodeString(s string) ([]byte, error) {
|
||||||
|
dst := make([]byte, len(s)/2)
|
||||||
|
_, err := hex.Decode(dst, []byte(s))
|
||||||
|
return dst, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeSoraLogURL(raw string) string {
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
q := parsed.Query()
|
||||||
|
q.Del("sig")
|
||||||
|
q.Del("expires")
|
||||||
|
parsed.RawQuery = q.Encode()
|
||||||
|
return parsed.String()
|
||||||
|
}
|
||||||
54
backend/internal/service/sora_client_test.go
Normal file
54
backend/internal/service/sora_client_test.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSoraDirectClient_DoRequestSuccess(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{BaseURL: server.URL},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, nil, nil)
|
||||||
|
|
||||||
|
body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, string(body), "ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
Headers: map[string]string{
|
||||||
|
"X-Test": "yes",
|
||||||
|
"Authorization": "should-ignore",
|
||||||
|
"openai-sentinel-token": "skip",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, nil, nil)
|
||||||
|
|
||||||
|
headers := client.buildBaseHeaders("token-123", "UA")
|
||||||
|
require.Equal(t, "Bearer token-123", headers.Get("Authorization"))
|
||||||
|
require.Equal(t, "UA", headers.Get("User-Agent"))
|
||||||
|
require.Equal(t, "yes", headers.Get("X-Test"))
|
||||||
|
require.Empty(t, headers.Get("openai-sentinel-token"))
|
||||||
|
}
|
||||||
@@ -4,10 +4,12 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -39,23 +41,23 @@ type soraStreamingResult struct {
|
|||||||
firstTokenMs *int
|
firstTokenMs *int
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoraGatewayService handles forwarding requests to sora2api.
|
// SoraGatewayService handles forwarding requests to Sora upstream.
|
||||||
type SoraGatewayService struct {
|
type SoraGatewayService struct {
|
||||||
sora2api *Sora2APIService
|
soraClient SoraClient
|
||||||
httpUpstream HTTPUpstream
|
mediaStorage *SoraMediaStorage
|
||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSoraGatewayService(
|
func NewSoraGatewayService(
|
||||||
sora2api *Sora2APIService,
|
soraClient SoraClient,
|
||||||
httpUpstream HTTPUpstream,
|
mediaStorage *SoraMediaStorage,
|
||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
) *SoraGatewayService {
|
) *SoraGatewayService {
|
||||||
return &SoraGatewayService{
|
return &SoraGatewayService{
|
||||||
sora2api: sora2api,
|
soraClient: soraClient,
|
||||||
httpUpstream: httpUpstream,
|
mediaStorage: mediaStorage,
|
||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
@@ -64,31 +66,53 @@ func NewSoraGatewayService(
|
|||||||
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
|
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
if s.sora2api == nil || !s.sora2api.Enabled() {
|
if s.soraClient == nil || !s.soraClient.Enabled() {
|
||||||
if c != nil {
|
if c != nil {
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"type": "api_error",
|
"type": "api_error",
|
||||||
"message": "sora2api 未配置",
|
"message": "Sora 上游未配置",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return nil, errors.New("sora2api not configured")
|
return nil, errors.New("sora upstream not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
var reqBody map[string]any
|
var reqBody map[string]any
|
||||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream)
|
||||||
return nil, fmt.Errorf("parse request: %w", err)
|
return nil, fmt.Errorf("parse request: %w", err)
|
||||||
}
|
}
|
||||||
reqModel, _ := reqBody["model"].(string)
|
reqModel, _ := reqBody["model"].(string)
|
||||||
reqStream, _ := reqBody["stream"].(bool)
|
reqStream, _ := reqBody["stream"].(bool)
|
||||||
|
if strings.TrimSpace(reqModel) == "" {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
|
||||||
|
return nil, errors.New("model is required")
|
||||||
|
}
|
||||||
|
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel && mappedModel != "" {
|
if mappedModel != "" && mappedModel != reqModel {
|
||||||
reqBody["model"] = mappedModel
|
reqModel = mappedModel
|
||||||
if updated, err := json.Marshal(reqBody); err == nil {
|
|
||||||
body = updated
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelCfg, ok := GetSoraModelConfig(reqModel)
|
||||||
|
if !ok {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
|
||||||
|
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||||
|
}
|
||||||
|
if modelCfg.Type == "prompt_enhance" {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
|
||||||
|
return nil, fmt.Errorf("prompt-enhance not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||||
|
if strings.TrimSpace(prompt) == "" {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||||
|
return nil, errors.New("prompt is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(videoInput) != "" {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
|
||||||
|
return nil, errors.New("video input not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
||||||
@@ -96,81 +120,122 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamReq, err := s.sora2api.NewAPIRequest(reqCtx, http.MethodPost, "/v1/chat/completions", body)
|
var imageData []byte
|
||||||
|
imageFilename := ""
|
||||||
|
if strings.TrimSpace(imageInput) != "" {
|
||||||
|
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if c != nil {
|
imageData = decoded
|
||||||
if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
|
imageFilename = filename
|
||||||
upstreamReq.Header.Set("User-Agent", ua)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if reqStream {
|
|
||||||
upstreamReq.Header.Set("Accept", "text/event-stream")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if c != nil {
|
mediaID := ""
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
if len(imageData) > 0 {
|
||||||
}
|
uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename)
|
||||||
|
|
||||||
proxyURL := ""
|
|
||||||
if account != nil && account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp *http.Response
|
|
||||||
if s.httpUpstream != nil {
|
|
||||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|
||||||
} else {
|
|
||||||
resp, err = http.DefaultClient.Do(upstreamReq)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.setUpstreamRequestError(c, account, err)
|
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
}
|
||||||
|
mediaID = uploadID
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
taskID := ""
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
var err error
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
switch modelCfg.Type {
|
||||||
_ = resp.Body.Close()
|
case "image":
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
Prompt: prompt,
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
Width: modelCfg.Width,
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
Height: modelCfg.Height,
|
||||||
Platform: account.Platform,
|
MediaID: mediaID,
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
||||||
Kind: "failover",
|
|
||||||
Message: upstreamMsg,
|
|
||||||
})
|
})
|
||||||
s.handleFailoverSideEffects(ctx, resp, account)
|
case "video":
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||||
|
Prompt: prompt,
|
||||||
|
Orientation: modelCfg.Orientation,
|
||||||
|
Frames: modelCfg.Frames,
|
||||||
|
Model: modelCfg.Model,
|
||||||
|
Size: modelCfg.Size,
|
||||||
|
MediaID: mediaID,
|
||||||
|
RemixTargetID: remixTargetID,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
|
||||||
}
|
}
|
||||||
return s.handleErrorResponse(ctx, resp, c, account, reqModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, reqModel, clientStream)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := &ForwardResult{
|
if clientStream && c != nil {
|
||||||
RequestID: resp.Header.Get("x-request-id"),
|
s.prepareSoraStream(c, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var mediaURLs []string
|
||||||
|
mediaType := modelCfg.Type
|
||||||
|
imageCount := 0
|
||||||
|
imageSize := ""
|
||||||
|
if modelCfg.Type == "image" {
|
||||||
|
urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
|
||||||
|
if pollErr != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
mediaURLs = urls
|
||||||
|
imageCount = len(urls)
|
||||||
|
imageSize = soraImageSizeFromModel(reqModel)
|
||||||
|
} else if modelCfg.Type == "video" {
|
||||||
|
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
||||||
|
if pollErr != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
mediaURLs = urls
|
||||||
|
} else {
|
||||||
|
mediaType = "prompt"
|
||||||
|
}
|
||||||
|
|
||||||
|
finalURLs := mediaURLs
|
||||||
|
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||||
|
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||||
|
if storeErr != nil {
|
||||||
|
return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream)
|
||||||
|
}
|
||||||
|
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||||
|
} else {
|
||||||
|
finalURLs = s.normalizeSoraMediaURLs(mediaURLs)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := buildSoraContent(mediaType, finalURLs)
|
||||||
|
var firstTokenMs *int
|
||||||
|
if clientStream {
|
||||||
|
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||||
|
if streamErr != nil {
|
||||||
|
return nil, streamErr
|
||||||
|
}
|
||||||
|
firstTokenMs = ms
|
||||||
|
} else if c != nil {
|
||||||
|
response := buildSoraNonStreamResponse(content, reqModel)
|
||||||
|
if len(finalURLs) > 0 {
|
||||||
|
response["media_url"] = finalURLs[0]
|
||||||
|
if len(finalURLs) > 1 {
|
||||||
|
response["media_urls"] = finalURLs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ForwardResult{
|
||||||
|
RequestID: taskID,
|
||||||
Model: reqModel,
|
Model: reqModel,
|
||||||
Stream: clientStream,
|
Stream: clientStream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: streamResult.firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
Usage: ClaudeUsage{},
|
Usage: ClaudeUsage{},
|
||||||
MediaType: streamResult.mediaType,
|
MediaType: mediaType,
|
||||||
MediaURL: firstMediaURL(streamResult.mediaURLs),
|
MediaURL: firstMediaURL(finalURLs),
|
||||||
ImageCount: streamResult.imageCount,
|
ImageCount: imageCount,
|
||||||
ImageSize: streamResult.imageSize,
|
ImageSize: imageSize,
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||||
@@ -780,3 +845,414 @@ func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) str
|
|||||||
}
|
}
|
||||||
return prefix + path + "?" + encoded
|
return prefix + path + "?" + encoded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
if strings.TrimSpace(requestID) != "" {
|
||||||
|
c.Header("x-request-id", requestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
writer := c.Writer
|
||||||
|
flusher, _ := writer.(http.Flusher)
|
||||||
|
|
||||||
|
chunk := map[string]any{
|
||||||
|
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []any{
|
||||||
|
map[string]any{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
encoded, _ := json.Marshal(chunk)
|
||||||
|
if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
finalChunk := map[string]any{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []any{
|
||||||
|
map[string]any{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]any{},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
finalEncoded, _ := json.Marshal(finalChunk)
|
||||||
|
if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil {
|
||||||
|
return &ms, err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil {
|
||||||
|
return &ms, err
|
||||||
|
}
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return &ms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if stream {
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorEvent)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(status, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var upstreamErr *SoraUpstreamError
|
||||||
|
if errors.As(err, &upstreamErr) {
|
||||||
|
if s.rateLimitService != nil && account != nil {
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||||
|
}
|
||||||
|
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
|
||||||
|
return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
|
||||||
|
}
|
||||||
|
msg := upstreamErr.Message
|
||||||
|
if override := soraProErrorMessage(model, msg); override != "" {
|
||||||
|
msg = override
|
||||||
|
}
|
||||||
|
s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||||
|
interval := s.pollInterval()
|
||||||
|
maxAttempts := s.pollMaxAttempts()
|
||||||
|
lastPing := time.Now()
|
||||||
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
status, err := s.soraClient.GetImageTask(ctx, account, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch strings.ToLower(status.Status) {
|
||||||
|
case "succeeded", "completed":
|
||||||
|
return status.URLs, nil
|
||||||
|
case "failed":
|
||||||
|
if status.ErrorMsg != "" {
|
||||||
|
return nil, errors.New(status.ErrorMsg)
|
||||||
|
}
|
||||||
|
return nil, errors.New("Sora image generation failed")
|
||||||
|
}
|
||||||
|
if stream {
|
||||||
|
s.maybeSendPing(c, &lastPing)
|
||||||
|
}
|
||||||
|
if err := sleepWithContext(ctx, interval); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("Sora image generation timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||||
|
interval := s.pollInterval()
|
||||||
|
maxAttempts := s.pollMaxAttempts()
|
||||||
|
lastPing := time.Now()
|
||||||
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
|
status, err := s.soraClient.GetVideoTask(ctx, account, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch strings.ToLower(status.Status) {
|
||||||
|
case "completed", "succeeded":
|
||||||
|
return status.URLs, nil
|
||||||
|
case "failed":
|
||||||
|
if status.ErrorMsg != "" {
|
||||||
|
return nil, errors.New(status.ErrorMsg)
|
||||||
|
}
|
||||||
|
return nil, errors.New("Sora video generation failed")
|
||||||
|
}
|
||||||
|
if stream {
|
||||||
|
s.maybeSendPing(c, &lastPing)
|
||||||
|
}
|
||||||
|
if err := sleepWithContext(ctx, interval); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("Sora video generation timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) pollInterval() time.Duration {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return 2 * time.Second
|
||||||
|
}
|
||||||
|
interval := s.cfg.Sora.Client.PollIntervalSeconds
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = 2
|
||||||
|
}
|
||||||
|
return time.Duration(interval) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) pollMaxAttempts() int {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return 600
|
||||||
|
}
|
||||||
|
maxAttempts := s.cfg.Sora.Client.MaxPollAttempts
|
||||||
|
if maxAttempts <= 0 {
|
||||||
|
maxAttempts = 600
|
||||||
|
}
|
||||||
|
return maxAttempts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
interval := 10 * time.Second
|
||||||
|
if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 {
|
||||||
|
interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second
|
||||||
|
}
|
||||||
|
if time.Since(*lastPing) < interval {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil {
|
||||||
|
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
*lastPing = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string {
|
||||||
|
if len(urls) == 0 {
|
||||||
|
return urls
|
||||||
|
}
|
||||||
|
output := make([]string, 0, len(urls))
|
||||||
|
for _, raw := range urls {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||||
|
output = append(output, raw)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pathVal := raw
|
||||||
|
if !strings.HasPrefix(pathVal, "/") {
|
||||||
|
pathVal = "/" + pathVal
|
||||||
|
}
|
||||||
|
output = append(output, s.buildSoraMediaURL(pathVal, ""))
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSoraContent(mediaType string, urls []string) string {
|
||||||
|
switch mediaType {
|
||||||
|
case "image":
|
||||||
|
parts := make([]string, 0, len(urls))
|
||||||
|
for _, u := range urls {
|
||||||
|
parts = append(parts, fmt.Sprintf("", u))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "\n")
|
||||||
|
case "video":
|
||||||
|
if len(urls) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("```html\n<video src='%s' controls></video>\n```", urls[0])
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) {
|
||||||
|
if body == nil {
|
||||||
|
return "", "", "", ""
|
||||||
|
}
|
||||||
|
if v, ok := body["remix_target_id"].(string); ok {
|
||||||
|
remixTargetID = v
|
||||||
|
}
|
||||||
|
if v, ok := body["image"].(string); ok {
|
||||||
|
imageInput = v
|
||||||
|
}
|
||||||
|
if v, ok := body["video"].(string); ok {
|
||||||
|
videoInput = v
|
||||||
|
}
|
||||||
|
if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" {
|
||||||
|
prompt = v
|
||||||
|
}
|
||||||
|
if messages, ok := body["messages"].([]any); ok {
|
||||||
|
builder := strings.Builder{}
|
||||||
|
for _, raw := range messages {
|
||||||
|
msg, ok := raw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role, _ := msg["role"].(string)
|
||||||
|
if role != "" && role != "user" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg["content"]
|
||||||
|
text, img, vid := parseSoraMessageContent(content)
|
||||||
|
if text != "" {
|
||||||
|
if builder.Len() > 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(text)
|
||||||
|
}
|
||||||
|
if imageInput == "" && img != "" {
|
||||||
|
imageInput = img
|
||||||
|
}
|
||||||
|
if videoInput == "" && vid != "" {
|
||||||
|
videoInput = vid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if prompt == "" {
|
||||||
|
prompt = builder.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prompt, imageInput, videoInput, remixTargetID
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSoraMessageContent(content any) (text, imageInput, videoInput string) {
|
||||||
|
switch val := content.(type) {
|
||||||
|
case string:
|
||||||
|
return val, "", ""
|
||||||
|
case []any:
|
||||||
|
builder := strings.Builder{}
|
||||||
|
for _, item := range val {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t, _ := itemMap["type"].(string)
|
||||||
|
switch t {
|
||||||
|
case "text":
|
||||||
|
if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
|
||||||
|
if builder.Len() > 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(txt)
|
||||||
|
}
|
||||||
|
case "image_url":
|
||||||
|
if imageInput == "" {
|
||||||
|
if urlVal, ok := itemMap["image_url"].(map[string]any); ok {
|
||||||
|
imageInput = fmt.Sprintf("%v", urlVal["url"])
|
||||||
|
} else if urlStr, ok := itemMap["image_url"].(string); ok {
|
||||||
|
imageInput = urlStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "video_url":
|
||||||
|
if videoInput == "" {
|
||||||
|
if urlVal, ok := itemMap["video_url"].(map[string]any); ok {
|
||||||
|
videoInput = fmt.Sprintf("%v", urlVal["url"])
|
||||||
|
} else if urlStr, ok := itemMap["video_url"].(string); ok {
|
||||||
|
videoInput = urlStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.String(), imageInput, videoInput
|
||||||
|
default:
|
||||||
|
return "", "", ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
|
||||||
|
raw := strings.TrimSpace(input)
|
||||||
|
if raw == "" {
|
||||||
|
return nil, "", errors.New("empty image input")
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "data:") {
|
||||||
|
parts := strings.SplitN(raw, ",", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, "", errors.New("invalid data url")
|
||||||
|
}
|
||||||
|
meta := parts[0]
|
||||||
|
payload := parts[1]
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
ext := ""
|
||||||
|
if strings.HasPrefix(meta, "data:") {
|
||||||
|
metaParts := strings.SplitN(meta[5:], ";", 2)
|
||||||
|
if len(metaParts) > 0 {
|
||||||
|
if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 {
|
||||||
|
ext = exts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filename := "image" + ext
|
||||||
|
return decoded, filename, nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||||
|
return downloadSoraImageInput(ctx, raw)
|
||||||
|
}
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", errors.New("invalid base64 image")
|
||||||
|
}
|
||||||
|
return decoded, "image.png", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
ext := fileExtFromURL(rawURL)
|
||||||
|
if ext == "" {
|
||||||
|
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
filename := "image" + ext
|
||||||
|
return data, filename, nil
|
||||||
|
}
|
||||||
|
|||||||
99
backend/internal/service/sora_gateway_service_test.go
Normal file
99
backend/internal/service/sora_gateway_service_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubSoraClientForPoll struct {
|
||||||
|
imageStatus *SoraImageTaskStatus
|
||||||
|
videoStatus *SoraVideoTaskStatus
|
||||||
|
imageCalls int
|
||||||
|
videoCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
||||||
|
func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
|
||||||
|
return "task-image", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||||
|
return "task-video", nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||||
|
s.imageCalls++
|
||||||
|
return s.imageStatus, nil
|
||||||
|
}
|
||||||
|
func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
||||||
|
s.videoCalls++
|
||||||
|
return s.videoStatus, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{
|
||||||
|
imageStatus: &SoraImageTaskStatus{
|
||||||
|
Status: "completed",
|
||||||
|
URLs: []string{"https://example.com/a.png"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
|
||||||
|
urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"https://example.com/a.png"}, urls)
|
||||||
|
require.Equal(t, 1, client.imageCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||||
|
client := &stubSoraClientForPoll{
|
||||||
|
videoStatus: &SoraVideoTaskStatus{
|
||||||
|
Status: "failed",
|
||||||
|
ErrorMsg: "reject",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
PollIntervalSeconds: 1,
|
||||||
|
MaxPollAttempts: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewSoraGatewayService(client, nil, nil, cfg)
|
||||||
|
|
||||||
|
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Empty(t, urls)
|
||||||
|
require.Contains(t, err.Error(), "reject")
|
||||||
|
require.Equal(t, 1, client.videoCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
SoraMediaSigningKey: "test-key",
|
||||||
|
SoraMediaSignedURLTTLSeconds: 600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
service := NewSoraGatewayService(nil, nil, nil, cfg)
|
||||||
|
|
||||||
|
url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
|
||||||
|
require.Contains(t, url, "/sora/media-signed")
|
||||||
|
require.Contains(t, url, "expires=")
|
||||||
|
require.Contains(t, url, "sig=")
|
||||||
|
}
|
||||||
117
backend/internal/service/sora_media_cleanup_service.go
Normal file
117
backend/internal/service/sora_media_cleanup_service.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/robfig/cron/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var soraCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||||
|
|
||||||
|
// SoraMediaCleanupService 定期清理本地媒体文件
|
||||||
|
type SoraMediaCleanupService struct {
|
||||||
|
storage *SoraMediaStorage
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
cron *cron.Cron
|
||||||
|
|
||||||
|
startOnce sync.Once
|
||||||
|
stopOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||||
|
return &SoraMediaCleanupService{
|
||||||
|
storage: storage,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaCleanupService) Start() {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.cfg.Sora.Storage.Cleanup.Enabled {
|
||||||
|
log.Printf("[SoraCleanup] not started (disabled)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.storage == nil || !s.storage.Enabled() {
|
||||||
|
log.Printf("[SoraCleanup] not started (storage disabled)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.startOnce.Do(func() {
|
||||||
|
schedule := strings.TrimSpace(s.cfg.Sora.Storage.Cleanup.Schedule)
|
||||||
|
if schedule == "" {
|
||||||
|
log.Printf("[SoraCleanup] not started (empty schedule)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
loc := time.Local
|
||||||
|
if strings.TrimSpace(s.cfg.Timezone) != "" {
|
||||||
|
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
|
||||||
|
loc = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c := cron.New(cron.WithParser(soraCleanupCronParser), cron.WithLocation(loc))
|
||||||
|
if _, err := c.AddFunc(schedule, func() { s.runCleanup() }); err != nil {
|
||||||
|
log.Printf("[SoraCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.cron = c
|
||||||
|
s.cron.Start()
|
||||||
|
log.Printf("[SoraCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaCleanupService) Stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
if s.cron != nil {
|
||||||
|
ctx := s.cron.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
log.Printf("[SoraCleanup] cron stop timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaCleanupService) runCleanup() {
|
||||||
|
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
|
||||||
|
if retention <= 0 {
|
||||||
|
log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -retention)
|
||||||
|
deleted := 0
|
||||||
|
|
||||||
|
roots := []string{s.storage.ImageRoot(), s.storage.VideoRoot()}
|
||||||
|
for _, root := range roots {
|
||||||
|
if root == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = filepath.Walk(root, func(p string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if info.ModTime().Before(cutoff) {
|
||||||
|
if rmErr := os.Remove(p); rmErr == nil {
|
||||||
|
deleted++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
log.Printf("[SoraCleanup] cleanup finished, deleted=%d", deleted)
|
||||||
|
}
|
||||||
46
backend/internal/service/sora_media_cleanup_service_test.go
Normal file
46
backend/internal/service/sora_media_cleanup_service_test.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSoraMediaCleanupService_RunCleanup(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Storage: config.SoraStorageConfig{
|
||||||
|
Type: "local",
|
||||||
|
LocalPath: tmpDir,
|
||||||
|
Cleanup: config.SoraStorageCleanupConfig{
|
||||||
|
Enabled: true,
|
||||||
|
RetentionDays: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage := NewSoraMediaStorage(cfg)
|
||||||
|
require.NoError(t, storage.EnsureLocalDirs())
|
||||||
|
|
||||||
|
oldImage := filepath.Join(storage.ImageRoot(), "old.png")
|
||||||
|
newVideo := filepath.Join(storage.VideoRoot(), "new.mp4")
|
||||||
|
require.NoError(t, os.WriteFile(oldImage, []byte("old"), 0o644))
|
||||||
|
require.NoError(t, os.WriteFile(newVideo, []byte("new"), 0o644))
|
||||||
|
|
||||||
|
oldTime := time.Now().Add(-48 * time.Hour)
|
||||||
|
require.NoError(t, os.Chtimes(oldImage, oldTime, oldTime))
|
||||||
|
|
||||||
|
cleanup := NewSoraMediaCleanupService(storage, cfg)
|
||||||
|
cleanup.runCleanup()
|
||||||
|
|
||||||
|
require.NoFileExists(t, oldImage)
|
||||||
|
require.FileExists(t, newVideo)
|
||||||
|
}
|
||||||
256
backend/internal/service/sora_media_storage.go
Normal file
256
backend/internal/service/sora_media_storage.go
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
soraStorageDefaultRoot = "/app/data/sora"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SoraMediaStorage 负责下载并落地 Sora 媒体
|
||||||
|
type SoraMediaStorage struct {
|
||||||
|
cfg *config.Config
|
||||||
|
root string
|
||||||
|
imageRoot string
|
||||||
|
videoRoot string
|
||||||
|
maxConcurrent int
|
||||||
|
fallbackToUpstream bool
|
||||||
|
debug bool
|
||||||
|
sem chan struct{}
|
||||||
|
ready bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||||
|
storage := &SoraMediaStorage{cfg: cfg}
|
||||||
|
storage.refreshConfig()
|
||||||
|
if storage.Enabled() {
|
||||||
|
if err := storage.EnsureLocalDirs(); err != nil {
|
||||||
|
log.Printf("[SoraStorage] 初始化失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return storage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) Enabled() bool {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.ToLower(strings.TrimSpace(s.cfg.Sora.Storage.Type)) == "local"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) Root() string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s.root
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) ImageRoot() string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s.imageRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) VideoRoot() string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s.videoRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) refreshConfig() {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
root := strings.TrimSpace(s.cfg.Sora.Storage.LocalPath)
|
||||||
|
if root == "" {
|
||||||
|
root = soraStorageDefaultRoot
|
||||||
|
}
|
||||||
|
s.root = root
|
||||||
|
s.imageRoot = filepath.Join(root, "image")
|
||||||
|
s.videoRoot = filepath.Join(root, "video")
|
||||||
|
|
||||||
|
maxConcurrent := s.cfg.Sora.Storage.MaxConcurrentDownloads
|
||||||
|
if maxConcurrent <= 0 {
|
||||||
|
maxConcurrent = 4
|
||||||
|
}
|
||||||
|
s.maxConcurrent = maxConcurrent
|
||||||
|
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
|
||||||
|
s.debug = s.cfg.Sora.Storage.Debug
|
||||||
|
s.sem = make(chan struct{}, maxConcurrent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureLocalDirs 创建并校验本地目录
|
||||||
|
func (s *SoraMediaStorage) EnsureLocalDirs() error {
|
||||||
|
if s == nil || !s.Enabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(s.imageRoot, 0o755); err != nil {
|
||||||
|
return fmt.Errorf("create image dir: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(s.videoRoot, 0o755); err != nil {
|
||||||
|
return fmt.Errorf("create video dir: %w", err)
|
||||||
|
}
|
||||||
|
s.ready = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreFromURLs 下载并存储媒体,返回相对路径或回退 URL
|
||||||
|
func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, urls []string) ([]string, error) {
|
||||||
|
if len(urls) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if s == nil || !s.Enabled() {
|
||||||
|
return urls, nil
|
||||||
|
}
|
||||||
|
if !s.ready {
|
||||||
|
if err := s.EnsureLocalDirs(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results := make([]string, 0, len(urls))
|
||||||
|
for _, raw := range urls {
|
||||||
|
relative, err := s.downloadAndStore(ctx, mediaType, raw)
|
||||||
|
if err != nil {
|
||||||
|
if s.fallbackToUpstream {
|
||||||
|
results = append(results, raw)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, relative)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
|
||||||
|
if strings.TrimSpace(rawURL) == "" {
|
||||||
|
return "", errors.New("empty url")
|
||||||
|
}
|
||||||
|
root := s.imageRoot
|
||||||
|
if mediaType == "video" {
|
||||||
|
root = s.videoRoot
|
||||||
|
}
|
||||||
|
if root == "" {
|
||||||
|
return "", errors.New("storage root not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
retries := 3
|
||||||
|
for attempt := 1; attempt <= retries; attempt++ {
|
||||||
|
release, err := s.acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
relative, err := s.downloadOnce(ctx, root, mediaType, rawURL)
|
||||||
|
release()
|
||||||
|
if err == nil {
|
||||||
|
return relative, nil
|
||||||
|
}
|
||||||
|
if s.debug {
|
||||||
|
log.Printf("[SoraStorage] 下载失败(%d/%d): %s err=%v", attempt, retries, sanitizeSoraLogURL(rawURL), err)
|
||||||
|
}
|
||||||
|
if attempt < retries {
|
||||||
|
time.Sleep(time.Duration(attempt*attempt) * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "", errors.New("download retries exhausted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, rawURL string) (string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
ext := fileExtFromURL(rawURL)
|
||||||
|
if ext == "" {
|
||||||
|
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
if ext == "" {
|
||||||
|
ext = ".bin"
|
||||||
|
}
|
||||||
|
|
||||||
|
datePath := time.Now().Format("2006/01/02")
|
||||||
|
destDir := filepath.Join(root, filepath.FromSlash(datePath))
|
||||||
|
if err := os.MkdirAll(destDir, 0o755); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
filename := uuid.NewString() + ext
|
||||||
|
destPath := filepath.Join(destDir, filename)
|
||||||
|
out, err := os.Create(destPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() { _ = out.Close() }()
|
||||||
|
|
||||||
|
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||||
|
_ = os.Remove(destPath)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
relative := path.Join("/", mediaType, datePath, filename)
|
||||||
|
if s.debug {
|
||||||
|
log.Printf("[SoraStorage] 已落地 %s -> %s", sanitizeSoraLogURL(rawURL), relative)
|
||||||
|
}
|
||||||
|
return relative, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraMediaStorage) acquire(ctx context.Context) (func(), error) {
|
||||||
|
if s.sem == nil {
|
||||||
|
return func() {}, nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case s.sem <- struct{}{}:
|
||||||
|
return func() { <-s.sem }, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExtFromURL(raw string) string {
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
ext := path.Ext(parsed.Path)
|
||||||
|
return strings.ToLower(ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExtFromContentType(ct string) string {
|
||||||
|
if ct == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if exts, err := mime.ExtensionsByType(ct); err == nil && len(exts) > 0 {
|
||||||
|
return strings.ToLower(exts[0])
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
69
backend/internal/service/sora_media_storage_test.go
Normal file
69
backend/internal/service/sora_media_storage_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSoraMediaStorage_StoreFromURLs(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "image/png")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("data"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Storage: config.SoraStorageConfig{
|
||||||
|
Type: "local",
|
||||||
|
LocalPath: tmpDir,
|
||||||
|
MaxConcurrentDownloads: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage := NewSoraMediaStorage(cfg)
|
||||||
|
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, urls, 1)
|
||||||
|
require.True(t, strings.HasPrefix(urls[0], "/image/"))
|
||||||
|
require.True(t, strings.HasSuffix(urls[0], ".png"))
|
||||||
|
|
||||||
|
localPath := filepath.Join(tmpDir, filepath.FromSlash(strings.TrimPrefix(urls[0], "/")))
|
||||||
|
require.FileExists(t, localPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Storage: config.SoraStorageConfig{
|
||||||
|
Type: "local",
|
||||||
|
LocalPath: tmpDir,
|
||||||
|
FallbackToUpstream: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage := NewSoraMediaStorage(cfg)
|
||||||
|
url := server.URL + "/broken.png"
|
||||||
|
urls, err := storage.StoreFromURLs(context.Background(), "image", []string{url})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{url}, urls)
|
||||||
|
}
|
||||||
252
backend/internal/service/sora_models.go
Normal file
252
backend/internal/service/sora_models.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SoraModelConfig Sora 模型配置
|
||||||
|
type SoraModelConfig struct {
|
||||||
|
Type string
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
|
Orientation string
|
||||||
|
Frames int
|
||||||
|
Model string
|
||||||
|
Size string
|
||||||
|
RequirePro bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraModelConfigs = map[string]SoraModelConfig{
|
||||||
|
"gpt-image": {
|
||||||
|
Type: "image",
|
||||||
|
Width: 360,
|
||||||
|
Height: 360,
|
||||||
|
},
|
||||||
|
"gpt-image-landscape": {
|
||||||
|
Type: "image",
|
||||||
|
Width: 540,
|
||||||
|
Height: 360,
|
||||||
|
},
|
||||||
|
"gpt-image-portrait": {
|
||||||
|
Type: "image",
|
||||||
|
Width: 360,
|
||||||
|
Height: 540,
|
||||||
|
},
|
||||||
|
"sora2-landscape-10s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 300,
|
||||||
|
Model: "sy_8",
|
||||||
|
Size: "small",
|
||||||
|
},
|
||||||
|
"sora2-portrait-10s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 300,
|
||||||
|
Model: "sy_8",
|
||||||
|
Size: "small",
|
||||||
|
},
|
||||||
|
"sora2-landscape-15s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 450,
|
||||||
|
Model: "sy_8",
|
||||||
|
Size: "small",
|
||||||
|
},
|
||||||
|
"sora2-portrait-15s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 450,
|
||||||
|
Model: "sy_8",
|
||||||
|
Size: "small",
|
||||||
|
},
|
||||||
|
"sora2-landscape-25s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 750,
|
||||||
|
Model: "sy_8",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2-portrait-25s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 750,
|
||||||
|
Model: "sy_8",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-landscape-10s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 300,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-portrait-10s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 300,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-landscape-15s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 450,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-portrait-15s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 450,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-landscape-25s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 750,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-portrait-25s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 750,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "small",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-hd-landscape-10s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 300,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "large",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-hd-portrait-10s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 300,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "large",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-hd-landscape-15s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "landscape",
|
||||||
|
Frames: 450,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "large",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"sora2pro-hd-portrait-15s": {
|
||||||
|
Type: "video",
|
||||||
|
Orientation: "portrait",
|
||||||
|
Frames: 450,
|
||||||
|
Model: "sy_ore",
|
||||||
|
Size: "large",
|
||||||
|
RequirePro: true,
|
||||||
|
},
|
||||||
|
"prompt-enhance-short-10s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-short-15s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-short-20s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-medium-10s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-medium-15s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-medium-20s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-long-10s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-long-15s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
"prompt-enhance-long-20s": {
|
||||||
|
Type: "prompt_enhance",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraModelIDs = []string{
|
||||||
|
"gpt-image",
|
||||||
|
"gpt-image-landscape",
|
||||||
|
"gpt-image-portrait",
|
||||||
|
"sora2-landscape-10s",
|
||||||
|
"sora2-portrait-10s",
|
||||||
|
"sora2-landscape-15s",
|
||||||
|
"sora2-portrait-15s",
|
||||||
|
"sora2-landscape-25s",
|
||||||
|
"sora2-portrait-25s",
|
||||||
|
"sora2pro-landscape-10s",
|
||||||
|
"sora2pro-portrait-10s",
|
||||||
|
"sora2pro-landscape-15s",
|
||||||
|
"sora2pro-portrait-15s",
|
||||||
|
"sora2pro-landscape-25s",
|
||||||
|
"sora2pro-portrait-25s",
|
||||||
|
"sora2pro-hd-landscape-10s",
|
||||||
|
"sora2pro-hd-portrait-10s",
|
||||||
|
"sora2pro-hd-landscape-15s",
|
||||||
|
"sora2pro-hd-portrait-15s",
|
||||||
|
"prompt-enhance-short-10s",
|
||||||
|
"prompt-enhance-short-15s",
|
||||||
|
"prompt-enhance-short-20s",
|
||||||
|
"prompt-enhance-medium-10s",
|
||||||
|
"prompt-enhance-medium-15s",
|
||||||
|
"prompt-enhance-medium-20s",
|
||||||
|
"prompt-enhance-long-10s",
|
||||||
|
"prompt-enhance-long-15s",
|
||||||
|
"prompt-enhance-long-20s",
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSoraModelConfig 返回 Sora 模型配置
|
||||||
|
func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
|
||||||
|
key := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
cfg, ok := soraModelConfigs[key]
|
||||||
|
return cfg, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultSoraModels returns the default Sora model list.
|
||||||
|
func DefaultSoraModels(cfg *config.Config) []openai.Model {
|
||||||
|
models := make([]openai.Model, 0, len(soraModelIDs))
|
||||||
|
for _, id := range soraModelIDs {
|
||||||
|
models = append(models, openai.Model{
|
||||||
|
ID: id,
|
||||||
|
Object: "model",
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "model",
|
||||||
|
DisplayName: id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if cfg != nil && cfg.Gateway.SoraModelFilters.HidePromptEnhance {
|
||||||
|
filtered := models[:0]
|
||||||
|
for _, model := range models {
|
||||||
|
if strings.HasPrefix(strings.ToLower(model.ID), "prompt-enhance") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, model)
|
||||||
|
}
|
||||||
|
models = filtered
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
@@ -63,16 +63,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraSyncService 设置 Sora2API 同步服务
|
|
||||||
// 需要在 Start() 之前调用
|
|
||||||
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
|
|
||||||
for _, refresher := range s.refreshers {
|
|
||||||
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
|
||||||
openaiRefresher.SetSoraSyncService(svc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start 启动后台刷新服务
|
// Start 启动后台刷新服务
|
||||||
func (s *TokenRefreshService) Start() {
|
func (s *TokenRefreshService) Start() {
|
||||||
if !s.cfg.Enabled {
|
if !s.cfg.Enabled {
|
||||||
|
|||||||
@@ -86,7 +86,6 @@ type OpenAITokenRefresher struct {
|
|||||||
openaiOAuthService *OpenAIOAuthService
|
openaiOAuthService *OpenAIOAuthService
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
||||||
soraSyncService *Sora2APISyncService // Sora2API 同步服务
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||||
@@ -104,11 +103,6 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
|
|||||||
r.soraAccountRepo = repo
|
r.soraAccountRepo = repo
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSoraSyncService 设置 Sora2API 同步服务
|
|
||||||
func (r *OpenAITokenRefresher) SetSoraSyncService(svc *Sora2APISyncService) {
|
|
||||||
r.soraSyncService = svc
|
|
||||||
}
|
|
||||||
|
|
||||||
// CanRefresh 检查是否能处理此账号
|
// CanRefresh 检查是否能处理此账号
|
||||||
// 只处理 openai 平台的 oauth 类型账号
|
// 只处理 openai 平台的 oauth 类型账号
|
||||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||||
@@ -151,17 +145,6 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
|
|||||||
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果是 Sora 平台账号,同步到 sora2api(不阻塞主流程)
|
|
||||||
if account.Platform == PlatformSora && r.soraSyncService != nil {
|
|
||||||
syncAccount := *account
|
|
||||||
syncAccount.Credentials = newCredentials
|
|
||||||
go func() {
|
|
||||||
if err := r.soraSyncService.SyncAccount(context.Background(), &syncAccount); err != nil {
|
|
||||||
log.Printf("[TokenSync] 同步 Sora2API 失败: account_id=%d err=%v", syncAccount.ID, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return newCredentials, nil
|
return newCredentials, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,13 +201,6 @@ func (r *OpenAITokenRefresher) syncLinkedSoraAccounts(ctx context.Context, opena
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2.3 同步到 sora2api(如果配置)
|
|
||||||
if r.soraSyncService != nil {
|
|
||||||
if err := r.soraSyncService.SyncAccount(ctx, &soraAccount); err != nil {
|
|
||||||
log.Printf("[TokenSync] 同步 sora2api 失败: account_id=%d err=%v", soraAccount.ID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
|
log.Printf("[TokenSync] 成功同步 Sora 账号 token: sora_account_id=%d openai_account_id=%d dual_table=%v",
|
||||||
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
|
soraAccount.ID, openaiAccountID, r.soraAccountRepo != nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
|||||||
func ProvideTokenRefreshService(
|
func ProvideTokenRefreshService(
|
||||||
accountRepo AccountRepository,
|
accountRepo AccountRepository,
|
||||||
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
|
soraAccountRepo SoraAccountRepository, // Sora 扩展表仓储,用于双表同步
|
||||||
soraSyncService *Sora2APISyncService,
|
|
||||||
oauthService *OAuthService,
|
oauthService *OAuthService,
|
||||||
openaiOAuthService *OpenAIOAuthService,
|
openaiOAuthService *OpenAIOAuthService,
|
||||||
geminiOAuthService *GeminiOAuthService,
|
geminiOAuthService *GeminiOAuthService,
|
||||||
@@ -51,7 +50,6 @@ func ProvideTokenRefreshService(
|
|||||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
||||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||||
svc.SetSoraSyncService(soraSyncService)
|
|
||||||
svc.Start()
|
svc.Start()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
@@ -187,6 +185,18 @@ func ProvideOpsCleanupService(
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideSoraMediaStorage 初始化 Sora 媒体存储
|
||||||
|
func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
|
||||||
|
return NewSoraMediaStorage(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
|
||||||
|
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
|
||||||
|
svc := NewSoraMediaCleanupService(storage, cfg)
|
||||||
|
svc.Start()
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
|
// ProvideOpsScheduledReportService creates and starts OpsScheduledReportService.
|
||||||
func ProvideOpsScheduledReportService(
|
func ProvideOpsScheduledReportService(
|
||||||
opsService *OpsService,
|
opsService *OpsService,
|
||||||
@@ -226,6 +236,10 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewBillingCacheService,
|
NewBillingCacheService,
|
||||||
NewAdminService,
|
NewAdminService,
|
||||||
NewGatewayService,
|
NewGatewayService,
|
||||||
|
ProvideSoraMediaStorage,
|
||||||
|
ProvideSoraMediaCleanupService,
|
||||||
|
NewSoraDirectClient,
|
||||||
|
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
|
||||||
NewSoraGatewayService,
|
NewSoraGatewayService,
|
||||||
NewOpenAIGatewayService,
|
NewOpenAIGatewayService,
|
||||||
NewOAuthService,
|
NewOAuthService,
|
||||||
|
|||||||
8
build_image.sh
Executable file
8
build_image.sh
Executable file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# 本地构建镜像的快速脚本,避免在命令行反复输入构建参数。
|
||||||
|
|
||||||
|
docker build -t sub2api:latest \
|
||||||
|
--build-arg GOPROXY=https://goproxy.cn,direct \
|
||||||
|
--build-arg GOSUMDB=sum.golang.google.cn \
|
||||||
|
-f Dockerfile \
|
||||||
|
.
|
||||||
111
deploy/Dockerfile
Normal file
111
deploy/Dockerfile
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# =============================================================================
|
||||||
|
# Sub2API Multi-Stage Dockerfile
|
||||||
|
# =============================================================================
|
||||||
|
# Stage 1: Build frontend
|
||||||
|
# Stage 2: Build Go backend with embedded frontend
|
||||||
|
# Stage 3: Final minimal image
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
ARG NODE_IMAGE=node:24-alpine
|
||||||
|
ARG GOLANG_IMAGE=golang:1.25.5-alpine
|
||||||
|
ARG ALPINE_IMAGE=alpine:3.20
|
||||||
|
ARG GOPROXY=https://goproxy.cn,direct
|
||||||
|
ARG GOSUMDB=sum.golang.google.cn
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Stage 1: Frontend Builder
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
FROM ${NODE_IMAGE} AS frontend-builder
|
||||||
|
|
||||||
|
WORKDIR /app/frontend
|
||||||
|
|
||||||
|
# Install pnpm
|
||||||
|
RUN corepack enable && corepack prepare pnpm@latest --activate
|
||||||
|
|
||||||
|
# Install dependencies first (better caching)
|
||||||
|
COPY frontend/package.json frontend/pnpm-lock.yaml ./
|
||||||
|
RUN pnpm install --frozen-lockfile
|
||||||
|
|
||||||
|
# Copy frontend source and build
|
||||||
|
COPY frontend/ ./
|
||||||
|
RUN pnpm run build
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Stage 2: Backend Builder
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
FROM ${GOLANG_IMAGE} AS backend-builder
|
||||||
|
|
||||||
|
# Build arguments for version info (set by CI)
|
||||||
|
ARG VERSION=docker
|
||||||
|
ARG COMMIT=docker
|
||||||
|
ARG DATE
|
||||||
|
ARG GOPROXY
|
||||||
|
ARG GOSUMDB
|
||||||
|
|
||||||
|
ENV GOPROXY=${GOPROXY}
|
||||||
|
ENV GOSUMDB=${GOSUMDB}
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apk add --no-cache git ca-certificates tzdata
|
||||||
|
|
||||||
|
WORKDIR /app/backend
|
||||||
|
|
||||||
|
# Copy go mod files first (better caching)
|
||||||
|
COPY backend/go.mod backend/go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
# Copy backend source first
|
||||||
|
COPY backend/ ./
|
||||||
|
|
||||||
|
# Copy frontend dist from previous stage (must be after backend copy to avoid being overwritten)
|
||||||
|
COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist
|
||||||
|
|
||||||
|
# Build the binary (BuildType=release for CI builds, embed frontend)
|
||||||
|
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||||
|
-tags embed \
|
||||||
|
-ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \
|
||||||
|
-o /app/sub2api \
|
||||||
|
./cmd/server
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Stage 3: Final Runtime Image
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
FROM ${ALPINE_IMAGE}
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
LABEL maintainer="Wei-Shaw <github.com/Wei-Shaw>"
|
||||||
|
LABEL description="Sub2API - AI API Gateway Platform"
|
||||||
|
LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
|
||||||
|
|
||||||
|
# Install runtime dependencies
|
||||||
|
RUN apk add --no-cache \
|
||||||
|
ca-certificates \
|
||||||
|
tzdata \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/cache/apk/*
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN addgroup -g 1000 sub2api && \
|
||||||
|
adduser -u 1000 -G sub2api -s /bin/sh -D sub2api
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy binary from builder
|
||||||
|
COPY --from=backend-builder /app/sub2api /app/sub2api
|
||||||
|
|
||||||
|
# Create data directory
|
||||||
|
RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER sub2api
|
||||||
|
|
||||||
|
# Expose port (can be overridden by SERVER_PORT env var)
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
|
||||||
|
|
||||||
|
# Run the application
|
||||||
|
ENTRYPOINT ["/app/sub2api"]
|
||||||
@@ -249,32 +249,64 @@ gateway:
|
|||||||
# name: "Custom Profile 2"
|
# name: "Custom Profile 2"
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Sora2API Configuration
|
# Sora Direct Client Configuration
|
||||||
# Sora2API 配置
|
# Sora 直连配置
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
sora2api:
|
sora:
|
||||||
# Sora2API base URL
|
client:
|
||||||
# Sora2API 服务地址
|
# Sora backend base URL
|
||||||
base_url: "http://127.0.0.1:8000"
|
# Sora 上游 Base URL
|
||||||
# Sora2API API Key (for /v1/chat/completions and /v1/models)
|
base_url: "https://sora.chatgpt.com/backend"
|
||||||
# Sora2API API Key(用于生成/模型列表)
|
# Request timeout (seconds)
|
||||||
api_key: ""
|
# 请求超时(秒)
|
||||||
# Admin username/password (for token sync)
|
timeout_seconds: 120
|
||||||
# 管理口用户名/密码(用于 token 同步)
|
# Max retries for upstream requests
|
||||||
admin_username: "admin"
|
# 上游请求最大重试次数
|
||||||
admin_password: "admin"
|
max_retries: 3
|
||||||
# Admin token cache ttl (seconds)
|
# Poll interval (seconds)
|
||||||
# 管理口 token 缓存时长(秒)
|
# 轮询间隔(秒)
|
||||||
admin_token_ttl_seconds: 900
|
poll_interval_seconds: 2
|
||||||
# Admin request timeout (seconds)
|
# Max poll attempts
|
||||||
# 管理口请求超时(秒)
|
# 最大轮询次数
|
||||||
admin_timeout_seconds: 10
|
max_poll_attempts: 600
|
||||||
# Token import mode: at/offline
|
# Enable debug logs for Sora upstream requests
|
||||||
# Token 导入模式:at/offline
|
# 启用 Sora 直连调试日志
|
||||||
token_import_mode: "at"
|
debug: false
|
||||||
# cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196]
|
# Optional custom headers (key-value)
|
||||||
# curves: [29, 23, 24]
|
# 额外请求头(键值对)
|
||||||
# point_formats: [0]
|
headers: {}
|
||||||
|
# Default User-Agent for Sora requests
|
||||||
|
# Sora 默认 User-Agent
|
||||||
|
user_agent: "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
|
||||||
|
# Disable TLS fingerprint for Sora upstream
|
||||||
|
# 关闭 Sora 上游 TLS 指纹伪装
|
||||||
|
disable_tls_fingerprint: false
|
||||||
|
storage:
|
||||||
|
# Storage type (local only for now)
|
||||||
|
# 存储类型(首发仅支持 local)
|
||||||
|
type: "local"
|
||||||
|
# Local base path; empty uses /app/data/sora
|
||||||
|
# 本地存储基础路径;为空使用 /app/data/sora
|
||||||
|
local_path: ""
|
||||||
|
# Fallback to upstream URL when download fails
|
||||||
|
# 下载失败时回退到上游 URL
|
||||||
|
fallback_to_upstream: true
|
||||||
|
# Max concurrent downloads
|
||||||
|
# 并发下载上限
|
||||||
|
max_concurrent_downloads: 4
|
||||||
|
# Enable debug logs for media storage
|
||||||
|
# 启用媒体存储调试日志
|
||||||
|
debug: false
|
||||||
|
cleanup:
|
||||||
|
# Enable cleanup task
|
||||||
|
# 启用清理任务
|
||||||
|
enabled: true
|
||||||
|
# Retention days
|
||||||
|
# 保留天数
|
||||||
|
retention_days: 7
|
||||||
|
# Cron schedule
|
||||||
|
# Cron 调度表达式
|
||||||
|
schedule: "0 3 * * *"
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# API Key Auth Cache Configuration
|
# API Key Auth Cache Configuration
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import geminiAPI from './gemini'
|
|||||||
import antigravityAPI from './antigravity'
|
import antigravityAPI from './antigravity'
|
||||||
import userAttributesAPI from './userAttributes'
|
import userAttributesAPI from './userAttributes'
|
||||||
import opsAPI from './ops'
|
import opsAPI from './ops'
|
||||||
import modelsAPI from './models'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unified admin API object for convenient access
|
* Unified admin API object for convenient access
|
||||||
@@ -38,8 +37,7 @@ export const adminAPI = {
|
|||||||
gemini: geminiAPI,
|
gemini: geminiAPI,
|
||||||
antigravity: antigravityAPI,
|
antigravity: antigravityAPI,
|
||||||
userAttributes: userAttributesAPI,
|
userAttributes: userAttributesAPI,
|
||||||
ops: opsAPI,
|
ops: opsAPI
|
||||||
models: modelsAPI
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export {
|
export {
|
||||||
@@ -57,8 +55,7 @@ export {
|
|||||||
geminiAPI,
|
geminiAPI,
|
||||||
antigravityAPI,
|
antigravityAPI,
|
||||||
userAttributesAPI,
|
userAttributesAPI,
|
||||||
opsAPI,
|
opsAPI
|
||||||
modelsAPI
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export default adminAPI
|
export default adminAPI
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
import { apiClient } from '@/api/client'
|
|
||||||
|
|
||||||
export async function getPlatformModels(platform: string): Promise<string[]> {
|
|
||||||
const { data } = await apiClient.get<string[]>('/admin/models', {
|
|
||||||
params: { platform }
|
|
||||||
})
|
|
||||||
return data
|
|
||||||
}
|
|
||||||
|
|
||||||
export const modelsAPI = {
|
|
||||||
getPlatformModels
|
|
||||||
}
|
|
||||||
|
|
||||||
export default modelsAPI
|
|
||||||
@@ -1501,9 +1501,9 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<label class="switch">
|
<label :class="['switch', { 'switch-active': enableSoraOnOpenAIOAuth }]">
|
||||||
<input type="checkbox" v-model="enableSoraOnOpenAIOAuth" />
|
<input type="checkbox" v-model="enableSoraOnOpenAIOAuth" class="sr-only" />
|
||||||
<span class="slider"></span>
|
<span class="switch-thumb"></span>
|
||||||
</label>
|
</label>
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -45,19 +45,6 @@
|
|||||||
:placeholder="t('admin.accounts.searchModels')"
|
:placeholder="t('admin.accounts.searchModels')"
|
||||||
@click.stop
|
@click.stop
|
||||||
/>
|
/>
|
||||||
<div v-if="props.platform === 'sora'" class="mt-2 flex items-center gap-2 text-xs">
|
|
||||||
<span v-if="loadingSoraModels" class="text-gray-500">
|
|
||||||
{{ t('admin.accounts.soraModelsLoading') }}
|
|
||||||
</span>
|
|
||||||
<button
|
|
||||||
v-else-if="soraLoadError"
|
|
||||||
type="button"
|
|
||||||
class="text-primary-600 hover:underline dark:text-primary-400"
|
|
||||||
@click.stop="loadSoraModels"
|
|
||||||
>
|
|
||||||
{{ t('admin.accounts.soraModelsRetry') }}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="max-h-52 overflow-auto">
|
<div class="max-h-52 overflow-auto">
|
||||||
<button
|
<button
|
||||||
@@ -133,13 +120,12 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, watch } from 'vue'
|
import { ref, computed } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import ModelIcon from '@/components/common/ModelIcon.vue'
|
import ModelIcon from '@/components/common/ModelIcon.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import { allModels, getModelsByPlatform } from '@/composables/useModelWhitelist'
|
import { allModels, getModelsByPlatform } from '@/composables/useModelWhitelist'
|
||||||
import { adminAPI } from '@/api/admin'
|
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
@@ -158,15 +144,8 @@ const showDropdown = ref(false)
|
|||||||
const searchQuery = ref('')
|
const searchQuery = ref('')
|
||||||
const customModel = ref('')
|
const customModel = ref('')
|
||||||
const isComposing = ref(false)
|
const isComposing = ref(false)
|
||||||
const soraModelOptions = ref<{ value: string; label: string }[]>([])
|
|
||||||
const loadingSoraModels = ref(false)
|
|
||||||
const soraLoadError = ref(false)
|
|
||||||
|
|
||||||
const availableOptions = computed(() => {
|
const availableOptions = computed(() => {
|
||||||
if (props.platform === 'sora') {
|
if (props.platform === 'sora') {
|
||||||
if (soraModelOptions.value.length > 0) {
|
|
||||||
return soraModelOptions.value
|
|
||||||
}
|
|
||||||
return getModelsByPlatform('sora').map(m => ({ value: m, label: m }))
|
return getModelsByPlatform('sora').map(m => ({ value: m, label: m }))
|
||||||
}
|
}
|
||||||
return allModels
|
return allModels
|
||||||
@@ -213,9 +192,7 @@ const handleEnter = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const fillRelated = () => {
|
const fillRelated = () => {
|
||||||
const models = props.platform === 'sora' && soraModelOptions.value.length > 0
|
const models = getModelsByPlatform(props.platform)
|
||||||
? soraModelOptions.value.map(m => m.value)
|
|
||||||
: getModelsByPlatform(props.platform)
|
|
||||||
const newModels = [...props.modelValue]
|
const newModels = [...props.modelValue]
|
||||||
for (const model of models) {
|
for (const model of models) {
|
||||||
if (!newModels.includes(model)) newModels.push(model)
|
if (!newModels.includes(model)) newModels.push(model)
|
||||||
@@ -227,31 +204,4 @@ const clearAll = () => {
|
|||||||
emit('update:modelValue', [])
|
emit('update:modelValue', [])
|
||||||
}
|
}
|
||||||
|
|
||||||
const loadSoraModels = async () => {
|
|
||||||
if (props.platform !== 'sora') {
|
|
||||||
soraModelOptions.value = []
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if (loadingSoraModels.value) return
|
|
||||||
soraLoadError.value = false
|
|
||||||
loadingSoraModels.value = true
|
|
||||||
try {
|
|
||||||
const models = await adminAPI.models.getPlatformModels('sora')
|
|
||||||
soraModelOptions.value = (models || []).map((m) => ({ value: m, label: m }))
|
|
||||||
} catch (error) {
|
|
||||||
console.warn('加载 Sora 模型列表失败', error)
|
|
||||||
soraLoadError.value = true
|
|
||||||
appStore.showWarning(t('admin.accounts.soraModelsLoadFailed'))
|
|
||||||
} finally {
|
|
||||||
loadingSoraModels.value = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
watch(
|
|
||||||
() => props.platform,
|
|
||||||
() => {
|
|
||||||
loadSoraModels()
|
|
||||||
},
|
|
||||||
{ immediate: true }
|
|
||||||
)
|
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -416,6 +416,7 @@ import { useI18n } from 'vue-i18n'
|
|||||||
import { useClipboard } from '@/composables/useClipboard'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import type { AddMethod, AuthInputMethod } from '@/composables/useAccountOAuth'
|
import type { AddMethod, AuthInputMethod } from '@/composables/useAccountOAuth'
|
||||||
|
import type { AccountPlatform } from '@/types'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
addMethod: AddMethod
|
addMethod: AddMethod
|
||||||
@@ -428,7 +429,7 @@ interface Props {
|
|||||||
allowMultiple?: boolean
|
allowMultiple?: boolean
|
||||||
methodLabel?: string
|
methodLabel?: string
|
||||||
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
showCookieOption?: boolean // Whether to show cookie auto-auth option
|
||||||
platform?: 'anthropic' | 'openai' | 'gemini' | 'antigravity' // Platform type for different UI/text
|
platform?: AccountPlatform // Platform type for different UI/text
|
||||||
showProjectId?: boolean // New prop to control project ID visibility
|
showProjectId?: boolean // New prop to control project ID visibility
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,11 +456,11 @@ const emit = defineEmits<{
|
|||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
const isOpenAI = computed(() => props.platform === 'openai')
|
const isOpenAI = computed(() => props.platform === 'openai' || props.platform === 'sora')
|
||||||
|
|
||||||
// Get translation key based on platform
|
// Get translation key based on platform
|
||||||
const getOAuthKey = (key: string) => {
|
const getOAuthKey = (key: string) => {
|
||||||
if (props.platform === 'openai') return `admin.accounts.oauth.openai.${key}`
|
if (props.platform === 'openai' || props.platform === 'sora') return `admin.accounts.oauth.openai.${key}`
|
||||||
if (props.platform === 'gemini') return `admin.accounts.oauth.gemini.${key}`
|
if (props.platform === 'gemini') return `admin.accounts.oauth.gemini.${key}`
|
||||||
if (props.platform === 'antigravity') return `admin.accounts.oauth.antigravity.${key}`
|
if (props.platform === 'antigravity') return `admin.accounts.oauth.antigravity.${key}`
|
||||||
return `admin.accounts.oauth.${key}`
|
return `admin.accounts.oauth.${key}`
|
||||||
@@ -478,7 +479,7 @@ const oauthAuthCode = computed(() => t(getOAuthKey('authCode')))
|
|||||||
const oauthAuthCodePlaceholder = computed(() => t(getOAuthKey('authCodePlaceholder')))
|
const oauthAuthCodePlaceholder = computed(() => t(getOAuthKey('authCodePlaceholder')))
|
||||||
const oauthAuthCodeHint = computed(() => t(getOAuthKey('authCodeHint')))
|
const oauthAuthCodeHint = computed(() => t(getOAuthKey('authCodeHint')))
|
||||||
const oauthImportantNotice = computed(() => {
|
const oauthImportantNotice = computed(() => {
|
||||||
if (props.platform === 'openai') return t('admin.accounts.oauth.openai.importantNotice')
|
if (props.platform === 'openai' || props.platform === 'sora') return t('admin.accounts.oauth.openai.importantNotice')
|
||||||
if (props.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.importantNotice')
|
if (props.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.importantNotice')
|
||||||
return ''
|
return ''
|
||||||
})
|
})
|
||||||
@@ -510,7 +511,7 @@ watch(inputMethod, (newVal) => {
|
|||||||
// Auto-extract code from callback URL (OpenAI/Gemini/Antigravity)
|
// Auto-extract code from callback URL (OpenAI/Gemini/Antigravity)
|
||||||
// e.g., http://localhost:8085/callback?code=xxx...&state=...
|
// e.g., http://localhost:8085/callback?code=xxx...&state=...
|
||||||
watch(authCodeInput, (newVal) => {
|
watch(authCodeInput, (newVal) => {
|
||||||
if (props.platform !== 'openai' && props.platform !== 'gemini' && props.platform !== 'antigravity') return
|
if (props.platform !== 'openai' && props.platform !== 'gemini' && props.platform !== 'antigravity' && props.platform !== 'sora') return
|
||||||
|
|
||||||
const trimmed = newVal.trim()
|
const trimmed = newVal.trim()
|
||||||
// Check if it looks like a URL with code parameter
|
// Check if it looks like a URL with code parameter
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ const geminiModels = [
|
|||||||
'gemini-3-pro-preview'
|
'gemini-3-pro-preview'
|
||||||
]
|
]
|
||||||
|
|
||||||
// Sora (sora2api)
|
// Sora
|
||||||
const soraModels = [
|
const soraModels = [
|
||||||
'gpt-image', 'gpt-image-landscape', 'gpt-image-portrait',
|
'gpt-image', 'gpt-image-landscape', 'gpt-image-portrait',
|
||||||
'sora2-landscape-10s', 'sora2-portrait-10s',
|
'sora2-landscape-10s', 'sora2-portrait-10s',
|
||||||
|
|||||||
@@ -1363,11 +1363,6 @@ const createForm = reactive({
|
|||||||
sora_image_price_540: null as number | null,
|
sora_image_price_540: null as number | null,
|
||||||
sora_video_price_per_request: null as number | null,
|
sora_video_price_per_request: null as number | null,
|
||||||
sora_video_price_per_request_hd: null as number | null,
|
sora_video_price_per_request_hd: null as number | null,
|
||||||
// Sora 按次计费配置
|
|
||||||
sora_image_price_360: null as number | null,
|
|
||||||
sora_image_price_540: null as number | null,
|
|
||||||
sora_video_price_per_request: null as number | null,
|
|
||||||
sora_video_price_per_request_hd: null as number | null,
|
|
||||||
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
||||||
claude_code_only: false,
|
claude_code_only: false,
|
||||||
fallback_group_id: null as number | null,
|
fallback_group_id: null as number | null,
|
||||||
|
|||||||
Reference in New Issue
Block a user