feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题
新增功能: - 新增 Sora 账号管理和 OAuth 认证 - 新增 Sora 视频/图片生成 API 网关 - 新增 Sora 任务调度和缓存机制 - 新增 Sora 使用统计和计费支持 - 前端增加 Sora 平台配置界面 安全修复(代码审核): - [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击 - [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽 - [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置 BUG 修复(代码审核): - [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏 - [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏 性能优化(代码审核): - [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销 技术细节: - 使用 io.LimitReader 限制所有外部输入的大小 - 添加 urlvalidator 验证防止 SSRF 攻击 - 使用 sync.Map 实现线程安全的包级缓存 - 优化并发槽位管理,添加 releaseAll 模式防止泄漏 影响范围: - 后端:新增 Sora 相关数据模型、服务、网关和管理接口 - 前端:新增 Sora 平台配置、账号管理和监控界面 - 配置:新增 Sora 相关配置项和环境变量 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -58,6 +58,7 @@ type Config struct {
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Sora SoraConfig `mapstructure:"sora"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
@@ -69,6 +70,38 @@ type GeminiConfig struct {
|
||||
Quota GeminiQuotaConfig `mapstructure:"quota"`
|
||||
}
|
||||
|
||||
type SoraConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Timeout int `mapstructure:"timeout"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
PollInterval float64 `mapstructure:"poll_interval"`
|
||||
CallLogicMode string `mapstructure:"call_logic_mode"`
|
||||
Cache SoraCacheConfig `mapstructure:"cache"`
|
||||
WatermarkFree SoraWatermarkFreeConfig `mapstructure:"watermark_free"`
|
||||
TokenRefresh SoraTokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
}
|
||||
|
||||
type SoraCacheConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
BaseDir string `mapstructure:"base_dir"`
|
||||
VideoDir string `mapstructure:"video_dir"`
|
||||
MaxBytes int64 `mapstructure:"max_bytes"`
|
||||
AllowedHosts []string `mapstructure:"allowed_hosts"`
|
||||
UserDirEnabled bool `mapstructure:"user_dir_enabled"`
|
||||
}
|
||||
|
||||
type SoraWatermarkFreeConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ParseMethod string `mapstructure:"parse_method"`
|
||||
CustomParseURL string `mapstructure:"custom_parse_url"`
|
||||
CustomParseToken string `mapstructure:"custom_parse_token"`
|
||||
FallbackOnFailure bool `mapstructure:"fallback_on_failure"`
|
||||
}
|
||||
|
||||
type SoraTokenRefreshConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type GeminiOAuthConfig struct {
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
@@ -862,6 +895,24 @@ func setDefaults() {
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
|
||||
viper.SetDefault("sora.base_url", "https://sora.chatgpt.com/backend")
|
||||
viper.SetDefault("sora.timeout", 120)
|
||||
viper.SetDefault("sora.max_retries", 3)
|
||||
viper.SetDefault("sora.poll_interval", 2.5)
|
||||
viper.SetDefault("sora.call_logic_mode", "default")
|
||||
viper.SetDefault("sora.cache.enabled", false)
|
||||
viper.SetDefault("sora.cache.base_dir", "tmp/sora")
|
||||
viper.SetDefault("sora.cache.video_dir", "data/video")
|
||||
viper.SetDefault("sora.cache.max_bytes", int64(0))
|
||||
viper.SetDefault("sora.cache.allowed_hosts", []string{})
|
||||
viper.SetDefault("sora.cache.user_dir_enabled", true)
|
||||
viper.SetDefault("sora.watermark_free.enabled", false)
|
||||
viper.SetDefault("sora.watermark_free.parse_method", "third_party")
|
||||
viper.SetDefault("sora.watermark_free.custom_parse_url", "")
|
||||
viper.SetDefault("sora.watermark_free.custom_parse_token", "")
|
||||
viper.SetDefault("sora.watermark_free.fallback_on_failure", true)
|
||||
viper.SetDefault("sora.token_refresh.enabled", false)
|
||||
|
||||
// Gemini OAuth - configure via environment variables or config file
|
||||
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
|
||||
// Default: uses Gemini CLI public credentials (set via environment)
|
||||
|
||||
@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -49,7 +49,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
|
||||
@@ -79,6 +79,23 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
SoraBaseURL: settings.SoraBaseURL,
|
||||
SoraTimeout: settings.SoraTimeout,
|
||||
SoraMaxRetries: settings.SoraMaxRetries,
|
||||
SoraPollInterval: settings.SoraPollInterval,
|
||||
SoraCallLogicMode: settings.SoraCallLogicMode,
|
||||
SoraCacheEnabled: settings.SoraCacheEnabled,
|
||||
SoraCacheBaseDir: settings.SoraCacheBaseDir,
|
||||
SoraCacheVideoDir: settings.SoraCacheVideoDir,
|
||||
SoraCacheMaxBytes: settings.SoraCacheMaxBytes,
|
||||
SoraCacheAllowedHosts: settings.SoraCacheAllowedHosts,
|
||||
SoraCacheUserDirEnabled: settings.SoraCacheUserDirEnabled,
|
||||
SoraWatermarkFreeEnabled: settings.SoraWatermarkFreeEnabled,
|
||||
SoraWatermarkFreeParseMethod: settings.SoraWatermarkFreeParseMethod,
|
||||
SoraWatermarkFreeCustomParseURL: settings.SoraWatermarkFreeCustomParseURL,
|
||||
SoraWatermarkFreeCustomParseToken: settings.SoraWatermarkFreeCustomParseToken,
|
||||
SoraWatermarkFreeFallbackOnFailure: settings.SoraWatermarkFreeFallbackOnFailure,
|
||||
SoraTokenRefreshEnabled: settings.SoraTokenRefreshEnabled,
|
||||
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||
@@ -138,6 +155,25 @@ type UpdateSettingsRequest struct {
|
||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||
|
||||
// Sora configuration
|
||||
SoraBaseURL string `json:"sora_base_url"`
|
||||
SoraTimeout int `json:"sora_timeout"`
|
||||
SoraMaxRetries int `json:"sora_max_retries"`
|
||||
SoraPollInterval float64 `json:"sora_poll_interval"`
|
||||
SoraCallLogicMode string `json:"sora_call_logic_mode"`
|
||||
SoraCacheEnabled bool `json:"sora_cache_enabled"`
|
||||
SoraCacheBaseDir string `json:"sora_cache_base_dir"`
|
||||
SoraCacheVideoDir string `json:"sora_cache_video_dir"`
|
||||
SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"`
|
||||
SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"`
|
||||
SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"`
|
||||
SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"`
|
||||
SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"`
|
||||
SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"`
|
||||
SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"`
|
||||
SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"`
|
||||
SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"`
|
||||
|
||||
// Ops monitoring (vNext)
|
||||
OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"`
|
||||
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
|
||||
@@ -227,6 +263,32 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Sora 参数校验与清理
|
||||
req.SoraBaseURL = strings.TrimSpace(req.SoraBaseURL)
|
||||
if req.SoraBaseURL == "" {
|
||||
req.SoraBaseURL = previousSettings.SoraBaseURL
|
||||
}
|
||||
if req.SoraBaseURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.SoraBaseURL); err != nil {
|
||||
response.BadRequest(c, "Sora Base URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.SoraTimeout <= 0 {
|
||||
req.SoraTimeout = previousSettings.SoraTimeout
|
||||
}
|
||||
if req.SoraMaxRetries < 0 {
|
||||
req.SoraMaxRetries = previousSettings.SoraMaxRetries
|
||||
}
|
||||
if req.SoraPollInterval <= 0 {
|
||||
req.SoraPollInterval = previousSettings.SoraPollInterval
|
||||
}
|
||||
if req.SoraCacheMaxBytes < 0 {
|
||||
req.SoraCacheMaxBytes = 0
|
||||
}
|
||||
req.SoraCacheAllowedHosts = normalizeStringList(req.SoraCacheAllowedHosts)
|
||||
req.SoraWatermarkFreeCustomParseURL = strings.TrimSpace(req.SoraWatermarkFreeCustomParseURL)
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -240,40 +302,57 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
SoraBaseURL: req.SoraBaseURL,
|
||||
SoraTimeout: req.SoraTimeout,
|
||||
SoraMaxRetries: req.SoraMaxRetries,
|
||||
SoraPollInterval: req.SoraPollInterval,
|
||||
SoraCallLogicMode: req.SoraCallLogicMode,
|
||||
SoraCacheEnabled: req.SoraCacheEnabled,
|
||||
SoraCacheBaseDir: req.SoraCacheBaseDir,
|
||||
SoraCacheVideoDir: req.SoraCacheVideoDir,
|
||||
SoraCacheMaxBytes: req.SoraCacheMaxBytes,
|
||||
SoraCacheAllowedHosts: req.SoraCacheAllowedHosts,
|
||||
SoraCacheUserDirEnabled: req.SoraCacheUserDirEnabled,
|
||||
SoraWatermarkFreeEnabled: req.SoraWatermarkFreeEnabled,
|
||||
SoraWatermarkFreeParseMethod: req.SoraWatermarkFreeParseMethod,
|
||||
SoraWatermarkFreeCustomParseURL: req.SoraWatermarkFreeCustomParseURL,
|
||||
SoraWatermarkFreeCustomParseToken: req.SoraWatermarkFreeCustomParseToken,
|
||||
SoraWatermarkFreeFallbackOnFailure: req.SoraWatermarkFreeFallbackOnFailure,
|
||||
SoraTokenRefreshEnabled: req.SoraTokenRefreshEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -349,6 +428,23 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
SoraBaseURL: updatedSettings.SoraBaseURL,
|
||||
SoraTimeout: updatedSettings.SoraTimeout,
|
||||
SoraMaxRetries: updatedSettings.SoraMaxRetries,
|
||||
SoraPollInterval: updatedSettings.SoraPollInterval,
|
||||
SoraCallLogicMode: updatedSettings.SoraCallLogicMode,
|
||||
SoraCacheEnabled: updatedSettings.SoraCacheEnabled,
|
||||
SoraCacheBaseDir: updatedSettings.SoraCacheBaseDir,
|
||||
SoraCacheVideoDir: updatedSettings.SoraCacheVideoDir,
|
||||
SoraCacheMaxBytes: updatedSettings.SoraCacheMaxBytes,
|
||||
SoraCacheAllowedHosts: updatedSettings.SoraCacheAllowedHosts,
|
||||
SoraCacheUserDirEnabled: updatedSettings.SoraCacheUserDirEnabled,
|
||||
SoraWatermarkFreeEnabled: updatedSettings.SoraWatermarkFreeEnabled,
|
||||
SoraWatermarkFreeParseMethod: updatedSettings.SoraWatermarkFreeParseMethod,
|
||||
SoraWatermarkFreeCustomParseURL: updatedSettings.SoraWatermarkFreeCustomParseURL,
|
||||
SoraWatermarkFreeCustomParseToken: updatedSettings.SoraWatermarkFreeCustomParseToken,
|
||||
SoraWatermarkFreeFallbackOnFailure: updatedSettings.SoraWatermarkFreeFallbackOnFailure,
|
||||
SoraTokenRefreshEnabled: updatedSettings.SoraTokenRefreshEnabled,
|
||||
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||
@@ -477,6 +573,57 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||
changed = append(changed, "identity_patch_prompt")
|
||||
}
|
||||
if before.SoraBaseURL != after.SoraBaseURL {
|
||||
changed = append(changed, "sora_base_url")
|
||||
}
|
||||
if before.SoraTimeout != after.SoraTimeout {
|
||||
changed = append(changed, "sora_timeout")
|
||||
}
|
||||
if before.SoraMaxRetries != after.SoraMaxRetries {
|
||||
changed = append(changed, "sora_max_retries")
|
||||
}
|
||||
if before.SoraPollInterval != after.SoraPollInterval {
|
||||
changed = append(changed, "sora_poll_interval")
|
||||
}
|
||||
if before.SoraCallLogicMode != after.SoraCallLogicMode {
|
||||
changed = append(changed, "sora_call_logic_mode")
|
||||
}
|
||||
if before.SoraCacheEnabled != after.SoraCacheEnabled {
|
||||
changed = append(changed, "sora_cache_enabled")
|
||||
}
|
||||
if before.SoraCacheBaseDir != after.SoraCacheBaseDir {
|
||||
changed = append(changed, "sora_cache_base_dir")
|
||||
}
|
||||
if before.SoraCacheVideoDir != after.SoraCacheVideoDir {
|
||||
changed = append(changed, "sora_cache_video_dir")
|
||||
}
|
||||
if before.SoraCacheMaxBytes != after.SoraCacheMaxBytes {
|
||||
changed = append(changed, "sora_cache_max_bytes")
|
||||
}
|
||||
if strings.Join(before.SoraCacheAllowedHosts, ",") != strings.Join(after.SoraCacheAllowedHosts, ",") {
|
||||
changed = append(changed, "sora_cache_allowed_hosts")
|
||||
}
|
||||
if before.SoraCacheUserDirEnabled != after.SoraCacheUserDirEnabled {
|
||||
changed = append(changed, "sora_cache_user_dir_enabled")
|
||||
}
|
||||
if before.SoraWatermarkFreeEnabled != after.SoraWatermarkFreeEnabled {
|
||||
changed = append(changed, "sora_watermark_free_enabled")
|
||||
}
|
||||
if before.SoraWatermarkFreeParseMethod != after.SoraWatermarkFreeParseMethod {
|
||||
changed = append(changed, "sora_watermark_free_parse_method")
|
||||
}
|
||||
if before.SoraWatermarkFreeCustomParseURL != after.SoraWatermarkFreeCustomParseURL {
|
||||
changed = append(changed, "sora_watermark_free_custom_parse_url")
|
||||
}
|
||||
if before.SoraWatermarkFreeCustomParseToken != after.SoraWatermarkFreeCustomParseToken {
|
||||
changed = append(changed, "sora_watermark_free_custom_parse_token")
|
||||
}
|
||||
if before.SoraWatermarkFreeFallbackOnFailure != after.SoraWatermarkFreeFallbackOnFailure {
|
||||
changed = append(changed, "sora_watermark_free_fallback_on_failure")
|
||||
}
|
||||
if before.SoraTokenRefreshEnabled != after.SoraTokenRefreshEnabled {
|
||||
changed = append(changed, "sora_token_refresh_enabled")
|
||||
}
|
||||
if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled {
|
||||
changed = append(changed, "ops_monitoring_enabled")
|
||||
}
|
||||
@@ -492,6 +639,19 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeStringList(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
|
||||
355
backend/internal/handler/admin/sora_account_handler.go
Normal file
355
backend/internal/handler/admin/sora_account_handler.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SoraAccountHandler Sora 账号扩展管理
|
||||
// 提供 Sora 扩展表的查询与更新能力。
|
||||
type SoraAccountHandler struct {
|
||||
adminService service.AdminService
|
||||
soraAccountRepo service.SoraAccountRepository
|
||||
usageRepo service.SoraUsageStatRepository
|
||||
}
|
||||
|
||||
// NewSoraAccountHandler 创建 SoraAccountHandler
|
||||
func NewSoraAccountHandler(adminService service.AdminService, soraAccountRepo service.SoraAccountRepository, usageRepo service.SoraUsageStatRepository) *SoraAccountHandler {
|
||||
return &SoraAccountHandler{
|
||||
adminService: adminService,
|
||||
soraAccountRepo: soraAccountRepo,
|
||||
usageRepo: usageRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// SoraAccountUpdateRequest 更新/创建 Sora 账号扩展请求
|
||||
// 使用指针类型区分未提供与设置为空值。
|
||||
type SoraAccountUpdateRequest struct {
|
||||
AccessToken *string `json:"access_token"`
|
||||
SessionToken *string `json:"session_token"`
|
||||
RefreshToken *string `json:"refresh_token"`
|
||||
ClientID *string `json:"client_id"`
|
||||
Email *string `json:"email"`
|
||||
Username *string `json:"username"`
|
||||
Remark *string `json:"remark"`
|
||||
UseCount *int `json:"use_count"`
|
||||
PlanType *string `json:"plan_type"`
|
||||
PlanTitle *string `json:"plan_title"`
|
||||
SubscriptionEnd *int64 `json:"subscription_end"`
|
||||
SoraSupported *bool `json:"sora_supported"`
|
||||
SoraInviteCode *string `json:"sora_invite_code"`
|
||||
SoraRedeemedCount *int `json:"sora_redeemed_count"`
|
||||
SoraRemainingCount *int `json:"sora_remaining_count"`
|
||||
SoraTotalCount *int `json:"sora_total_count"`
|
||||
SoraCooldownUntil *int64 `json:"sora_cooldown_until"`
|
||||
CooledUntil *int64 `json:"cooled_until"`
|
||||
ImageEnabled *bool `json:"image_enabled"`
|
||||
VideoEnabled *bool `json:"video_enabled"`
|
||||
ImageConcurrency *int `json:"image_concurrency"`
|
||||
VideoConcurrency *int `json:"video_concurrency"`
|
||||
IsExpired *bool `json:"is_expired"`
|
||||
}
|
||||
|
||||
// SoraAccountBatchRequest 批量导入请求
|
||||
// accounts 支持批量 upsert。
|
||||
type SoraAccountBatchRequest struct {
|
||||
Accounts []SoraAccountBatchItem `json:"accounts"`
|
||||
}
|
||||
|
||||
// SoraAccountBatchItem 批量导入条目
|
||||
type SoraAccountBatchItem struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
SoraAccountUpdateRequest
|
||||
}
|
||||
|
||||
// SoraAccountBatchResult 批量导入结果
|
||||
// 仅返回成功/失败数量与明细。
|
||||
type SoraAccountBatchResult struct {
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
Results []SoraAccountBatchItemResult `json:"results"`
|
||||
}
|
||||
|
||||
// SoraAccountBatchItemResult 批量导入单条结果
|
||||
type SoraAccountBatchItemResult struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// List 获取 Sora 账号扩展列表
|
||||
// GET /api/v1/admin/sora/accounts
|
||||
func (h *SoraAccountHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, service.PlatformSora, "", "", search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
accountIDs := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
accountIDs = append(accountIDs, accounts[i].ID)
|
||||
}
|
||||
|
||||
soraMap := map[int64]*service.SoraAccount{}
|
||||
if h.soraAccountRepo != nil {
|
||||
soraMap, _ = h.soraAccountRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
|
||||
}
|
||||
|
||||
usageMap := map[int64]*service.SoraUsageStat{}
|
||||
if h.usageRepo != nil {
|
||||
usageMap, _ = h.usageRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
|
||||
}
|
||||
|
||||
result := make([]dto.SoraAccount, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := accounts[i]
|
||||
item := dto.SoraAccountFromService(&acc, soraMap[acc.ID], usageMap[acc.ID])
|
||||
if item != nil {
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Get 获取单个 Sora 账号扩展
|
||||
// GET /api/v1/admin/sora/accounts/:id
|
||||
func (h *SoraAccountHandler) Get(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "账号 ID 无效")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if account.Platform != service.PlatformSora {
|
||||
response.BadRequest(c, "账号不是 Sora 平台")
|
||||
return
|
||||
}
|
||||
|
||||
var soraAcc *service.SoraAccount
|
||||
if h.soraAccountRepo != nil {
|
||||
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
var usage *service.SoraUsageStat
|
||||
if h.usageRepo != nil {
|
||||
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
|
||||
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
|
||||
}
|
||||
|
||||
// Upsert 更新或创建 Sora 账号扩展
|
||||
// PUT /api/v1/admin/sora/accounts/:id
|
||||
func (h *SoraAccountHandler) Upsert(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "账号 ID 无效")
|
||||
return
|
||||
}
|
||||
|
||||
var req SoraAccountUpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求参数无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if account.Platform != service.PlatformSora {
|
||||
response.BadRequest(c, "账号不是 Sora 平台")
|
||||
return
|
||||
}
|
||||
|
||||
updates := buildSoraAccountUpdates(&req)
|
||||
if h.soraAccountRepo != nil && len(updates) > 0 {
|
||||
if err := h.soraAccountRepo.Upsert(c.Request.Context(), accountID, updates); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var soraAcc *service.SoraAccount
|
||||
if h.soraAccountRepo != nil {
|
||||
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
var usage *service.SoraUsageStat
|
||||
if h.usageRepo != nil {
|
||||
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
|
||||
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
|
||||
}
|
||||
|
||||
// BatchUpsert 批量导入 Sora 账号扩展
|
||||
// POST /api/v1/admin/sora/accounts/import
|
||||
func (h *SoraAccountHandler) BatchUpsert(c *gin.Context) {
|
||||
var req SoraAccountBatchRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求参数无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.Accounts) == 0 {
|
||||
response.BadRequest(c, "accounts 不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(req.Accounts))
|
||||
for _, item := range req.Accounts {
|
||||
if item.AccountID > 0 {
|
||||
ids = append(ids, item.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
accountMap := make(map[int64]*service.Account, len(ids))
|
||||
if len(ids) > 0 {
|
||||
accounts, _ := h.adminService.GetAccountsByIDs(c.Request.Context(), ids)
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
accountMap[acc.ID] = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := SoraAccountBatchResult{
|
||||
Results: make([]SoraAccountBatchItemResult, 0, len(req.Accounts)),
|
||||
}
|
||||
|
||||
for _, item := range req.Accounts {
|
||||
entry := SoraAccountBatchItemResult{AccountID: item.AccountID}
|
||||
acc := accountMap[item.AccountID]
|
||||
if acc == nil {
|
||||
entry.Error = "账号不存在"
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Failed++
|
||||
continue
|
||||
}
|
||||
if acc.Platform != service.PlatformSora {
|
||||
entry.Error = "账号不是 Sora 平台"
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Failed++
|
||||
continue
|
||||
}
|
||||
updates := buildSoraAccountUpdates(&item.SoraAccountUpdateRequest)
|
||||
if h.soraAccountRepo != nil && len(updates) > 0 {
|
||||
if err := h.soraAccountRepo.Upsert(c.Request.Context(), item.AccountID, updates); err != nil {
|
||||
entry.Error = err.Error()
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Failed++
|
||||
continue
|
||||
}
|
||||
}
|
||||
entry.Success = true
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Success++
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListUsage 获取 Sora 调用统计
|
||||
// GET /api/v1/admin/sora/usage
|
||||
func (h *SoraAccountHandler) ListUsage(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
if h.usageRepo == nil {
|
||||
response.Paginated(c, []dto.SoraUsageStat{}, 0, page, pageSize)
|
||||
return
|
||||
}
|
||||
stats, paginationResult, err := h.usageRepo.List(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
result := make([]dto.SoraUsageStat, 0, len(stats))
|
||||
for _, stat := range stats {
|
||||
item := dto.SoraUsageStatFromService(stat)
|
||||
if item != nil {
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
response.Paginated(c, result, paginationResult.Total, paginationResult.Page, paginationResult.PageSize)
|
||||
}
|
||||
|
||||
func buildSoraAccountUpdates(req *SoraAccountUpdateRequest) map[string]any {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
updates := make(map[string]any)
|
||||
setString := func(key string, value *string) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
updates[key] = strings.TrimSpace(*value)
|
||||
}
|
||||
setString("access_token", req.AccessToken)
|
||||
setString("session_token", req.SessionToken)
|
||||
setString("refresh_token", req.RefreshToken)
|
||||
setString("client_id", req.ClientID)
|
||||
setString("email", req.Email)
|
||||
setString("username", req.Username)
|
||||
setString("remark", req.Remark)
|
||||
setString("plan_type", req.PlanType)
|
||||
setString("plan_title", req.PlanTitle)
|
||||
setString("sora_invite_code", req.SoraInviteCode)
|
||||
|
||||
if req.UseCount != nil {
|
||||
updates["use_count"] = *req.UseCount
|
||||
}
|
||||
if req.SoraSupported != nil {
|
||||
updates["sora_supported"] = *req.SoraSupported
|
||||
}
|
||||
if req.SoraRedeemedCount != nil {
|
||||
updates["sora_redeemed_count"] = *req.SoraRedeemedCount
|
||||
}
|
||||
if req.SoraRemainingCount != nil {
|
||||
updates["sora_remaining_count"] = *req.SoraRemainingCount
|
||||
}
|
||||
if req.SoraTotalCount != nil {
|
||||
updates["sora_total_count"] = *req.SoraTotalCount
|
||||
}
|
||||
if req.ImageEnabled != nil {
|
||||
updates["image_enabled"] = *req.ImageEnabled
|
||||
}
|
||||
if req.VideoEnabled != nil {
|
||||
updates["video_enabled"] = *req.VideoEnabled
|
||||
}
|
||||
if req.ImageConcurrency != nil {
|
||||
updates["image_concurrency"] = *req.ImageConcurrency
|
||||
}
|
||||
if req.VideoConcurrency != nil {
|
||||
updates["video_concurrency"] = *req.VideoConcurrency
|
||||
}
|
||||
if req.IsExpired != nil {
|
||||
updates["is_expired"] = *req.IsExpired
|
||||
}
|
||||
if req.SubscriptionEnd != nil && *req.SubscriptionEnd > 0 {
|
||||
updates["subscription_end"] = time.Unix(*req.SubscriptionEnd, 0).UTC()
|
||||
}
|
||||
if req.SoraCooldownUntil != nil && *req.SoraCooldownUntil > 0 {
|
||||
updates["sora_cooldown_until"] = time.Unix(*req.SoraCooldownUntil, 0).UTC()
|
||||
}
|
||||
if req.CooledUntil != nil && *req.CooledUntil > 0 {
|
||||
updates["cooled_until"] = time.Unix(*req.CooledUntil, 0).UTC()
|
||||
}
|
||||
return updates
|
||||
}
|
||||
@@ -287,6 +287,72 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
||||
}
|
||||
}
|
||||
|
||||
func SoraUsageStatFromService(stat *service.SoraUsageStat) *SoraUsageStat {
|
||||
if stat == nil {
|
||||
return nil
|
||||
}
|
||||
return &SoraUsageStat{
|
||||
AccountID: stat.AccountID,
|
||||
ImageCount: stat.ImageCount,
|
||||
VideoCount: stat.VideoCount,
|
||||
ErrorCount: stat.ErrorCount,
|
||||
LastErrorAt: timeToUnixSeconds(stat.LastErrorAt),
|
||||
TodayImageCount: stat.TodayImageCount,
|
||||
TodayVideoCount: stat.TodayVideoCount,
|
||||
TodayErrorCount: stat.TodayErrorCount,
|
||||
TodayDate: timeToUnixSeconds(stat.TodayDate),
|
||||
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
|
||||
CreatedAt: stat.CreatedAt,
|
||||
UpdatedAt: stat.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func SoraAccountFromService(account *service.Account, soraAcc *service.SoraAccount, usage *service.SoraUsageStat) *SoraAccount {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
out := &SoraAccount{
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
AccountStatus: account.Status,
|
||||
AccountType: account.Type,
|
||||
AccountConcurrency: account.Concurrency,
|
||||
ProxyID: account.ProxyID,
|
||||
Usage: SoraUsageStatFromService(usage),
|
||||
CreatedAt: account.CreatedAt,
|
||||
UpdatedAt: account.UpdatedAt,
|
||||
}
|
||||
if soraAcc == nil {
|
||||
return out
|
||||
}
|
||||
out.AccessToken = soraAcc.AccessToken
|
||||
out.SessionToken = soraAcc.SessionToken
|
||||
out.RefreshToken = soraAcc.RefreshToken
|
||||
out.ClientID = soraAcc.ClientID
|
||||
out.Email = soraAcc.Email
|
||||
out.Username = soraAcc.Username
|
||||
out.Remark = soraAcc.Remark
|
||||
out.UseCount = soraAcc.UseCount
|
||||
out.PlanType = soraAcc.PlanType
|
||||
out.PlanTitle = soraAcc.PlanTitle
|
||||
out.SubscriptionEnd = timeToUnixSeconds(soraAcc.SubscriptionEnd)
|
||||
out.SoraSupported = soraAcc.SoraSupported
|
||||
out.SoraInviteCode = soraAcc.SoraInviteCode
|
||||
out.SoraRedeemedCount = soraAcc.SoraRedeemedCount
|
||||
out.SoraRemainingCount = soraAcc.SoraRemainingCount
|
||||
out.SoraTotalCount = soraAcc.SoraTotalCount
|
||||
out.SoraCooldownUntil = timeToUnixSeconds(soraAcc.SoraCooldownUntil)
|
||||
out.CooledUntil = timeToUnixSeconds(soraAcc.CooledUntil)
|
||||
out.ImageEnabled = soraAcc.ImageEnabled
|
||||
out.VideoEnabled = soraAcc.VideoEnabled
|
||||
out.ImageConcurrency = soraAcc.ImageConcurrency
|
||||
out.VideoConcurrency = soraAcc.VideoConcurrency
|
||||
out.IsExpired = soraAcc.IsExpired
|
||||
out.CreatedAt = soraAcc.CreatedAt
|
||||
out.UpdatedAt = soraAcc.UpdatedAt
|
||||
return out
|
||||
}
|
||||
|
||||
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
|
||||
if a == nil {
|
||||
return nil
|
||||
|
||||
@@ -46,6 +46,25 @@ type SystemSettings struct {
|
||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||
|
||||
// Sora configuration
|
||||
SoraBaseURL string `json:"sora_base_url"`
|
||||
SoraTimeout int `json:"sora_timeout"`
|
||||
SoraMaxRetries int `json:"sora_max_retries"`
|
||||
SoraPollInterval float64 `json:"sora_poll_interval"`
|
||||
SoraCallLogicMode string `json:"sora_call_logic_mode"`
|
||||
SoraCacheEnabled bool `json:"sora_cache_enabled"`
|
||||
SoraCacheBaseDir string `json:"sora_cache_base_dir"`
|
||||
SoraCacheVideoDir string `json:"sora_cache_video_dir"`
|
||||
SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"`
|
||||
SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"`
|
||||
SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"`
|
||||
SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"`
|
||||
SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"`
|
||||
SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"`
|
||||
SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"`
|
||||
SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"`
|
||||
SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"`
|
||||
|
||||
// Ops monitoring (vNext)
|
||||
OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"`
|
||||
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
|
||||
|
||||
@@ -141,6 +141,56 @@ type Account struct {
|
||||
Groups []*Group `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
type SoraUsageStat struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
ImageCount int `json:"image_count"`
|
||||
VideoCount int `json:"video_count"`
|
||||
ErrorCount int `json:"error_count"`
|
||||
LastErrorAt *int64 `json:"last_error_at"`
|
||||
TodayImageCount int `json:"today_image_count"`
|
||||
TodayVideoCount int `json:"today_video_count"`
|
||||
TodayErrorCount int `json:"today_error_count"`
|
||||
TodayDate *int64 `json:"today_date"`
|
||||
ConsecutiveErrorCount int `json:"consecutive_error_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type SoraAccount struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
AccountName string `json:"account_name"`
|
||||
AccountStatus string `json:"account_status"`
|
||||
AccountType string `json:"account_type"`
|
||||
AccountConcurrency int `json:"account_concurrency"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
SessionToken string `json:"session_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Remark string `json:"remark"`
|
||||
UseCount int `json:"use_count"`
|
||||
PlanType string `json:"plan_type"`
|
||||
PlanTitle string `json:"plan_title"`
|
||||
SubscriptionEnd *int64 `json:"subscription_end"`
|
||||
SoraSupported bool `json:"sora_supported"`
|
||||
SoraInviteCode string `json:"sora_invite_code"`
|
||||
SoraRedeemedCount int `json:"sora_redeemed_count"`
|
||||
SoraRemainingCount int `json:"sora_remaining_count"`
|
||||
SoraTotalCount int `json:"sora_total_count"`
|
||||
SoraCooldownUntil *int64 `json:"sora_cooldown_until"`
|
||||
CooledUntil *int64 `json:"cooled_until"`
|
||||
ImageEnabled bool `json:"image_enabled"`
|
||||
VideoEnabled bool `json:"video_enabled"`
|
||||
ImageConcurrency int `json:"image_concurrency"`
|
||||
VideoConcurrency int `json:"video_concurrency"`
|
||||
IsExpired bool `json:"is_expired"`
|
||||
Usage *SoraUsageStat `json:"usage,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -508,6 +509,13 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if platform == service.PlatformSora {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": sora.ListModels(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
|
||||
@@ -17,6 +17,7 @@ type AdminHandlers struct {
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
SoraAccount *admin.SoraAccountHandler
|
||||
Setting *admin.SettingHandler
|
||||
Ops *admin.OpsHandler
|
||||
System *admin.SystemHandler
|
||||
@@ -36,6 +37,7 @@ type Handlers struct {
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
|
||||
@@ -814,6 +814,8 @@ func guessPlatformFromPath(path string) string {
|
||||
return service.PlatformAntigravity
|
||||
case strings.HasPrefix(p, "/v1beta/"):
|
||||
return service.PlatformGemini
|
||||
case strings.Contains(p, "/chat/completions"):
|
||||
return service.PlatformSora
|
||||
case strings.Contains(p, "/responses"):
|
||||
return service.PlatformOpenAI
|
||||
default:
|
||||
|
||||
364
backend/internal/handler/sora_gateway_handler.go
Normal file
364
backend/internal/handler/sora_gateway_handler.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora OpenAI compatible endpoints.
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler.
|
||||
func NewSoraGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
}
|
||||
return &SoraGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletions handles Sora OpenAI-compatible chat completions endpoint.
|
||||
// POST /v1/chat/completions
|
||||
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := reqBody["model"].(string)
|
||||
if strings.TrimSpace(model) == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
stream, _ := reqBody["stream"].(bool)
|
||||
|
||||
prompt, imageData, videoData, remixID, err := parseSoraPrompt(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
if remixID == "" {
|
||||
remixID = sora.ExtractRemixID(prompt)
|
||||
}
|
||||
if remixID != "" {
|
||||
prompt = strings.ReplaceAll(prompt, remixID, "")
|
||||
}
|
||||
|
||||
if apiKey.Group != nil && apiKey.Group.Platform != service.PlatformSora {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "当前分组不支持 Sora 平台")
|
||||
return
|
||||
}
|
||||
|
||||
streamStarted := false
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
if err == nil && !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
maxSwitches := h.maxAccountSwitches
|
||||
if mode := h.soraGatewayService.CallLogicMode(c.Request.Context()); strings.EqualFold(mode, "native") {
|
||||
maxSwitches = 1
|
||||
}
|
||||
|
||||
for switchCount := 0; switchCount < maxSwitches; switchCount++ {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, "", model, failedAccountIDs, "")
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "server_error", err.Error())
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
releaseFunc := selection.ReleaseFunc
|
||||
|
||||
result, err := h.soraGatewayService.Generate(c.Request.Context(), account, service.SoraGenerationRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Image: imageData,
|
||||
Video: videoData,
|
||||
RemixTargetID: remixID,
|
||||
Stream: stream,
|
||||
UserID: subject.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
// 失败路径:立即释放槽位,而非 defer
|
||||
if releaseFunc != nil {
|
||||
releaseFunc()
|
||||
}
|
||||
|
||||
if errors.Is(err, service.ErrSoraAccountMissingToken) || errors.Is(err, service.ErrSoraAccountNotEligible) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "server_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 成功路径:使用 defer 在函数退出时释放
|
||||
if releaseFunc != nil {
|
||||
defer releaseFunc()
|
||||
}
|
||||
|
||||
h.respondCompletion(c, model, result, stream)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleFailoverExhausted(c, http.StatusServiceUnavailable, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) respondCompletion(c *gin.Context, model string, result *service.SoraGenerationResult, stream bool) {
|
||||
if result == nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Empty response")
|
||||
return
|
||||
}
|
||||
if stream {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
first := buildSoraStreamChunk(model, "", true, "")
|
||||
if _, err := c.Writer.WriteString(first); err != nil {
|
||||
return
|
||||
}
|
||||
final := buildSoraStreamChunk(model, result.Content, false, "stop")
|
||||
if _, err := c.Writer.WriteString(final); err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildSoraNonStreamResponse(model, result.Content))
|
||||
}
|
||||
|
||||
func buildSoraStreamChunk(model, content string, isFirst bool, finishReason string) string {
|
||||
chunkID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli())
|
||||
delta := map[string]any{}
|
||||
if isFirst {
|
||||
delta["role"] = "assistant"
|
||||
}
|
||||
if content != "" {
|
||||
delta["content"] = content
|
||||
} else {
|
||||
delta["content"] = nil
|
||||
}
|
||||
response := map[string]any{
|
||||
"id": chunkID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload, _ := json.Marshal(response)
|
||||
return "data: " + string(payload) + "\n\n"
|
||||
}
|
||||
|
||||
func buildSoraNonStreamResponse(model, content string) map[string]any {
|
||||
return map[string]any{
|
||||
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()),
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseSoraPrompt(req map[string]any) (prompt, imageData, videoData, remixID string, err error) {
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok || len(messages) == 0 {
|
||||
return "", "", "", "", fmt.Errorf("messages is required")
|
||||
}
|
||||
last := messages[len(messages)-1]
|
||||
msg, ok := last.(map[string]any)
|
||||
if !ok {
|
||||
return "", "", "", "", fmt.Errorf("invalid message format")
|
||||
}
|
||||
content, ok := msg["content"]
|
||||
if !ok {
|
||||
return "", "", "", "", fmt.Errorf("content is required")
|
||||
}
|
||||
|
||||
if v, ok := req["image"].(string); ok && v != "" {
|
||||
imageData = v
|
||||
}
|
||||
if v, ok := req["video"].(string); ok && v != "" {
|
||||
videoData = v
|
||||
}
|
||||
if v, ok := req["remix_target_id"].(string); ok {
|
||||
remixID = v
|
||||
}
|
||||
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
prompt = value
|
||||
case []any:
|
||||
for _, item := range value {
|
||||
part, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch part["type"] {
|
||||
case "text":
|
||||
if text, ok := part["text"].(string); ok {
|
||||
prompt = text
|
||||
}
|
||||
case "image_url":
|
||||
if image, ok := part["image_url"].(map[string]any); ok {
|
||||
if url, ok := image["url"].(string); ok {
|
||||
imageData = url
|
||||
}
|
||||
}
|
||||
case "video_url":
|
||||
if video, ok := part["video_url"].(map[string]any); ok {
|
||||
if url, ok := video["url"].(string); ok {
|
||||
videoData = url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return "", "", "", "", fmt.Errorf("invalid content format")
|
||||
}
|
||||
if strings.TrimSpace(prompt) == "" && strings.TrimSpace(videoData) == "" {
|
||||
return "", "", "", "", fmt.Errorf("prompt is required")
|
||||
}
|
||||
return prompt, imageData, videoData, remixID, nil
|
||||
}
|
||||
|
||||
func looksLikeURL(value string) bool {
|
||||
trimmed := strings.ToLower(strings.TrimSpace(value))
|
||||
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", err.Error(), true)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
message := "No available Sora accounts"
|
||||
h.handleStreamingAwareError(c, statusCode, "server_error", message, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
payload := map[string]any{"error": map[string]any{"message": message, "type": errType, "param": nil, "code": nil}}
|
||||
data, _ := json.Marshal(payload)
|
||||
_, _ = c.Writer.WriteString("data: " + string(data) + "\n\n")
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"message": message,
|
||||
"type": errType,
|
||||
"param": nil,
|
||||
"code": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
soraAccountHandler *admin.SoraAccountHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
opsHandler *admin.OpsHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
SoraAccount: soraAccountHandler,
|
||||
Setting: settingHandler,
|
||||
Ops: opsHandler,
|
||||
System: systemHandler,
|
||||
@@ -69,6 +71,7 @@ func ProvideHandlers(
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
@@ -81,6 +84,7 @@ func ProvideHandlers(
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewSoraGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
@@ -110,6 +115,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
admin.NewSoraAccountHandler,
|
||||
admin.NewSettingHandler,
|
||||
admin.NewOpsHandler,
|
||||
ProvideSystemHandler,
|
||||
|
||||
148
backend/internal/pkg/sora/character.go
Normal file
148
backend/internal/pkg/sora/character.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
)
|
||||
|
||||
// UploadCharacterVideo uploads a character video and returns cameo ID.
|
||||
func (c *Client) UploadCharacterVideo(ctx context.Context, opts RequestOptions, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("video data empty")
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
if err := writeMultipartFile(writer, "file", "video.mp4", "video/mp4", data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.WriteField("timestamps", "0,3"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/characters/upload", opts, &buf, writer.FormDataContentType(), false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "id"), nil
|
||||
}
|
||||
|
||||
// GetCameoStatus returns cameo processing status.
|
||||
func (c *Client) GetCameoStatus(ctx context.Context, opts RequestOptions, cameoID string) (map[string]any, error) {
|
||||
if cameoID == "" {
|
||||
return nil, errors.New("cameo id empty")
|
||||
}
|
||||
return c.doRequest(ctx, "GET", "/project_y/cameos/in_progress/"+cameoID, opts, nil, "", false)
|
||||
}
|
||||
|
||||
// DownloadCharacterImage downloads character avatar image data.
|
||||
func (c *Client) DownloadCharacterImage(ctx context.Context, opts RequestOptions, imageURL string) ([]byte, error) {
|
||||
if c.upstream == nil {
|
||||
return nil, errors.New("upstream is nil")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("User-Agent", defaultDesktopUA)
|
||||
resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("download image failed: %d", resp.StatusCode)
|
||||
}
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// UploadCharacterImage uploads character avatar and returns asset pointer.
|
||||
func (c *Client) UploadCharacterImage(ctx context.Context, opts RequestOptions, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("image data empty")
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
if err := writeMultipartFile(writer, "file", "profile.webp", "image/webp", data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.WriteField("use_case", "profile"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/project_y/file/upload", opts, &buf, writer.FormDataContentType(), false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "asset_pointer"), nil
|
||||
}
|
||||
|
||||
// FinalizeCharacter finalizes character creation and returns character ID.
|
||||
func (c *Client) FinalizeCharacter(ctx context.Context, opts RequestOptions, cameoID, username, displayName, assetPointer string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"cameo_id": cameoID,
|
||||
"username": username,
|
||||
"display_name": displayName,
|
||||
"profile_asset_pointer": assetPointer,
|
||||
"instruction_set": nil,
|
||||
"safety_instruction_set": nil,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/characters/finalize", opts, bytes.NewReader(body), "application/json", false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if character, ok := resp["character"].(map[string]any); ok {
|
||||
if id, ok := character["character_id"].(string); ok {
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// SetCharacterPublic marks character as public.
|
||||
func (c *Client) SetCharacterPublic(ctx context.Context, opts RequestOptions, cameoID string) error {
|
||||
payload := map[string]any{"visibility": "public"}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.doRequest(ctx, "POST", "/project_y/cameos/by_id/"+cameoID+"/update_v2", opts, bytes.NewReader(body), "application/json", false)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteCharacter deletes a character by ID.
|
||||
func (c *Client) DeleteCharacter(ctx context.Context, opts RequestOptions, characterID string) error {
|
||||
if characterID == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := c.doRequest(ctx, "DELETE", "/project_y/characters/"+characterID, opts, nil, "", false)
|
||||
return err
|
||||
}
|
||||
|
||||
func writeMultipartFile(writer *multipart.Writer, field, filename, contentType string, data []byte) error {
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, field, filename))
|
||||
if contentType != "" {
|
||||
header.Set("Content-Type", contentType)
|
||||
}
|
||||
part, err := writer.CreatePart(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = part.Write(data)
|
||||
return err
|
||||
}
|
||||
612
backend/internal/pkg/sora/client.go
Normal file
612
backend/internal/pkg/sora/client.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha3"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
chatGPTBaseURL = "https://chatgpt.com"
|
||||
sentinelFlow = "sora_2_create_task"
|
||||
maxAPIResponseSize = 1 * 1024 * 1024 // 1MB
|
||||
)
|
||||
|
||||
var (
|
||||
defaultMobileUA = "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)"
|
||||
defaultDesktopUA = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
sentinelCache sync.Map // 包级缓存,存储 Sentinel Token,key 为 accountID
|
||||
)
|
||||
|
||||
// sentinelCacheEntry 是 Sentinel Token 缓存条目
|
||||
type sentinelCacheEntry struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// UpstreamClient defines the HTTP client interface for Sora requests.
|
||||
type UpstreamClient interface {
|
||||
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
|
||||
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
|
||||
}
|
||||
|
||||
// Client is a minimal Sora API client.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
timeout time.Duration
|
||||
upstream UpstreamClient
|
||||
enableTLSFingerprint bool
|
||||
}
|
||||
|
||||
// RequestOptions configures per-request context.
|
||||
type RequestOptions struct {
|
||||
AccountID int64
|
||||
AccountConcurrency int
|
||||
ProxyURL string
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
// getCachedSentinel 从缓存中获取 Sentinel Token
|
||||
func getCachedSentinel(accountID int64) (string, bool) {
|
||||
v, ok := sentinelCache.Load(accountID)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
entry := v.(*sentinelCacheEntry)
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
sentinelCache.Delete(accountID)
|
||||
return "", false
|
||||
}
|
||||
return entry.token, true
|
||||
}
|
||||
|
||||
// cacheSentinel 缓存 Sentinel Token
|
||||
func cacheSentinel(accountID int64, token string) {
|
||||
sentinelCache.Store(accountID, &sentinelCacheEntry{
|
||||
token: token,
|
||||
expiresAt: time.Now().Add(3 * time.Minute), // 3分钟有效期
|
||||
})
|
||||
}
|
||||
|
||||
// NewClient creates a Sora client.
|
||||
func NewClient(baseURL string, timeout time.Duration, upstream UpstreamClient, enableTLSFingerprint bool) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
timeout: timeout,
|
||||
upstream: upstream,
|
||||
enableTLSFingerprint: enableTLSFingerprint,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadImage uploads an image and returns media ID.
|
||||
func (c *Client) UploadImage(ctx context.Context, opts RequestOptions, data []byte, filename string) (string, error) {
|
||||
if filename == "" {
|
||||
filename = "image.png"
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.WriteField("file_name", filename); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/uploads", opts, &buf, writer.FormDataContentType(), false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "id"), nil
|
||||
}
|
||||
|
||||
// GenerateImage creates an image generation task.
|
||||
func (c *Client) GenerateImage(ctx context.Context, opts RequestOptions, prompt string, width, height int, mediaID string) (string, error) {
|
||||
operation := "simple_compose"
|
||||
var inpaint []map[string]any
|
||||
if mediaID != "" {
|
||||
operation = "remix"
|
||||
inpaint = []map[string]any{
|
||||
{
|
||||
"type": "image",
|
||||
"frame_index": 0,
|
||||
"upload_media_id": mediaID,
|
||||
},
|
||||
}
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "image_gen",
|
||||
"operation": operation,
|
||||
"prompt": prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"n_variants": 1,
|
||||
"n_frames": 1,
|
||||
"inpaint_items": inpaint,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/video_gen", opts, bytes.NewReader(body), "application/json", true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "id"), nil
|
||||
}
|
||||
|
||||
// GenerateVideo creates a video generation task.
|
||||
func (c *Client) GenerateVideo(ctx context.Context, opts RequestOptions, prompt, orientation string, nFrames int, mediaID, styleID, model, size string) (string, error) {
|
||||
var inpaint []map[string]any
|
||||
if mediaID != "" {
|
||||
inpaint = []map[string]any{{"kind": "upload", "upload_id": mediaID}}
|
||||
}
|
||||
payload := map[string]any{
|
||||
"kind": "video",
|
||||
"prompt": prompt,
|
||||
"orientation": orientation,
|
||||
"size": size,
|
||||
"n_frames": nFrames,
|
||||
"model": model,
|
||||
"inpaint_items": inpaint,
|
||||
"style_id": styleID,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/nf/create", opts, bytes.NewReader(body), "application/json", true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "id"), nil
|
||||
}
|
||||
|
||||
// GenerateStoryboard creates a storyboard video task.
|
||||
func (c *Client) GenerateStoryboard(ctx context.Context, opts RequestOptions, prompt, orientation string, nFrames int, mediaID, styleID string) (string, error) {
|
||||
var inpaint []map[string]any
|
||||
if mediaID != "" {
|
||||
inpaint = []map[string]any{{"kind": "upload", "upload_id": mediaID}}
|
||||
}
|
||||
payload := map[string]any{
|
||||
"kind": "video",
|
||||
"prompt": prompt,
|
||||
"title": "Draft your video",
|
||||
"orientation": orientation,
|
||||
"size": "small",
|
||||
"n_frames": nFrames,
|
||||
"storyboard_id": nil,
|
||||
"inpaint_items": inpaint,
|
||||
"remix_target_id": nil,
|
||||
"model": "sy_8",
|
||||
"metadata": nil,
|
||||
"style_id": styleID,
|
||||
"cameo_ids": nil,
|
||||
"cameo_replacements": nil,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/nf/create/storyboard", opts, bytes.NewReader(body), "application/json", true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "id"), nil
|
||||
}
|
||||
|
||||
// RemixVideo creates a remix task.
|
||||
func (c *Client) RemixVideo(ctx context.Context, opts RequestOptions, remixTargetID, prompt, orientation string, nFrames int, styleID string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"kind": "video",
|
||||
"prompt": prompt,
|
||||
"inpaint_items": []map[string]any{},
|
||||
"remix_target_id": remixTargetID,
|
||||
"cameo_ids": []string{},
|
||||
"cameo_replacements": map[string]any{},
|
||||
"model": "sy_8",
|
||||
"orientation": orientation,
|
||||
"n_frames": nFrames,
|
||||
"style_id": styleID,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/nf/create", opts, bytes.NewReader(body), "application/json", true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "id"), nil
|
||||
}
|
||||
|
||||
// GetImageTasks returns recent image tasks.
|
||||
func (c *Client) GetImageTasks(ctx context.Context, opts RequestOptions) (map[string]any, error) {
|
||||
return c.doRequest(ctx, "GET", "/v2/recent_tasks?limit=20", opts, nil, "", false)
|
||||
}
|
||||
|
||||
// GetPendingTasks returns pending video tasks.
|
||||
func (c *Client) GetPendingTasks(ctx context.Context, opts RequestOptions) ([]map[string]any, error) {
|
||||
resp, err := c.doRequestAny(ctx, "GET", "/nf/pending/v2", opts, nil, "", false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch v := resp.(type) {
|
||||
case []any:
|
||||
return convertList(v), nil
|
||||
case map[string]any:
|
||||
if list, ok := v["items"].([]any); ok {
|
||||
return convertList(list), nil
|
||||
}
|
||||
if arr, ok := v["data"].([]any); ok {
|
||||
return convertList(arr), nil
|
||||
}
|
||||
return convertListFromAny(v), nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetVideoDrafts returns recent video drafts.
|
||||
func (c *Client) GetVideoDrafts(ctx context.Context, opts RequestOptions) (map[string]any, error) {
|
||||
return c.doRequest(ctx, "GET", "/project_y/profile/drafts?limit=15", opts, nil, "", false)
|
||||
}
|
||||
|
||||
// EnhancePrompt calls prompt enhancement API.
|
||||
func (c *Client) EnhancePrompt(ctx context.Context, opts RequestOptions, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
payload := map[string]any{
|
||||
"prompt": prompt,
|
||||
"expansion_level": expansionLevel,
|
||||
"duration_s": durationS,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/editor/enhance_prompt", opts, bytes.NewReader(body), "application/json", false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return stringFromJSON(resp, "enhanced_prompt"), nil
|
||||
}
|
||||
|
||||
// PostVideoForWatermarkFree publishes a video for watermark-free parsing.
|
||||
func (c *Client) PostVideoForWatermarkFree(ctx context.Context, opts RequestOptions, generationID string) (string, error) {
|
||||
payload := map[string]any{
|
||||
"attachments_to_create": []map[string]any{{
|
||||
"generation_id": generationID,
|
||||
"kind": "sora",
|
||||
}},
|
||||
"post_text": "",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doRequest(ctx, "POST", "/project_y/post", opts, bytes.NewReader(body), "application/json", true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
post, _ := resp["post"].(map[string]any)
|
||||
if post == nil {
|
||||
return "", nil
|
||||
}
|
||||
return stringFromJSON(post, "id"), nil
|
||||
}
|
||||
|
||||
// DeletePost deletes a Sora post.
|
||||
func (c *Client) DeletePost(ctx context.Context, opts RequestOptions, postID string) error {
|
||||
if postID == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := c.doRequest(ctx, "DELETE", "/project_y/post/"+postID, opts, nil, "", false)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) doRequest(ctx context.Context, method, endpoint string, opts RequestOptions, body io.Reader, contentType string, addSentinel bool) (map[string]any, error) {
|
||||
resp, err := c.doRequestAny(ctx, method, endpoint, opts, body, contentType, addSentinel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, ok := resp.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected response format")
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func (c *Client) doRequestAny(ctx context.Context, method, endpoint string, opts RequestOptions, body io.Reader, contentType string, addSentinel bool) (any, error) {
|
||||
if c.upstream == nil {
|
||||
return nil, errors.New("upstream is nil")
|
||||
}
|
||||
url := c.baseURL + endpoint
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if contentType != "" {
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
if opts.AccessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
|
||||
}
|
||||
req.Header.Set("User-Agent", defaultMobileUA)
|
||||
if addSentinel {
|
||||
sentinel, err := c.generateSentinelToken(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("openai-sentinel-token", sentinel)
|
||||
}
|
||||
resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 使用 LimitReader 限制最大响应大小,防止 DoS 攻击
|
||||
limitedReader := io.LimitReader(resp.Body, maxAPIResponseSize+1)
|
||||
data, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否超过大小限制
|
||||
if int64(len(data)) > maxAPIResponseSize {
|
||||
return nil, fmt.Errorf("API 响应过大 (最大 %d 字节)", maxAPIResponseSize)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("sora api error: %d %s", resp.StatusCode, strings.TrimSpace(string(data)))
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
var parsed any
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func (c *Client) generateSentinelToken(ctx context.Context, opts RequestOptions) (string, error) {
|
||||
// 尝试从缓存获取
|
||||
if token, ok := getCachedSentinel(opts.AccountID); ok {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
reqID := uuid.New().String()
|
||||
powToken, err := generatePowToken(defaultDesktopUA)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
payload := map[string]any{"p": powToken, "flow": sentinelFlow, "id": reqID}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
url := chatGPTBaseURL + "/backend-api/sentinel/req"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", defaultDesktopUA)
|
||||
if opts.AccessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
|
||||
}
|
||||
resp, err := c.upstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, c.enableTLSFingerprint)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 使用 LimitReader 限制最大响应大小,防止 DoS 攻击
|
||||
limitedReader := io.LimitReader(resp.Body, maxAPIResponseSize+1)
|
||||
data, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 检查是否超过大小限制
|
||||
if int64(len(data)) > maxAPIResponseSize {
|
||||
return "", fmt.Errorf("API 响应过大 (最大 %d 字节)", maxAPIResponseSize)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("sentinel request failed: %d %s", resp.StatusCode, strings.TrimSpace(string(data)))
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
token := buildSentinelToken(reqID, powToken, parsed)
|
||||
|
||||
// 缓存结果
|
||||
cacheSentinel(opts.AccountID, token)
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func buildSentinelToken(reqID, powToken string, resp map[string]any) string {
|
||||
finalPow := powToken
|
||||
pow, _ := resp["proofofwork"].(map[string]any)
|
||||
if pow != nil {
|
||||
required, _ := pow["required"].(bool)
|
||||
if required {
|
||||
seed, _ := pow["seed"].(string)
|
||||
difficulty, _ := pow["difficulty"].(string)
|
||||
if seed != "" && difficulty != "" {
|
||||
candidate, _ := solvePow(seed, difficulty, defaultDesktopUA)
|
||||
if candidate != "" {
|
||||
finalPow = "gAAAAAB" + candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !strings.HasSuffix(finalPow, "~S") {
|
||||
finalPow += "~S"
|
||||
}
|
||||
turnstile := ""
|
||||
if t, ok := resp["turnstile"].(map[string]any); ok {
|
||||
turnstile, _ = t["dx"].(string)
|
||||
}
|
||||
token := ""
|
||||
if v, ok := resp["token"].(string); ok {
|
||||
token = v
|
||||
}
|
||||
payload := map[string]any{
|
||||
"p": finalPow,
|
||||
"t": turnstile,
|
||||
"c": token,
|
||||
"id": reqID,
|
||||
"flow": sentinelFlow,
|
||||
}
|
||||
data, _ := json.Marshal(payload)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func generatePowToken(userAgent string) (string, error) {
|
||||
seed := fmt.Sprintf("%f", float64(time.Now().UnixNano())/1e9)
|
||||
candidate, _ := solvePow(seed, "0fffff", userAgent)
|
||||
if candidate == "" {
|
||||
return "", errors.New("pow generation failed")
|
||||
}
|
||||
return "gAAAAAC" + candidate, nil
|
||||
}
|
||||
|
||||
func solvePow(seed, difficulty, userAgent string) (string, bool) {
|
||||
config := powConfig(userAgent)
|
||||
seedBytes := []byte(seed)
|
||||
diffBytes, err := hexDecode(difficulty)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
configBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
prefix := configBytes[:len(configBytes)-1]
|
||||
for i := 0; i < 500000; i++ {
|
||||
payload := append(prefix, []byte(fmt.Sprintf(",%d,%d]", i, i>>1))...)
|
||||
b64 := base64.StdEncoding.EncodeToString(payload)
|
||||
h := sha3.Sum512(append(seedBytes, []byte(b64)...))
|
||||
if bytes.Compare(h[:len(diffBytes)], diffBytes) <= 0 {
|
||||
return b64, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func powConfig(userAgent string) []any {
|
||||
return []any{
|
||||
3000,
|
||||
formatPowTime(),
|
||||
4294705152,
|
||||
0,
|
||||
userAgent,
|
||||
"",
|
||||
nil,
|
||||
"en-US",
|
||||
"en-US,es-US,en,es",
|
||||
0,
|
||||
"webdriver-false",
|
||||
"location",
|
||||
"window",
|
||||
time.Now().UnixMilli(),
|
||||
uuid.New().String(),
|
||||
"",
|
||||
16,
|
||||
float64(time.Now().UnixMilli()),
|
||||
}
|
||||
}
|
||||
|
||||
func formatPowTime() string {
|
||||
loc := time.FixedZone("EST", -5*60*60)
|
||||
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05") + " GMT-0500 (Eastern Standard Time)"
|
||||
}
|
||||
|
||||
func hexDecode(s string) ([]byte, error) {
|
||||
if len(s)%2 != 0 {
|
||||
return nil, errors.New("invalid hex length")
|
||||
}
|
||||
out := make([]byte, len(s)/2)
|
||||
for i := 0; i < len(out); i++ {
|
||||
byteVal, err := hexPair(s[i*2 : i*2+2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out[i] = byteVal
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func hexPair(pair string) (byte, error) {
|
||||
var v byte
|
||||
for i := 0; i < 2; i++ {
|
||||
c := pair[i]
|
||||
var n byte
|
||||
switch {
|
||||
case c >= '0' && c <= '9':
|
||||
n = c - '0'
|
||||
case c >= 'a' && c <= 'f':
|
||||
n = c - 'a' + 10
|
||||
case c >= 'A' && c <= 'F':
|
||||
n = c - 'A' + 10
|
||||
default:
|
||||
return 0, errors.New("invalid hex")
|
||||
}
|
||||
v = v<<4 | n
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func stringFromJSON(data map[string]any, key string) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := data[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertList(list []any) []map[string]any {
|
||||
results := make([]map[string]any, 0, len(list))
|
||||
for _, item := range list {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
results = append(results, m)
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func convertListFromAny(data map[string]any) []map[string]any {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := data["items"].([]any)
|
||||
if ok {
|
||||
return convertList(items)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
263
backend/internal/pkg/sora/models.go
Normal file
263
backend/internal/pkg/sora/models.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package sora
|
||||
|
||||
// ModelConfig 定义 Sora 模型配置。
|
||||
type ModelConfig struct {
|
||||
Type string
|
||||
Width int
|
||||
Height int
|
||||
Orientation string
|
||||
NFrames int
|
||||
Model string
|
||||
Size string
|
||||
RequirePro bool
|
||||
ExpansionLevel string
|
||||
DurationS int
|
||||
}
|
||||
|
||||
// ModelConfigs 定义所有模型配置。
|
||||
var ModelConfigs = map[string]ModelConfig{
|
||||
"gpt-image": {
|
||||
Type: "image",
|
||||
Width: 360,
|
||||
Height: 360,
|
||||
},
|
||||
"gpt-image-landscape": {
|
||||
Type: "image",
|
||||
Width: 540,
|
||||
Height: 360,
|
||||
},
|
||||
"gpt-image-portrait": {
|
||||
Type: "image",
|
||||
Width: 360,
|
||||
Height: 540,
|
||||
},
|
||||
"sora2-landscape-10s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 300,
|
||||
},
|
||||
"sora2-portrait-10s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 300,
|
||||
},
|
||||
"sora2-landscape-15s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 450,
|
||||
},
|
||||
"sora2-portrait-15s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 450,
|
||||
},
|
||||
"sora2-landscape-25s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 750,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2-portrait-25s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 750,
|
||||
Model: "sy_8",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-landscape-10s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-portrait-10s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-landscape-15s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-portrait-15s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-landscape-25s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 750,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-portrait-25s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 750,
|
||||
Model: "sy_ore",
|
||||
Size: "small",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-landscape-10s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-portrait-10s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 300,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-landscape-15s": {
|
||||
Type: "video",
|
||||
Orientation: "landscape",
|
||||
NFrames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"sora2pro-hd-portrait-15s": {
|
||||
Type: "video",
|
||||
Orientation: "portrait",
|
||||
NFrames: 450,
|
||||
Model: "sy_ore",
|
||||
Size: "large",
|
||||
RequirePro: true,
|
||||
},
|
||||
"prompt-enhance-short-10s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-short-15s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-short-20s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "short",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-medium-10s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-medium-15s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-medium-20s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "medium",
|
||||
DurationS: 20,
|
||||
},
|
||||
"prompt-enhance-long-10s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 10,
|
||||
},
|
||||
"prompt-enhance-long-15s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 15,
|
||||
},
|
||||
"prompt-enhance-long-20s": {
|
||||
Type: "prompt_enhance",
|
||||
ExpansionLevel: "long",
|
||||
DurationS: 20,
|
||||
},
|
||||
}
|
||||
|
||||
// ModelListItem 返回模型列表条目。
|
||||
type ModelListItem struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// ListModels 生成模型列表。
|
||||
func ListModels() []ModelListItem {
|
||||
models := make([]ModelListItem, 0, len(ModelConfigs))
|
||||
for id, cfg := range ModelConfigs {
|
||||
description := ""
|
||||
switch cfg.Type {
|
||||
case "image":
|
||||
description = "Image generation"
|
||||
if cfg.Width > 0 && cfg.Height > 0 {
|
||||
description += " - " + itoa(cfg.Width) + "x" + itoa(cfg.Height)
|
||||
}
|
||||
case "video":
|
||||
description = "Video generation"
|
||||
if cfg.Orientation != "" {
|
||||
description += " - " + cfg.Orientation
|
||||
}
|
||||
case "prompt_enhance":
|
||||
description = "Prompt enhancement"
|
||||
if cfg.ExpansionLevel != "" {
|
||||
description += " - " + cfg.ExpansionLevel
|
||||
}
|
||||
if cfg.DurationS > 0 {
|
||||
description += " (" + itoa(cfg.DurationS) + "s)"
|
||||
}
|
||||
default:
|
||||
description = "Sora model"
|
||||
}
|
||||
models = append(models, ModelListItem{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
OwnedBy: "sora",
|
||||
Description: description,
|
||||
})
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func itoa(val int) string {
|
||||
if val == 0 {
|
||||
return "0"
|
||||
}
|
||||
neg := false
|
||||
if val < 0 {
|
||||
neg = true
|
||||
val = -val
|
||||
}
|
||||
buf := [12]byte{}
|
||||
i := len(buf)
|
||||
for val > 0 {
|
||||
i--
|
||||
buf[i] = byte('0' + val%10)
|
||||
val /= 10
|
||||
}
|
||||
if neg {
|
||||
i--
|
||||
buf[i] = '-'
|
||||
}
|
||||
return string(buf[i:])
|
||||
}
|
||||
63
backend/internal/pkg/sora/prompt.go
Normal file
63
backend/internal/pkg/sora/prompt.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package sora
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var storyboardRe = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]`)
|
||||
|
||||
// IsStoryboardPrompt 检测是否为分镜提示词。
|
||||
func IsStoryboardPrompt(prompt string) bool {
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
return false
|
||||
}
|
||||
return storyboardRe.MatchString(prompt)
|
||||
}
|
||||
|
||||
// FormatStoryboardPrompt 将分镜提示词转换为 API 需要的格式。
|
||||
func FormatStoryboardPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return prompt
|
||||
}
|
||||
matches := storyboardRe.FindAllStringSubmatchIndex(prompt, -1)
|
||||
if len(matches) == 0 {
|
||||
return prompt
|
||||
}
|
||||
firstIdx := matches[0][0]
|
||||
instructions := strings.TrimSpace(prompt[:firstIdx])
|
||||
|
||||
shotPattern := regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
|
||||
shotMatches := shotPattern.FindAllStringSubmatch(prompt, -1)
|
||||
if len(shotMatches) == 0 {
|
||||
return prompt
|
||||
}
|
||||
|
||||
shots := make([]string, 0, len(shotMatches))
|
||||
for i, sm := range shotMatches {
|
||||
if len(sm) < 3 {
|
||||
continue
|
||||
}
|
||||
duration := strings.TrimSpace(sm[1])
|
||||
scene := strings.TrimSpace(sm[2])
|
||||
shots = append(shots, "Shot "+itoa(i+1)+":\nduration: "+duration+"sec\nScene: "+scene)
|
||||
}
|
||||
|
||||
timeline := strings.Join(shots, "\n\n")
|
||||
if instructions != "" {
|
||||
return "current timeline:\n" + timeline + "\n\ninstructions:\n" + instructions
|
||||
}
|
||||
return timeline
|
||||
}
|
||||
|
||||
// ExtractRemixID 提取分享链接中的 remix ID。
|
||||
func ExtractRemixID(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
re := regexp.MustCompile(`s_[a-f0-9]{32}`)
|
||||
match := re.FindString(text)
|
||||
return match
|
||||
}
|
||||
31
backend/internal/pkg/uuidv7/uuidv7.go
Normal file
31
backend/internal/pkg/uuidv7/uuidv7.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package uuidv7
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// New returns a UUIDv7 string.
|
||||
func New() (string, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ms := uint64(time.Now().UnixMilli())
|
||||
b[0] = byte(ms >> 40)
|
||||
b[1] = byte(ms >> 32)
|
||||
b[2] = byte(ms >> 24)
|
||||
b[3] = byte(ms >> 16)
|
||||
b[4] = byte(ms >> 8)
|
||||
b[5] = byte(ms)
|
||||
b[6] = (b[6] & 0x0f) | 0x70
|
||||
b[8] = (b[8] & 0x3f) | 0x80
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
uint32(b[0])<<24|uint32(b[1])<<16|uint32(b[2])<<8|uint32(b[3]),
|
||||
uint16(b[4])<<8|uint16(b[5]),
|
||||
uint16(b[6])<<8|uint16(b[7]),
|
||||
uint16(b[8])<<8|uint16(b[9]),
|
||||
uint64(b[10])<<40|uint64(b[11])<<32|uint64(b[12])<<24|uint64(b[13])<<16|uint64(b[14])<<8|uint64(b[15]),
|
||||
), nil
|
||||
}
|
||||
498
backend/internal/repository/sora_repo.go
Normal file
498
backend/internal/repository/sora_repo.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
dbsoraaccount "github.com/Wei-Shaw/sub2api/ent/soraaccount"
|
||||
dbsoracachefile "github.com/Wei-Shaw/sub2api/ent/soracachefile"
|
||||
dbsoratask "github.com/Wei-Shaw/sub2api/ent/soratask"
|
||||
dbsorausagestat "github.com/Wei-Shaw/sub2api/ent/sorausagestat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// SoraAccount
|
||||
|
||||
type soraAccountRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSoraAccountRepository(client *ent.Client) service.SoraAccountRepository {
|
||||
return &soraAccountRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return mapSoraAccount(acc), nil
|
||||
}
|
||||
|
||||
func (r *soraAccountRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraAccount, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]*service.SoraAccount{}, nil
|
||||
}
|
||||
records, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDIn(accountIDs...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[int64]*service.SoraAccount, len(records))
|
||||
for _, record := range records {
|
||||
if record == nil {
|
||||
continue
|
||||
}
|
||||
result[record.AccountID] = mapSoraAccount(record)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
|
||||
if accountID <= 0 {
|
||||
return errors.New("invalid account_id")
|
||||
}
|
||||
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
|
||||
if err != nil && !ent.IsNotFound(err) {
|
||||
return err
|
||||
}
|
||||
if acc == nil {
|
||||
builder := r.client.SoraAccount.Create().SetAccountID(accountID)
|
||||
applySoraAccountUpdates(builder.Mutation(), updates)
|
||||
return builder.Exec(ctx)
|
||||
}
|
||||
updater := r.client.SoraAccount.UpdateOneID(acc.ID)
|
||||
applySoraAccountUpdates(updater.Mutation(), updates)
|
||||
return updater.Exec(ctx)
|
||||
}
|
||||
|
||||
func applySoraAccountUpdates(m *ent.SoraAccountMutation, updates map[string]any) {
|
||||
if updates == nil {
|
||||
return
|
||||
}
|
||||
for key, val := range updates {
|
||||
switch key {
|
||||
case "access_token":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetAccessToken(v)
|
||||
}
|
||||
case "session_token":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetSessionToken(v)
|
||||
}
|
||||
case "refresh_token":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetRefreshToken(v)
|
||||
}
|
||||
case "client_id":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetClientID(v)
|
||||
}
|
||||
case "email":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetEmail(v)
|
||||
}
|
||||
case "username":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetUsername(v)
|
||||
}
|
||||
case "remark":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetRemark(v)
|
||||
}
|
||||
case "plan_type":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetPlanType(v)
|
||||
}
|
||||
case "plan_title":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetPlanTitle(v)
|
||||
}
|
||||
case "subscription_end":
|
||||
if v, ok := val.(time.Time); ok {
|
||||
m.SetSubscriptionEnd(v)
|
||||
}
|
||||
if v, ok := val.(*time.Time); ok && v != nil {
|
||||
m.SetSubscriptionEnd(*v)
|
||||
}
|
||||
case "sora_supported":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetSoraSupported(v)
|
||||
}
|
||||
case "sora_invite_code":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetSoraInviteCode(v)
|
||||
}
|
||||
case "sora_redeemed_count":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetSoraRedeemedCount(v)
|
||||
}
|
||||
case "sora_remaining_count":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetSoraRemainingCount(v)
|
||||
}
|
||||
case "sora_total_count":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetSoraTotalCount(v)
|
||||
}
|
||||
case "sora_cooldown_until":
|
||||
if v, ok := val.(time.Time); ok {
|
||||
m.SetSoraCooldownUntil(v)
|
||||
}
|
||||
if v, ok := val.(*time.Time); ok && v != nil {
|
||||
m.SetSoraCooldownUntil(*v)
|
||||
}
|
||||
case "cooled_until":
|
||||
if v, ok := val.(time.Time); ok {
|
||||
m.SetCooledUntil(v)
|
||||
}
|
||||
if v, ok := val.(*time.Time); ok && v != nil {
|
||||
m.SetCooledUntil(*v)
|
||||
}
|
||||
case "image_enabled":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetImageEnabled(v)
|
||||
}
|
||||
case "video_enabled":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetVideoEnabled(v)
|
||||
}
|
||||
case "image_concurrency":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetImageConcurrency(v)
|
||||
}
|
||||
case "video_concurrency":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetVideoConcurrency(v)
|
||||
}
|
||||
case "is_expired":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetIsExpired(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mapSoraAccount(acc *ent.SoraAccount) *service.SoraAccount {
|
||||
if acc == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.SoraAccount{
|
||||
AccountID: acc.AccountID,
|
||||
AccessToken: derefString(acc.AccessToken),
|
||||
SessionToken: derefString(acc.SessionToken),
|
||||
RefreshToken: derefString(acc.RefreshToken),
|
||||
ClientID: derefString(acc.ClientID),
|
||||
Email: derefString(acc.Email),
|
||||
Username: derefString(acc.Username),
|
||||
Remark: derefString(acc.Remark),
|
||||
UseCount: acc.UseCount,
|
||||
PlanType: derefString(acc.PlanType),
|
||||
PlanTitle: derefString(acc.PlanTitle),
|
||||
SubscriptionEnd: acc.SubscriptionEnd,
|
||||
SoraSupported: acc.SoraSupported,
|
||||
SoraInviteCode: derefString(acc.SoraInviteCode),
|
||||
SoraRedeemedCount: acc.SoraRedeemedCount,
|
||||
SoraRemainingCount: acc.SoraRemainingCount,
|
||||
SoraTotalCount: acc.SoraTotalCount,
|
||||
SoraCooldownUntil: acc.SoraCooldownUntil,
|
||||
CooledUntil: acc.CooledUntil,
|
||||
ImageEnabled: acc.ImageEnabled,
|
||||
VideoEnabled: acc.VideoEnabled,
|
||||
ImageConcurrency: acc.ImageConcurrency,
|
||||
VideoConcurrency: acc.VideoConcurrency,
|
||||
IsExpired: acc.IsExpired,
|
||||
CreatedAt: acc.CreatedAt,
|
||||
UpdatedAt: acc.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func mapSoraUsageStat(stat *ent.SoraUsageStat) *service.SoraUsageStat {
|
||||
if stat == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.SoraUsageStat{
|
||||
AccountID: stat.AccountID,
|
||||
ImageCount: stat.ImageCount,
|
||||
VideoCount: stat.VideoCount,
|
||||
ErrorCount: stat.ErrorCount,
|
||||
LastErrorAt: stat.LastErrorAt,
|
||||
TodayImageCount: stat.TodayImageCount,
|
||||
TodayVideoCount: stat.TodayVideoCount,
|
||||
TodayErrorCount: stat.TodayErrorCount,
|
||||
TodayDate: stat.TodayDate,
|
||||
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
|
||||
CreatedAt: stat.CreatedAt,
|
||||
UpdatedAt: stat.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func mapSoraCacheFile(file *ent.SoraCacheFile) *service.SoraCacheFile {
|
||||
if file == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.SoraCacheFile{
|
||||
ID: int64(file.ID),
|
||||
TaskID: derefString(file.TaskID),
|
||||
AccountID: file.AccountID,
|
||||
UserID: file.UserID,
|
||||
MediaType: file.MediaType,
|
||||
OriginalURL: file.OriginalURL,
|
||||
CachePath: file.CachePath,
|
||||
CacheURL: file.CacheURL,
|
||||
SizeBytes: file.SizeBytes,
|
||||
CreatedAt: file.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// SoraUsageStat
|
||||
|
||||
type soraUsageStatRepository struct {
|
||||
client *ent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewSoraUsageStatRepository(client *ent.Client, sqlDB *sql.DB) service.SoraUsageStatRepository {
|
||||
return &soraUsageStatRepository{client: client, sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error {
|
||||
if accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
field := "image_count"
|
||||
todayField := "today_image_count"
|
||||
if isVideo {
|
||||
field = "video_count"
|
||||
todayField = "today_video_count"
|
||||
}
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
query := "INSERT INTO sora_usage_stats (account_id, " + field + ", " + todayField + ", today_date, consecutive_error_count, created_at, updated_at) " +
|
||||
"VALUES ($1, 1, 1, $2, 0, NOW(), NOW()) " +
|
||||
"ON CONFLICT (account_id) DO UPDATE SET " +
|
||||
field + " = sora_usage_stats." + field + " + 1, " +
|
||||
todayField + " = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats." + todayField + " + 1 ELSE 1 END, " +
|
||||
"today_date = $2, consecutive_error_count = 0, updated_at = NOW()"
|
||||
_, err := r.sql.ExecContext(ctx, query, accountID, today)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) RecordError(ctx context.Context, accountID int64) (int, error) {
|
||||
if accountID <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
query := "INSERT INTO sora_usage_stats (account_id, error_count, today_error_count, today_date, consecutive_error_count, last_error_at, created_at, updated_at) " +
|
||||
"VALUES ($1, 1, 1, $2, 1, NOW(), NOW(), NOW()) " +
|
||||
"ON CONFLICT (account_id) DO UPDATE SET " +
|
||||
"error_count = sora_usage_stats.error_count + 1, " +
|
||||
"today_error_count = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats.today_error_count + 1 ELSE 1 END, " +
|
||||
"today_date = $2, consecutive_error_count = sora_usage_stats.consecutive_error_count + 1, last_error_at = NOW(), updated_at = NOW() " +
|
||||
"RETURNING consecutive_error_count"
|
||||
var consecutive int
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{accountID, today}, &consecutive)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return consecutive, nil
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) ResetConsecutiveErrors(ctx context.Context, accountID int64) error {
|
||||
if accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
err := r.client.SoraUsageStat.Update().Where(dbsorausagestat.AccountIDEQ(accountID)).
|
||||
SetConsecutiveErrorCount(0).
|
||||
Exec(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraUsageStat, error) {
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
stat, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDEQ(accountID)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return mapSoraUsageStat(stat), nil
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraUsageStat, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]*service.SoraUsageStat{}, nil
|
||||
}
|
||||
stats, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDIn(accountIDs...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[int64]*service.SoraUsageStat, len(stats))
|
||||
for _, stat := range stats {
|
||||
if stat == nil {
|
||||
continue
|
||||
}
|
||||
result[stat.AccountID] = mapSoraUsageStat(stat)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) List(ctx context.Context, params pagination.PaginationParams) ([]*service.SoraUsageStat, *pagination.PaginationResult, error) {
|
||||
query := r.client.SoraUsageStat.Query()
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
stats, err := query.Order(ent.Desc(dbsorausagestat.FieldUpdatedAt)).
|
||||
Limit(params.Limit()).
|
||||
Offset(params.Offset()).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
result := make([]*service.SoraUsageStat, 0, len(stats))
|
||||
for _, stat := range stats {
|
||||
result = append(result, mapSoraUsageStat(stat))
|
||||
}
|
||||
return result, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SoraTask
|
||||
|
||||
type soraTaskRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSoraTaskRepository(client *ent.Client) service.SoraTaskRepository {
|
||||
return &soraTaskRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *soraTaskRepository) Create(ctx context.Context, task *service.SoraTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
builder := r.client.SoraTask.Create().
|
||||
SetTaskID(task.TaskID).
|
||||
SetAccountID(task.AccountID).
|
||||
SetModel(task.Model).
|
||||
SetPrompt(task.Prompt).
|
||||
SetStatus(task.Status).
|
||||
SetProgress(task.Progress).
|
||||
SetRetryCount(task.RetryCount)
|
||||
if task.ResultURLs != "" {
|
||||
builder.SetResultUrls(task.ResultURLs)
|
||||
}
|
||||
if task.ErrorMessage != "" {
|
||||
builder.SetErrorMessage(task.ErrorMessage)
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
builder.SetCreatedAt(time.Now())
|
||||
} else {
|
||||
builder.SetCreatedAt(task.CreatedAt)
|
||||
}
|
||||
if task.CompletedAt != nil {
|
||||
builder.SetCompletedAt(*task.CompletedAt)
|
||||
}
|
||||
return builder.Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *soraTaskRepository) UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error {
|
||||
if taskID == "" {
|
||||
return nil
|
||||
}
|
||||
builder := r.client.SoraTask.Update().Where(dbsoratask.TaskIDEQ(taskID)).
|
||||
SetStatus(status).
|
||||
SetProgress(progress)
|
||||
if resultURLs != "" {
|
||||
builder.SetResultUrls(resultURLs)
|
||||
}
|
||||
if errorMessage != "" {
|
||||
builder.SetErrorMessage(errorMessage)
|
||||
}
|
||||
if completedAt != nil {
|
||||
builder.SetCompletedAt(*completedAt)
|
||||
}
|
||||
_, err := builder.Save(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SoraCacheFile
|
||||
|
||||
type soraCacheFileRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSoraCacheFileRepository(client *ent.Client) service.SoraCacheFileRepository {
|
||||
return &soraCacheFileRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *soraCacheFileRepository) Create(ctx context.Context, file *service.SoraCacheFile) error {
|
||||
if file == nil {
|
||||
return nil
|
||||
}
|
||||
builder := r.client.SoraCacheFile.Create().
|
||||
SetAccountID(file.AccountID).
|
||||
SetUserID(file.UserID).
|
||||
SetMediaType(file.MediaType).
|
||||
SetOriginalURL(file.OriginalURL).
|
||||
SetCachePath(file.CachePath).
|
||||
SetCacheURL(file.CacheURL).
|
||||
SetSizeBytes(file.SizeBytes)
|
||||
if file.TaskID != "" {
|
||||
builder.SetTaskID(file.TaskID)
|
||||
}
|
||||
if file.CreatedAt.IsZero() {
|
||||
builder.SetCreatedAt(time.Now())
|
||||
} else {
|
||||
builder.SetCreatedAt(file.CreatedAt)
|
||||
}
|
||||
return builder.Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *soraCacheFileRepository) ListOldest(ctx context.Context, limit int) ([]*service.SoraCacheFile, error) {
|
||||
if limit <= 0 {
|
||||
return []*service.SoraCacheFile{}, nil
|
||||
}
|
||||
records, err := r.client.SoraCacheFile.Query().
|
||||
Order(dbsoracachefile.ByCreatedAt(entsql.OrderAsc())).
|
||||
Limit(limit).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]*service.SoraCacheFile, 0, len(records))
|
||||
for _, record := range records {
|
||||
if record == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, mapSoraCacheFile(record))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *soraCacheFileRepository) DeleteByIDs(ctx context.Context, ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := r.client.SoraCacheFile.Delete().Where(dbsoracachefile.IDIn(ids...)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
@@ -64,6 +64,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
NewUserAttributeValueRepository,
|
||||
NewSoraAccountRepository,
|
||||
NewSoraUsageStatRepository,
|
||||
NewSoraTaskRepository,
|
||||
NewSoraCacheFileRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
@@ -46,6 +49,22 @@ func SetupRouter(
|
||||
}
|
||||
}
|
||||
|
||||
// Serve Sora cached videos when enabled
|
||||
cacheVideoDir := ""
|
||||
cacheEnabled := false
|
||||
if settingService != nil {
|
||||
soraCfg := settingService.GetSoraConfig(context.Background())
|
||||
cacheEnabled = soraCfg.Cache.Enabled
|
||||
cacheVideoDir = strings.TrimSpace(soraCfg.Cache.VideoDir)
|
||||
} else if cfg != nil {
|
||||
cacheEnabled = cfg.Sora.Cache.Enabled
|
||||
cacheVideoDir = strings.TrimSpace(cfg.Sora.Cache.VideoDir)
|
||||
}
|
||||
if cacheEnabled && cacheVideoDir != "" {
|
||||
videoDir := filepath.Clean(cacheVideoDir)
|
||||
r.Static("/data/video", videoDir)
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ func RegisterAdminRoutes(
|
||||
// 账号管理
|
||||
registerAccountRoutes(admin, h)
|
||||
|
||||
// Sora 账号扩展
|
||||
registerSoraRoutes(admin, h)
|
||||
|
||||
// OpenAI OAuth
|
||||
registerOpenAIOAuthRoutes(admin, h)
|
||||
|
||||
@@ -229,6 +232,17 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerSoraRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
sora := admin.Group("/sora")
|
||||
{
|
||||
sora.GET("/accounts", h.Admin.SoraAccount.List)
|
||||
sora.GET("/accounts/:id", h.Admin.SoraAccount.Get)
|
||||
sora.PUT("/accounts/:id", h.Admin.SoraAccount.Upsert)
|
||||
sora.POST("/accounts/import", h.Admin.SoraAccount.BatchUpsert)
|
||||
sora.GET("/usage", h.Admin.SoraAccount.ListUsage)
|
||||
}
|
||||
}
|
||||
|
||||
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
openai := admin.Group("/openai")
|
||||
{
|
||||
|
||||
@@ -33,6 +33,7 @@ func RegisterGatewayRoutes(
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformSora = "sora"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
@@ -124,6 +125,28 @@ const (
|
||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||
|
||||
// =========================
|
||||
// Sora Settings
|
||||
// =========================
|
||||
|
||||
SettingKeySoraBaseURL = "sora_base_url"
|
||||
SettingKeySoraTimeout = "sora_timeout"
|
||||
SettingKeySoraMaxRetries = "sora_max_retries"
|
||||
SettingKeySoraPollInterval = "sora_poll_interval"
|
||||
SettingKeySoraCallLogicMode = "sora_call_logic_mode"
|
||||
SettingKeySoraCacheEnabled = "sora_cache_enabled"
|
||||
SettingKeySoraCacheBaseDir = "sora_cache_base_dir"
|
||||
SettingKeySoraCacheVideoDir = "sora_cache_video_dir"
|
||||
SettingKeySoraCacheMaxBytes = "sora_cache_max_bytes"
|
||||
SettingKeySoraCacheAllowedHosts = "sora_cache_allowed_hosts"
|
||||
SettingKeySoraCacheUserDirEnabled = "sora_cache_user_dir_enabled"
|
||||
SettingKeySoraWatermarkFreeEnabled = "sora_watermark_free_enabled"
|
||||
SettingKeySoraWatermarkFreeParseMethod = "sora_watermark_free_parse_method"
|
||||
SettingKeySoraWatermarkFreeCustomParseURL = "sora_watermark_free_custom_parse_url"
|
||||
SettingKeySoraWatermarkFreeCustomParseToken = "sora_watermark_free_custom_parse_token"
|
||||
SettingKeySoraWatermarkFreeFallbackOnFailure = "sora_watermark_free_fallback_on_failure"
|
||||
SettingKeySoraTokenRefreshEnabled = "sora_token_refresh_enabled"
|
||||
|
||||
// =========================
|
||||
// Ops Monitoring (vNext)
|
||||
// =========================
|
||||
|
||||
@@ -378,7 +378,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformSora, PlatformAntigravity}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil {
|
||||
@@ -661,7 +661,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformSora, PlatformAntigravity}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
|
||||
@@ -219,6 +219,29 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch)
|
||||
updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt
|
||||
|
||||
// Sora settings
|
||||
updates[SettingKeySoraBaseURL] = strings.TrimSpace(settings.SoraBaseURL)
|
||||
updates[SettingKeySoraTimeout] = strconv.Itoa(settings.SoraTimeout)
|
||||
updates[SettingKeySoraMaxRetries] = strconv.Itoa(settings.SoraMaxRetries)
|
||||
updates[SettingKeySoraPollInterval] = strconv.FormatFloat(settings.SoraPollInterval, 'f', -1, 64)
|
||||
updates[SettingKeySoraCallLogicMode] = settings.SoraCallLogicMode
|
||||
updates[SettingKeySoraCacheEnabled] = strconv.FormatBool(settings.SoraCacheEnabled)
|
||||
updates[SettingKeySoraCacheBaseDir] = settings.SoraCacheBaseDir
|
||||
updates[SettingKeySoraCacheVideoDir] = settings.SoraCacheVideoDir
|
||||
updates[SettingKeySoraCacheMaxBytes] = strconv.FormatInt(settings.SoraCacheMaxBytes, 10)
|
||||
allowedHostsRaw, err := marshalStringSliceSetting(settings.SoraCacheAllowedHosts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sora cache allowed hosts: %w", err)
|
||||
}
|
||||
updates[SettingKeySoraCacheAllowedHosts] = allowedHostsRaw
|
||||
updates[SettingKeySoraCacheUserDirEnabled] = strconv.FormatBool(settings.SoraCacheUserDirEnabled)
|
||||
updates[SettingKeySoraWatermarkFreeEnabled] = strconv.FormatBool(settings.SoraWatermarkFreeEnabled)
|
||||
updates[SettingKeySoraWatermarkFreeParseMethod] = settings.SoraWatermarkFreeParseMethod
|
||||
updates[SettingKeySoraWatermarkFreeCustomParseURL] = strings.TrimSpace(settings.SoraWatermarkFreeCustomParseURL)
|
||||
updates[SettingKeySoraWatermarkFreeCustomParseToken] = settings.SoraWatermarkFreeCustomParseToken
|
||||
updates[SettingKeySoraWatermarkFreeFallbackOnFailure] = strconv.FormatBool(settings.SoraWatermarkFreeFallbackOnFailure)
|
||||
updates[SettingKeySoraTokenRefreshEnabled] = strconv.FormatBool(settings.SoraTokenRefreshEnabled)
|
||||
|
||||
// Ops monitoring (vNext)
|
||||
updates[SettingKeyOpsMonitoringEnabled] = strconv.FormatBool(settings.OpsMonitoringEnabled)
|
||||
updates[SettingKeyOpsRealtimeMonitoringEnabled] = strconv.FormatBool(settings.OpsRealtimeMonitoringEnabled)
|
||||
@@ -227,7 +250,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds)
|
||||
}
|
||||
|
||||
err := s.settingRepo.SetMultiple(ctx, updates)
|
||||
err = s.settingRepo.SetMultiple(ctx, updates)
|
||||
if err == nil && s.onUpdate != nil {
|
||||
s.onUpdate() // Invalidate cache after settings update
|
||||
}
|
||||
@@ -295,6 +318,41 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// GetSoraConfig 获取 Sora 配置(优先读取 DB 设置,回退 config.yaml)
|
||||
func (s *SettingService) GetSoraConfig(ctx context.Context) config.SoraConfig {
|
||||
base := config.SoraConfig{}
|
||||
if s.cfg != nil {
|
||||
base = s.cfg.Sora
|
||||
}
|
||||
if s.settingRepo == nil {
|
||||
return base
|
||||
}
|
||||
keys := []string{
|
||||
SettingKeySoraBaseURL,
|
||||
SettingKeySoraTimeout,
|
||||
SettingKeySoraMaxRetries,
|
||||
SettingKeySoraPollInterval,
|
||||
SettingKeySoraCallLogicMode,
|
||||
SettingKeySoraCacheEnabled,
|
||||
SettingKeySoraCacheBaseDir,
|
||||
SettingKeySoraCacheVideoDir,
|
||||
SettingKeySoraCacheMaxBytes,
|
||||
SettingKeySoraCacheAllowedHosts,
|
||||
SettingKeySoraCacheUserDirEnabled,
|
||||
SettingKeySoraWatermarkFreeEnabled,
|
||||
SettingKeySoraWatermarkFreeParseMethod,
|
||||
SettingKeySoraWatermarkFreeCustomParseURL,
|
||||
SettingKeySoraWatermarkFreeCustomParseToken,
|
||||
SettingKeySoraWatermarkFreeFallbackOnFailure,
|
||||
SettingKeySoraTokenRefreshEnabled,
|
||||
}
|
||||
values, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return base
|
||||
}
|
||||
return mergeSoraConfig(base, values)
|
||||
}
|
||||
|
||||
// InitializeDefaultSettings 初始化默认设置
|
||||
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// 检查是否已有设置
|
||||
@@ -308,6 +366,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// 初始化默认设置
|
||||
soraCfg := config.SoraConfig{}
|
||||
if s.cfg != nil {
|
||||
soraCfg = s.cfg.Sora
|
||||
}
|
||||
allowedHostsRaw, _ := marshalStringSliceSetting(soraCfg.Cache.AllowedHosts)
|
||||
|
||||
defaults := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
@@ -328,6 +392,25 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyEnableIdentityPatch: "true",
|
||||
SettingKeyIdentityPatchPrompt: "",
|
||||
|
||||
// Sora defaults
|
||||
SettingKeySoraBaseURL: soraCfg.BaseURL,
|
||||
SettingKeySoraTimeout: strconv.Itoa(soraCfg.Timeout),
|
||||
SettingKeySoraMaxRetries: strconv.Itoa(soraCfg.MaxRetries),
|
||||
SettingKeySoraPollInterval: strconv.FormatFloat(soraCfg.PollInterval, 'f', -1, 64),
|
||||
SettingKeySoraCallLogicMode: soraCfg.CallLogicMode,
|
||||
SettingKeySoraCacheEnabled: strconv.FormatBool(soraCfg.Cache.Enabled),
|
||||
SettingKeySoraCacheBaseDir: soraCfg.Cache.BaseDir,
|
||||
SettingKeySoraCacheVideoDir: soraCfg.Cache.VideoDir,
|
||||
SettingKeySoraCacheMaxBytes: strconv.FormatInt(soraCfg.Cache.MaxBytes, 10),
|
||||
SettingKeySoraCacheAllowedHosts: allowedHostsRaw,
|
||||
SettingKeySoraCacheUserDirEnabled: strconv.FormatBool(soraCfg.Cache.UserDirEnabled),
|
||||
SettingKeySoraWatermarkFreeEnabled: strconv.FormatBool(soraCfg.WatermarkFree.Enabled),
|
||||
SettingKeySoraWatermarkFreeParseMethod: soraCfg.WatermarkFree.ParseMethod,
|
||||
SettingKeySoraWatermarkFreeCustomParseURL: soraCfg.WatermarkFree.CustomParseURL,
|
||||
SettingKeySoraWatermarkFreeCustomParseToken: soraCfg.WatermarkFree.CustomParseToken,
|
||||
SettingKeySoraWatermarkFreeFallbackOnFailure: strconv.FormatBool(soraCfg.WatermarkFree.FallbackOnFailure),
|
||||
SettingKeySoraTokenRefreshEnabled: strconv.FormatBool(soraCfg.TokenRefresh.Enabled),
|
||||
|
||||
// Ops monitoring defaults (vNext)
|
||||
SettingKeyOpsMonitoringEnabled: "true",
|
||||
SettingKeyOpsRealtimeMonitoringEnabled: "true",
|
||||
@@ -434,6 +517,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
}
|
||||
result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt]
|
||||
|
||||
// Sora settings
|
||||
soraCfg := s.parseSoraConfig(settings)
|
||||
result.SoraBaseURL = soraCfg.BaseURL
|
||||
result.SoraTimeout = soraCfg.Timeout
|
||||
result.SoraMaxRetries = soraCfg.MaxRetries
|
||||
result.SoraPollInterval = soraCfg.PollInterval
|
||||
result.SoraCallLogicMode = soraCfg.CallLogicMode
|
||||
result.SoraCacheEnabled = soraCfg.Cache.Enabled
|
||||
result.SoraCacheBaseDir = soraCfg.Cache.BaseDir
|
||||
result.SoraCacheVideoDir = soraCfg.Cache.VideoDir
|
||||
result.SoraCacheMaxBytes = soraCfg.Cache.MaxBytes
|
||||
result.SoraCacheAllowedHosts = soraCfg.Cache.AllowedHosts
|
||||
result.SoraCacheUserDirEnabled = soraCfg.Cache.UserDirEnabled
|
||||
result.SoraWatermarkFreeEnabled = soraCfg.WatermarkFree.Enabled
|
||||
result.SoraWatermarkFreeParseMethod = soraCfg.WatermarkFree.ParseMethod
|
||||
result.SoraWatermarkFreeCustomParseURL = soraCfg.WatermarkFree.CustomParseURL
|
||||
result.SoraWatermarkFreeCustomParseToken = soraCfg.WatermarkFree.CustomParseToken
|
||||
result.SoraWatermarkFreeFallbackOnFailure = soraCfg.WatermarkFree.FallbackOnFailure
|
||||
result.SoraTokenRefreshEnabled = soraCfg.TokenRefresh.Enabled
|
||||
|
||||
// Ops monitoring settings (default: enabled, fail-open)
|
||||
result.OpsMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsMonitoringEnabled])
|
||||
result.OpsRealtimeMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsRealtimeMonitoringEnabled])
|
||||
@@ -471,6 +574,131 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func (s *SettingService) parseSoraConfig(settings map[string]string) config.SoraConfig {
|
||||
base := config.SoraConfig{}
|
||||
if s.cfg != nil {
|
||||
base = s.cfg.Sora
|
||||
}
|
||||
return mergeSoraConfig(base, settings)
|
||||
}
|
||||
|
||||
func mergeSoraConfig(base config.SoraConfig, settings map[string]string) config.SoraConfig {
|
||||
cfg := base
|
||||
if settings == nil {
|
||||
return cfg
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraBaseURL]; ok {
|
||||
if trimmed := strings.TrimSpace(raw); trimmed != "" {
|
||||
cfg.BaseURL = trimmed
|
||||
}
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraTimeout]; ok {
|
||||
if v, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil && v > 0 {
|
||||
cfg.Timeout = v
|
||||
}
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraMaxRetries]; ok {
|
||||
if v, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil && v >= 0 {
|
||||
cfg.MaxRetries = v
|
||||
}
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraPollInterval]; ok {
|
||||
if v, err := strconv.ParseFloat(strings.TrimSpace(raw), 64); err == nil && v > 0 {
|
||||
cfg.PollInterval = v
|
||||
}
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraCallLogicMode]; ok && strings.TrimSpace(raw) != "" {
|
||||
cfg.CallLogicMode = strings.TrimSpace(raw)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraCacheEnabled]; ok {
|
||||
cfg.Cache.Enabled = parseBoolSetting(raw, cfg.Cache.Enabled)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraCacheBaseDir]; ok && strings.TrimSpace(raw) != "" {
|
||||
cfg.Cache.BaseDir = strings.TrimSpace(raw)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraCacheVideoDir]; ok && strings.TrimSpace(raw) != "" {
|
||||
cfg.Cache.VideoDir = strings.TrimSpace(raw)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraCacheMaxBytes]; ok {
|
||||
if v, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64); err == nil && v >= 0 {
|
||||
cfg.Cache.MaxBytes = v
|
||||
}
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraCacheAllowedHosts]; ok {
|
||||
cfg.Cache.AllowedHosts = parseStringSliceSetting(raw)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraCacheUserDirEnabled]; ok {
|
||||
cfg.Cache.UserDirEnabled = parseBoolSetting(raw, cfg.Cache.UserDirEnabled)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraWatermarkFreeEnabled]; ok {
|
||||
cfg.WatermarkFree.Enabled = parseBoolSetting(raw, cfg.WatermarkFree.Enabled)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraWatermarkFreeParseMethod]; ok && strings.TrimSpace(raw) != "" {
|
||||
cfg.WatermarkFree.ParseMethod = strings.TrimSpace(raw)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraWatermarkFreeCustomParseURL]; ok && strings.TrimSpace(raw) != "" {
|
||||
cfg.WatermarkFree.CustomParseURL = strings.TrimSpace(raw)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraWatermarkFreeCustomParseToken]; ok {
|
||||
cfg.WatermarkFree.CustomParseToken = raw
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraWatermarkFreeFallbackOnFailure]; ok {
|
||||
cfg.WatermarkFree.FallbackOnFailure = parseBoolSetting(raw, cfg.WatermarkFree.FallbackOnFailure)
|
||||
}
|
||||
if raw, ok := settings[SettingKeySoraTokenRefreshEnabled]; ok {
|
||||
cfg.TokenRefresh.Enabled = parseBoolSetting(raw, cfg.TokenRefresh.Enabled)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func parseBoolSetting(raw string, fallback bool) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return fallback
|
||||
}
|
||||
if v, err := strconv.ParseBool(trimmed); err == nil {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func parseStringSliceSetting(raw string) []string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return []string{}
|
||||
}
|
||||
var values []string
|
||||
if err := json.Unmarshal([]byte(trimmed), &values); err == nil {
|
||||
return normalizeStringSlice(values)
|
||||
}
|
||||
parts := strings.FieldsFunc(trimmed, func(r rune) bool {
|
||||
return r == ',' || r == '\n' || r == ';'
|
||||
})
|
||||
return normalizeStringSlice(parts)
|
||||
}
|
||||
|
||||
func marshalStringSliceSetting(values []string) (string, error) {
|
||||
normalized := normalizeStringSlice(values)
|
||||
data, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func normalizeStringSlice(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
|
||||
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled)
|
||||
|
||||
@@ -49,6 +49,25 @@ type SystemSettings struct {
|
||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||
|
||||
// Sora configuration
|
||||
SoraBaseURL string
|
||||
SoraTimeout int
|
||||
SoraMaxRetries int
|
||||
SoraPollInterval float64
|
||||
SoraCallLogicMode string
|
||||
SoraCacheEnabled bool
|
||||
SoraCacheBaseDir string
|
||||
SoraCacheVideoDir string
|
||||
SoraCacheMaxBytes int64
|
||||
SoraCacheAllowedHosts []string
|
||||
SoraCacheUserDirEnabled bool
|
||||
SoraWatermarkFreeEnabled bool
|
||||
SoraWatermarkFreeParseMethod string
|
||||
SoraWatermarkFreeCustomParseURL string
|
||||
SoraWatermarkFreeCustomParseToken string
|
||||
SoraWatermarkFreeFallbackOnFailure bool
|
||||
SoraTokenRefreshEnabled bool
|
||||
|
||||
// Ops monitoring (vNext)
|
||||
OpsMonitoringEnabled bool
|
||||
OpsRealtimeMonitoringEnabled bool
|
||||
|
||||
156
backend/internal/service/sora_cache_cleanup_service.go
Normal file
156
backend/internal/service/sora_cache_cleanup_service.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
soraCacheCleanupInterval = time.Hour
|
||||
soraCacheCleanupBatch = 200
|
||||
)
|
||||
|
||||
// SoraCacheCleanupService 负责清理 Sora 视频缓存文件。
|
||||
type SoraCacheCleanupService struct {
|
||||
cacheRepo SoraCacheFileRepository
|
||||
settingService *SettingService
|
||||
cfg *config.Config
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewSoraCacheCleanupService(cacheRepo SoraCacheFileRepository, settingService *SettingService, cfg *config.Config) *SoraCacheCleanupService {
|
||||
return &SoraCacheCleanupService{
|
||||
cacheRepo: cacheRepo,
|
||||
settingService: settingService,
|
||||
cfg: cfg,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraCacheCleanupService) Start() {
|
||||
if s == nil || s.cacheRepo == nil {
|
||||
return
|
||||
}
|
||||
go s.cleanupLoop()
|
||||
}
|
||||
|
||||
func (s *SoraCacheCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SoraCacheCleanupService) cleanupLoop() {
|
||||
ticker := time.NewTicker(soraCacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.cleanupOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.cleanupOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraCacheCleanupService) cleanupOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if s.cacheRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := s.getSoraConfig(ctx)
|
||||
videoDir := strings.TrimSpace(cfg.Cache.VideoDir)
|
||||
if videoDir == "" {
|
||||
return
|
||||
}
|
||||
maxBytes := cfg.Cache.MaxBytes
|
||||
if maxBytes <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
size, err := dirSize(videoDir)
|
||||
if err != nil {
|
||||
log.Printf("[SoraCacheCleanup] 计算目录大小失败: %v", err)
|
||||
return
|
||||
}
|
||||
if size <= maxBytes {
|
||||
return
|
||||
}
|
||||
|
||||
for size > maxBytes {
|
||||
entries, err := s.cacheRepo.ListOldest(ctx, soraCacheCleanupBatch)
|
||||
if err != nil {
|
||||
log.Printf("[SoraCacheCleanup] 读取缓存记录失败: %v", err)
|
||||
return
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
log.Printf("[SoraCacheCleanup] 无缓存记录但目录仍超限: size=%d max=%d", size, maxBytes)
|
||||
return
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry == nil {
|
||||
continue
|
||||
}
|
||||
removedSize := entry.SizeBytes
|
||||
if entry.CachePath != "" {
|
||||
if info, err := os.Stat(entry.CachePath); err == nil {
|
||||
if removedSize <= 0 {
|
||||
removedSize = info.Size()
|
||||
}
|
||||
}
|
||||
if err := os.Remove(entry.CachePath); err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("[SoraCacheCleanup] 删除缓存文件失败: path=%s err=%v", entry.CachePath, err)
|
||||
}
|
||||
}
|
||||
|
||||
if entry.ID > 0 {
|
||||
ids = append(ids, entry.ID)
|
||||
}
|
||||
if removedSize > 0 {
|
||||
size -= removedSize
|
||||
if size < 0 {
|
||||
size = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ids) > 0 {
|
||||
if err := s.cacheRepo.DeleteByIDs(ctx, ids); err != nil {
|
||||
log.Printf("[SoraCacheCleanup] 删除缓存记录失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if size > maxBytes {
|
||||
if refreshed, err := dirSize(videoDir); err == nil {
|
||||
size = refreshed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraCacheCleanupService) getSoraConfig(ctx context.Context) config.SoraConfig {
|
||||
if s.settingService != nil {
|
||||
return s.settingService.GetSoraConfig(ctx)
|
||||
}
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Sora
|
||||
}
|
||||
return config.SoraConfig{}
|
||||
}
|
||||
246
backend/internal/service/sora_cache_service.go
Normal file
246
backend/internal/service/sora_cache_service.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/uuidv7"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
// SoraCacheService 提供 Sora 视频缓存能力。
|
||||
type SoraCacheService struct {
|
||||
cfg *config.Config
|
||||
cacheRepo SoraCacheFileRepository
|
||||
settingService *SettingService
|
||||
accountRepo AccountRepository
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
// NewSoraCacheService 创建 SoraCacheService。
|
||||
func NewSoraCacheService(cfg *config.Config, cacheRepo SoraCacheFileRepository, settingService *SettingService, accountRepo AccountRepository, httpUpstream HTTPUpstream) *SoraCacheService {
|
||||
return &SoraCacheService{
|
||||
cfg: cfg,
|
||||
cacheRepo: cacheRepo,
|
||||
settingService: settingService,
|
||||
accountRepo: accountRepo,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraCacheService) CacheVideo(ctx context.Context, accountID, userID int64, taskID, mediaURL string) (*SoraCacheFile, error) {
|
||||
cfg := s.getSoraConfig(ctx)
|
||||
if !cfg.Cache.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(mediaURL)
|
||||
if trimmed == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
allowedHosts := cfg.Cache.AllowedHosts
|
||||
useAllowlist := true
|
||||
if len(allowedHosts) == 0 {
|
||||
if s.cfg != nil {
|
||||
allowedHosts = s.cfg.Security.URLAllowlist.UpstreamHosts
|
||||
useAllowlist = s.cfg.Security.URLAllowlist.Enabled
|
||||
} else {
|
||||
useAllowlist = false
|
||||
}
|
||||
}
|
||||
|
||||
if useAllowlist {
|
||||
if _, err := urlvalidator.ValidateHTTPSURL(trimmed, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: allowedHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg != nil && s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("缓存下载地址不合法: %w", err)
|
||||
}
|
||||
} else {
|
||||
allowInsecure := false
|
||||
if s.cfg != nil {
|
||||
allowInsecure = s.cfg.Security.URLAllowlist.AllowInsecureHTTP
|
||||
}
|
||||
if _, err := urlvalidator.ValidateURLFormat(trimmed, allowInsecure); err != nil {
|
||||
return nil, fmt.Errorf("缓存下载地址不合法: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
videoDir := strings.TrimSpace(cfg.Cache.VideoDir)
|
||||
if videoDir == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if cfg.Cache.MaxBytes > 0 {
|
||||
size, err := dirSize(videoDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if size >= cfg.Cache.MaxBytes {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
relativeDir := ""
|
||||
if cfg.Cache.UserDirEnabled && userID > 0 {
|
||||
relativeDir = fmt.Sprintf("u_%d", userID)
|
||||
}
|
||||
|
||||
targetDir := filepath.Join(videoDir, relativeDir)
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uuid, err := uuidv7.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := deriveFileName(trimmed)
|
||||
if name == "" {
|
||||
name = "video.mp4"
|
||||
}
|
||||
name = sanitizeFileName(name)
|
||||
filename := uuid + "_" + name
|
||||
cachePath := filepath.Join(targetDir, filename)
|
||||
|
||||
resp, err := s.downloadMedia(ctx, accountID, trimmed, time.Duration(cfg.Timeout)*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("缓存下载失败: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
out, err := os.Create(cachePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
written, err := io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheURL := buildCacheURL(relativeDir, filename)
|
||||
|
||||
record := &SoraCacheFile{
|
||||
TaskID: taskID,
|
||||
AccountID: accountID,
|
||||
UserID: userID,
|
||||
MediaType: "video",
|
||||
OriginalURL: trimmed,
|
||||
CachePath: cachePath,
|
||||
CacheURL: cacheURL,
|
||||
SizeBytes: written,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if s.cacheRepo != nil {
|
||||
if err := s.cacheRepo.Create(ctx, record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func buildCacheURL(relativeDir, filename string) string {
|
||||
base := "/data/video"
|
||||
if relativeDir != "" {
|
||||
return path.Join(base, relativeDir, filename)
|
||||
}
|
||||
return path.Join(base, filename)
|
||||
}
|
||||
|
||||
func (s *SoraCacheService) getSoraConfig(ctx context.Context) config.SoraConfig {
|
||||
if s.settingService != nil {
|
||||
return s.settingService.GetSoraConfig(ctx)
|
||||
}
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Sora
|
||||
}
|
||||
return config.SoraConfig{}
|
||||
}
|
||||
|
||||
func (s *SoraCacheService) downloadMedia(ctx context.Context, accountID int64, mediaURL string, timeout time.Duration) (*http.Response, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 120 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", mediaURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
|
||||
if s.httpUpstream == nil {
|
||||
client := &http.Client{Timeout: timeout}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
var accountConcurrency int
|
||||
proxyURL := ""
|
||||
if s.accountRepo != nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account != nil {
|
||||
accountConcurrency = account.Concurrency
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
}
|
||||
}
|
||||
enableTLS := false
|
||||
if s.cfg != nil {
|
||||
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
|
||||
}
|
||||
return s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
|
||||
}
|
||||
|
||||
func deriveFileName(rawURL string) string {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
name := path.Base(parsed.Path)
|
||||
if name == "/" || name == "." {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func sanitizeFileName(name string) string {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
sanitized := strings.Map(func(r rune) rune {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
return r
|
||||
case r >= 'A' && r <= 'Z':
|
||||
return r
|
||||
case r >= '0' && r <= '9':
|
||||
return r
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
return r
|
||||
case r == ' ': // 空格替换为下划线
|
||||
return '_'
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}, name)
|
||||
return strings.TrimLeft(sanitized, ".")
|
||||
}
|
||||
28
backend/internal/service/sora_cache_utils.go
Normal file
28
backend/internal/service/sora_cache_utils.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func dirSize(root string) (int64, error) {
|
||||
var size int64
|
||||
err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
size += info.Size()
|
||||
return nil
|
||||
})
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
return size, err
|
||||
}
|
||||
853
backend/internal/service/sora_gateway_service.go
Normal file
853
backend/internal/service/sora_gateway_service.go
Normal file
@@ -0,0 +1,853 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
const (
|
||||
soraErrorDisableThreshold = 5
|
||||
maxImageDownloadSize = 20 * 1024 * 1024 // 20MB
|
||||
maxVideoDownloadSize = 200 * 1024 * 1024 // 200MB
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSoraAccountMissingToken = errors.New("sora account missing access token")
|
||||
ErrSoraAccountNotEligible = errors.New("sora account not eligible")
|
||||
)
|
||||
|
||||
// SoraGenerationRequest 表示 Sora 生成请求。
|
||||
type SoraGenerationRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Image string
|
||||
Video string
|
||||
RemixTargetID string
|
||||
Stream bool
|
||||
UserID int64
|
||||
}
|
||||
|
||||
// SoraGenerationResult 表示 Sora 生成结果。
|
||||
type SoraGenerationResult struct {
|
||||
Content string
|
||||
MediaType string
|
||||
ResultURLs []string
|
||||
TaskID string
|
||||
}
|
||||
|
||||
// SoraGatewayService 处理 Sora 生成流程。
|
||||
type SoraGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository
|
||||
usageRepo SoraUsageStatRepository
|
||||
taskRepo SoraTaskRepository
|
||||
cacheService *SoraCacheService
|
||||
settingService *SettingService
|
||||
concurrency *ConcurrencyService
|
||||
cfg *config.Config
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
// NewSoraGatewayService 创建 SoraGatewayService。
|
||||
func NewSoraGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
usageRepo SoraUsageStatRepository,
|
||||
taskRepo SoraTaskRepository,
|
||||
cacheService *SoraCacheService,
|
||||
settingService *SettingService,
|
||||
concurrencyService *ConcurrencyService,
|
||||
cfg *config.Config,
|
||||
httpUpstream HTTPUpstream,
|
||||
) *SoraGatewayService {
|
||||
return &SoraGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
soraAccountRepo: soraAccountRepo,
|
||||
usageRepo: usageRepo,
|
||||
taskRepo: taskRepo,
|
||||
cacheService: cacheService,
|
||||
settingService: settingService,
|
||||
concurrency: concurrencyService,
|
||||
cfg: cfg,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// ListModels 返回 Sora 模型列表。
|
||||
func (s *SoraGatewayService) ListModels() []sora.ModelListItem {
|
||||
return sora.ListModels()
|
||||
}
|
||||
|
||||
// Generate 执行 Sora 生成流程。
|
||||
func (s *SoraGatewayService) Generate(ctx context.Context, account *Account, req SoraGenerationRequest) (*SoraGenerationResult, error) {
|
||||
client, cfg := s.getClient(ctx)
|
||||
if client == nil {
|
||||
return nil, errors.New("sora client is not configured")
|
||||
}
|
||||
modelCfg, ok := sora.ModelConfigs[req.Model]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model: %s", req.Model)
|
||||
}
|
||||
accessToken, soraAcc, err := s.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if soraAcc != nil && soraAcc.SoraCooldownUntil != nil && time.Now().Before(*soraAcc.SoraCooldownUntil) {
|
||||
return nil, ErrSoraAccountNotEligible
|
||||
}
|
||||
if modelCfg.RequirePro && !isSoraProAccount(soraAcc) {
|
||||
return nil, ErrSoraAccountNotEligible
|
||||
}
|
||||
if modelCfg.Type == "video" && soraAcc != nil {
|
||||
if !soraAcc.VideoEnabled || !soraAcc.SoraSupported || soraAcc.IsExpired {
|
||||
return nil, ErrSoraAccountNotEligible
|
||||
}
|
||||
}
|
||||
if modelCfg.Type == "image" && soraAcc != nil {
|
||||
if !soraAcc.ImageEnabled || soraAcc.IsExpired {
|
||||
return nil, ErrSoraAccountNotEligible
|
||||
}
|
||||
}
|
||||
|
||||
opts := sora.RequestOptions{
|
||||
AccountID: account.ID,
|
||||
AccountConcurrency: account.Concurrency,
|
||||
AccessToken: accessToken,
|
||||
}
|
||||
if account.Proxy != nil {
|
||||
opts.ProxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
releaseFunc, err := s.acquireSoraSlots(ctx, account, soraAcc, modelCfg.Type == "video")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if releaseFunc != nil {
|
||||
defer releaseFunc()
|
||||
}
|
||||
|
||||
if modelCfg.Type == "prompt_enhance" {
|
||||
content, err := client.EnhancePrompt(ctx, opts, req.Prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SoraGenerationResult{Content: content, MediaType: "text"}, nil
|
||||
}
|
||||
|
||||
var mediaID string
|
||||
if req.Image != "" {
|
||||
data, err := s.loadImageBytes(ctx, opts, req.Image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mediaID, err = client.UploadImage(ctx, opts, data, "image.png")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if req.Video != "" && modelCfg.Type != "video" {
|
||||
return nil, errors.New("视频输入仅支持视频模型")
|
||||
}
|
||||
if req.Video != "" && req.Image != "" {
|
||||
return nil, errors.New("不能同时传入 image 与 video")
|
||||
}
|
||||
|
||||
var cleanupCharacter func()
|
||||
if req.Video != "" && req.RemixTargetID == "" {
|
||||
username, characterID, err := s.createCharacter(ctx, client, opts, req.Video)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return &SoraGenerationResult{
|
||||
Content: fmt.Sprintf("角色创建成功,角色名@%s", username),
|
||||
MediaType: "text",
|
||||
}, nil
|
||||
}
|
||||
if username != "" {
|
||||
req.Prompt = fmt.Sprintf("@%s %s", username, strings.TrimSpace(req.Prompt))
|
||||
}
|
||||
if characterID != "" {
|
||||
cleanupCharacter = func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
_ = client.DeleteCharacter(ctx, opts, characterID)
|
||||
}
|
||||
}
|
||||
}
|
||||
if cleanupCharacter != nil {
|
||||
defer cleanupCharacter()
|
||||
}
|
||||
|
||||
var taskID string
|
||||
if modelCfg.Type == "image" {
|
||||
taskID, err = client.GenerateImage(ctx, opts, req.Prompt, modelCfg.Width, modelCfg.Height, mediaID)
|
||||
} else {
|
||||
orientation := modelCfg.Orientation
|
||||
if orientation == "" {
|
||||
orientation = "landscape"
|
||||
}
|
||||
modelName := modelCfg.Model
|
||||
if modelName == "" {
|
||||
modelName = "sy_8"
|
||||
}
|
||||
size := modelCfg.Size
|
||||
if size == "" {
|
||||
size = "small"
|
||||
}
|
||||
if req.RemixTargetID != "" {
|
||||
taskID, err = client.RemixVideo(ctx, opts, req.RemixTargetID, req.Prompt, orientation, modelCfg.NFrames, "")
|
||||
} else if sora.IsStoryboardPrompt(req.Prompt) {
|
||||
formatted := sora.FormatStoryboardPrompt(req.Prompt)
|
||||
taskID, err = client.GenerateStoryboard(ctx, opts, formatted, orientation, modelCfg.NFrames, mediaID, "")
|
||||
} else {
|
||||
taskID, err = client.GenerateVideo(ctx, opts, req.Prompt, orientation, modelCfg.NFrames, mediaID, "", modelName, size)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.taskRepo != nil {
|
||||
_ = s.taskRepo.Create(ctx, &SoraTask{
|
||||
TaskID: taskID,
|
||||
AccountID: account.ID,
|
||||
Model: req.Model,
|
||||
Prompt: req.Prompt,
|
||||
Status: "processing",
|
||||
Progress: 0,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
result, err := s.pollResult(ctx, client, cfg, opts, taskID, modelCfg.Type == "video", req)
|
||||
if err != nil {
|
||||
if s.taskRepo != nil {
|
||||
_ = s.taskRepo.UpdateStatus(ctx, taskID, "failed", 0, "", err.Error(), timePtr(time.Now()))
|
||||
}
|
||||
consecutive := 0
|
||||
if s.usageRepo != nil {
|
||||
consecutive, _ = s.usageRepo.RecordError(ctx, account.ID)
|
||||
}
|
||||
if consecutive >= soraErrorDisableThreshold {
|
||||
_ = s.accountRepo.SetError(ctx, account.ID, "Sora 连续错误次数过多,已自动禁用")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.taskRepo != nil {
|
||||
payload, _ := json.Marshal(result.ResultURLs)
|
||||
_ = s.taskRepo.UpdateStatus(ctx, taskID, "completed", 100, string(payload), "", timePtr(time.Now()))
|
||||
}
|
||||
if s.usageRepo != nil {
|
||||
_ = s.usageRepo.RecordSuccess(ctx, account.ID, modelCfg.Type == "video")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollResult(ctx context.Context, client *sora.Client, cfg config.SoraConfig, opts sora.RequestOptions, taskID string, isVideo bool, req SoraGenerationRequest) (*SoraGenerationResult, error) {
|
||||
if taskID == "" {
|
||||
return nil, errors.New("missing task id")
|
||||
}
|
||||
pollInterval := 2 * time.Second
|
||||
if cfg.PollInterval > 0 {
|
||||
pollInterval = time.Duration(cfg.PollInterval*1000) * time.Millisecond
|
||||
}
|
||||
timeout := 300 * time.Second
|
||||
if cfg.Timeout > 0 {
|
||||
timeout = time.Duration(cfg.Timeout) * time.Second
|
||||
}
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
if isVideo {
|
||||
pending, err := client.GetPendingTasks(ctx, opts)
|
||||
if err == nil {
|
||||
for _, task := range pending {
|
||||
if stringFromMap(task, "id") == taskID {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
drafts, err := client.GetVideoDrafts(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items, _ := drafts["items"].([]any)
|
||||
for _, item := range items {
|
||||
entry, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if stringFromMap(entry, "task_id") != taskID {
|
||||
continue
|
||||
}
|
||||
url := firstNonEmpty(stringFromMap(entry, "downloadable_url"), stringFromMap(entry, "url"))
|
||||
reason := stringFromMap(entry, "reason_str")
|
||||
if url == "" {
|
||||
if reason == "" {
|
||||
reason = "视频生成失败"
|
||||
}
|
||||
return nil, errors.New(reason)
|
||||
}
|
||||
finalURL, err := s.handleWatermark(ctx, client, cfg, opts, url, entry, req, opts.AccountID, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SoraGenerationResult{
|
||||
Content: buildVideoMarkdown(finalURL),
|
||||
MediaType: "video",
|
||||
ResultURLs: []string{finalURL},
|
||||
TaskID: taskID,
|
||||
}, nil
|
||||
}
|
||||
} else {
|
||||
resp, err := client.GetImageTasks(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasks, _ := resp["task_responses"].([]any)
|
||||
for _, item := range tasks {
|
||||
entry, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if stringFromMap(entry, "id") != taskID {
|
||||
continue
|
||||
}
|
||||
status := stringFromMap(entry, "status")
|
||||
switch status {
|
||||
case "succeeded":
|
||||
urls := extractImageURLs(entry)
|
||||
if len(urls) == 0 {
|
||||
return nil, errors.New("image urls empty")
|
||||
}
|
||||
content := buildImageMarkdown(urls)
|
||||
return &SoraGenerationResult{
|
||||
Content: content,
|
||||
MediaType: "image",
|
||||
ResultURLs: urls,
|
||||
TaskID: taskID,
|
||||
}, nil
|
||||
case "failed":
|
||||
message := stringFromMap(entry, "error_message")
|
||||
if message == "" {
|
||||
message = "image generation failed"
|
||||
}
|
||||
return nil, errors.New(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
return nil, errors.New("generation timeout")
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) handleWatermark(ctx context.Context, client *sora.Client, cfg config.SoraConfig, opts sora.RequestOptions, url string, entry map[string]any, req SoraGenerationRequest, accountID int64, taskID string) (string, error) {
|
||||
if !cfg.WatermarkFree.Enabled {
|
||||
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
|
||||
}
|
||||
generationID := stringFromMap(entry, "id")
|
||||
if generationID == "" {
|
||||
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
|
||||
}
|
||||
postID, err := client.PostVideoForWatermarkFree(ctx, opts, generationID)
|
||||
if err != nil {
|
||||
if cfg.WatermarkFree.FallbackOnFailure {
|
||||
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if postID == "" {
|
||||
if cfg.WatermarkFree.FallbackOnFailure {
|
||||
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
|
||||
}
|
||||
return "", errors.New("watermark-free post id empty")
|
||||
}
|
||||
var parsedURL string
|
||||
if cfg.WatermarkFree.ParseMethod == "custom" {
|
||||
if cfg.WatermarkFree.CustomParseURL == "" || cfg.WatermarkFree.CustomParseToken == "" {
|
||||
return "", errors.New("custom parse 未配置")
|
||||
}
|
||||
parsedURL, err = s.fetchCustomWatermarkURL(ctx, cfg.WatermarkFree.CustomParseURL, cfg.WatermarkFree.CustomParseToken, postID)
|
||||
if err != nil {
|
||||
if cfg.WatermarkFree.FallbackOnFailure {
|
||||
return s.cacheVideo(ctx, url, req, accountID, taskID), nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
parsedURL = fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID)
|
||||
}
|
||||
cached := s.cacheVideo(ctx, parsedURL, req, accountID, taskID)
|
||||
_ = client.DeletePost(ctx, opts, postID)
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) cacheVideo(ctx context.Context, url string, req SoraGenerationRequest, accountID int64, taskID string) string {
|
||||
if s.cacheService == nil {
|
||||
return url
|
||||
}
|
||||
file, err := s.cacheService.CacheVideo(ctx, accountID, req.UserID, taskID, url)
|
||||
if err != nil || file == nil {
|
||||
return url
|
||||
}
|
||||
return file.CacheURL
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) getAccessToken(ctx context.Context, account *Account) (string, *SoraAccount, error) {
|
||||
if account == nil {
|
||||
return "", nil, errors.New("account is nil")
|
||||
}
|
||||
var soraAcc *SoraAccount
|
||||
if s.soraAccountRepo != nil {
|
||||
soraAcc, _ = s.soraAccountRepo.GetByAccountID(ctx, account.ID)
|
||||
}
|
||||
if soraAcc != nil && soraAcc.AccessToken != "" {
|
||||
return soraAcc.AccessToken, soraAcc, nil
|
||||
}
|
||||
if account.Credentials != nil {
|
||||
if v, ok := account.Credentials["access_token"].(string); ok && v != "" {
|
||||
return v, soraAcc, nil
|
||||
}
|
||||
if v, ok := account.Credentials["token"].(string); ok && v != "" {
|
||||
return v, soraAcc, nil
|
||||
}
|
||||
}
|
||||
return "", soraAcc, ErrSoraAccountMissingToken
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) getClient(ctx context.Context) (*sora.Client, config.SoraConfig) {
|
||||
cfg := s.getSoraConfig(ctx)
|
||||
if s.httpUpstream == nil {
|
||||
return nil, cfg
|
||||
}
|
||||
baseURL := strings.TrimSpace(cfg.BaseURL)
|
||||
if baseURL == "" {
|
||||
return nil, cfg
|
||||
}
|
||||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||||
if cfg.Timeout <= 0 {
|
||||
timeout = 120 * time.Second
|
||||
}
|
||||
enableTLS := false
|
||||
if s.cfg != nil {
|
||||
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
|
||||
}
|
||||
return sora.NewClient(baseURL, timeout, s.httpUpstream, enableTLS), cfg
|
||||
}
|
||||
|
||||
func decodeBase64(raw string) ([]byte, error) {
|
||||
data := raw
|
||||
if idx := strings.Index(raw, "base64,"); idx != -1 {
|
||||
data = raw[idx+7:]
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(data)
|
||||
}
|
||||
|
||||
func extractImageURLs(entry map[string]any) []string {
|
||||
generations, _ := entry["generations"].([]any)
|
||||
urls := make([]string, 0, len(generations))
|
||||
for _, gen := range generations {
|
||||
m, ok := gen.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if url, ok := m["url"].(string); ok && url != "" {
|
||||
urls = append(urls, url)
|
||||
}
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func buildImageMarkdown(urls []string) string {
|
||||
parts := make([]string, 0, len(urls))
|
||||
for _, u := range urls {
|
||||
parts = append(parts, fmt.Sprintf("", u))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
func buildVideoMarkdown(url string) string {
|
||||
return fmt.Sprintf("```html\n<video src='%s' controls></video>\n```", url)
|
||||
}
|
||||
|
||||
func stringFromMap(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isSoraProAccount(acc *SoraAccount) bool {
|
||||
if acc == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(acc.PlanType, "chatgpt_pro")
|
||||
}
|
||||
|
||||
func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
// fetchCustomWatermarkURL 使用自定义解析服务获取无水印视频 URL
|
||||
func (s *SoraGatewayService) fetchCustomWatermarkURL(ctx context.Context, parseURL, parseToken, postID string) (string, error) {
|
||||
// 使用项目的 URL 校验器验证 parseURL 格式,防止 SSRF 攻击
|
||||
if _, err := urlvalidator.ValidateHTTPSURL(parseURL, urlvalidator.ValidationOptions{}); err != nil {
|
||||
return "", fmt.Errorf("无效的解析服务地址: %w", err)
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"url": fmt.Sprintf("https://sora.chatgpt.com/p/%s", postID),
|
||||
"token": parseToken,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", strings.TrimRight(parseURL, "/")+"/get-sora-link", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 复用 httpUpstream,遵守代理和 TLS 配置
|
||||
enableTLS := false
|
||||
if s.cfg != nil {
|
||||
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
|
||||
}
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, "", 0, 1, enableTLS)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("custom parse failed: %d", resp.StatusCode)
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if errMsg, ok := parsed["error"].(string); ok && errMsg != "" {
|
||||
return "", errors.New(errMsg)
|
||||
}
|
||||
if link, ok := parsed["download_link"].(string); ok {
|
||||
return link, nil
|
||||
}
|
||||
return "", errors.New("custom parse response missing download_link")
|
||||
}
|
||||
|
||||
const (
|
||||
soraSlotImageLock int64 = 1
|
||||
soraSlotImageLimit int64 = 2
|
||||
soraSlotVideoLimit int64 = 3
|
||||
soraDefaultUsername = "character"
|
||||
)
|
||||
|
||||
func (s *SoraGatewayService) CallLogicMode(ctx context.Context) string {
|
||||
return strings.TrimSpace(s.getSoraConfig(ctx).CallLogicMode)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) getSoraConfig(ctx context.Context) config.SoraConfig {
|
||||
if s.settingService != nil {
|
||||
return s.settingService.GetSoraConfig(ctx)
|
||||
}
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Sora
|
||||
}
|
||||
return config.SoraConfig{}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) acquireSoraSlots(ctx context.Context, account *Account, soraAcc *SoraAccount, isVideo bool) (func(), error) {
|
||||
if s.concurrency == nil || account == nil || soraAcc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
releases := make([]func(), 0, 2)
|
||||
appendRelease := func(release func()) {
|
||||
if release != nil {
|
||||
releases = append(releases, release)
|
||||
}
|
||||
}
|
||||
// 错误时释放所有已获取的槽位
|
||||
releaseAll := func() {
|
||||
for _, r := range releases {
|
||||
r()
|
||||
}
|
||||
}
|
||||
|
||||
if isVideo {
|
||||
if soraAcc.VideoConcurrency > 0 {
|
||||
release, err := s.acquireSoraSlot(ctx, account.ID, soraAcc.VideoConcurrency, soraSlotVideoLimit)
|
||||
if err != nil {
|
||||
releaseAll()
|
||||
return nil, err
|
||||
}
|
||||
appendRelease(release)
|
||||
}
|
||||
} else {
|
||||
release, err := s.acquireSoraSlot(ctx, account.ID, 1, soraSlotImageLock)
|
||||
if err != nil {
|
||||
releaseAll()
|
||||
return nil, err
|
||||
}
|
||||
appendRelease(release)
|
||||
if soraAcc.ImageConcurrency > 0 {
|
||||
release, err := s.acquireSoraSlot(ctx, account.ID, soraAcc.ImageConcurrency, soraSlotImageLimit)
|
||||
if err != nil {
|
||||
releaseAll() // 释放已获取的 soraSlotImageLock
|
||||
return nil, err
|
||||
}
|
||||
appendRelease(release)
|
||||
}
|
||||
}
|
||||
|
||||
if len(releases) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return func() {
|
||||
for _, release := range releases {
|
||||
release()
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) acquireSoraSlot(ctx context.Context, accountID int64, maxConcurrency int, slotType int64) (func(), error) {
|
||||
if s.concurrency == nil || maxConcurrency <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
derivedID := soraConcurrencyAccountID(accountID, slotType)
|
||||
result, err := s.concurrency.AcquireAccountSlot(ctx, derivedID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !result.Acquired {
|
||||
return nil, ErrSoraAccountNotEligible
|
||||
}
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
func soraConcurrencyAccountID(accountID int64, slotType int64) int64 {
|
||||
if accountID < 0 {
|
||||
accountID = -accountID
|
||||
}
|
||||
return -(accountID*10 + slotType)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) createCharacter(ctx context.Context, client *sora.Client, opts sora.RequestOptions, rawVideo string) (string, string, error) {
|
||||
videoBytes, err := s.loadVideoBytes(ctx, opts, rawVideo)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
cameoID, err := client.UploadCharacterVideo(ctx, opts, videoBytes)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
status, err := s.pollCameoStatus(ctx, client, opts, cameoID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
username := processCharacterUsername(stringFromMap(status, "username_hint"))
|
||||
if username == "" {
|
||||
username = soraDefaultUsername
|
||||
}
|
||||
displayName := stringFromMap(status, "display_name_hint")
|
||||
if displayName == "" {
|
||||
displayName = "Character"
|
||||
}
|
||||
profileURL := stringFromMap(status, "profile_asset_url")
|
||||
if profileURL == "" {
|
||||
return "", "", errors.New("profile asset url missing")
|
||||
}
|
||||
avatarData, err := client.DownloadCharacterImage(ctx, opts, profileURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
assetPointer, err := client.UploadCharacterImage(ctx, opts, avatarData)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
characterID, err := client.FinalizeCharacter(ctx, opts, cameoID, username, displayName, assetPointer)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if err := client.SetCharacterPublic(ctx, opts, cameoID); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return username, characterID, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, client *sora.Client, opts sora.RequestOptions, cameoID string) (map[string]any, error) {
|
||||
if cameoID == "" {
|
||||
return nil, errors.New("cameo id empty")
|
||||
}
|
||||
timeout := 600 * time.Second
|
||||
pollInterval := 5 * time.Second
|
||||
deadline := time.Now().Add(timeout)
|
||||
consecutiveErrors := 0
|
||||
maxConsecutiveErrors := 3
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
time.Sleep(pollInterval)
|
||||
status, err := client.GetCameoStatus(ctx, opts, cameoID)
|
||||
if err != nil {
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors >= maxConsecutiveErrors {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
consecutiveErrors = 0
|
||||
statusValue := stringFromMap(status, "status")
|
||||
statusMessage := stringFromMap(status, "status_message")
|
||||
if statusValue == "failed" {
|
||||
if statusMessage == "" {
|
||||
statusMessage = "角色创建失败"
|
||||
}
|
||||
return nil, fmt.Errorf("角色创建失败: %s", statusMessage)
|
||||
}
|
||||
if strings.EqualFold(statusMessage, "Completed") || strings.EqualFold(statusValue, "finalized") {
|
||||
return status, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("角色创建超时")
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) loadVideoBytes(ctx context.Context, opts sora.RequestOptions, rawVideo string) ([]byte, error) {
|
||||
trimmed := strings.TrimSpace(rawVideo)
|
||||
if trimmed == "" {
|
||||
return nil, errors.New("video data is empty")
|
||||
}
|
||||
if looksLikeURL(trimmed) {
|
||||
if err := s.validateMediaURL(trimmed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.downloadMedia(ctx, opts, trimmed, maxVideoDownloadSize)
|
||||
}
|
||||
return decodeBase64(trimmed)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) loadImageBytes(ctx context.Context, opts sora.RequestOptions, rawImage string) ([]byte, error) {
|
||||
trimmed := strings.TrimSpace(rawImage)
|
||||
if trimmed == "" {
|
||||
return nil, errors.New("image data is empty")
|
||||
}
|
||||
if looksLikeURL(trimmed) {
|
||||
if err := s.validateMediaURL(trimmed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.downloadMedia(ctx, opts, trimmed, maxImageDownloadSize)
|
||||
}
|
||||
return decodeBase64(trimmed)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) validateMediaURL(rawURL string) error {
|
||||
cfg := s.cfg
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
if cfg.Security.URLAllowlist.Enabled {
|
||||
_, err := urlvalidator.ValidateHTTPSURL(rawURL, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("媒体地址不合法: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if _, err := urlvalidator.ValidateURLFormat(rawURL, cfg.Security.URLAllowlist.AllowInsecureHTTP); err != nil {
|
||||
return fmt.Errorf("媒体地址不合法: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) downloadMedia(ctx context.Context, opts sora.RequestOptions, mediaURL string, maxSize int64) ([]byte, error) {
|
||||
if s.httpUpstream == nil {
|
||||
return nil, errors.New("upstream is nil")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", mediaURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
enableTLS := false
|
||||
if s.cfg != nil {
|
||||
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
|
||||
}
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, opts.AccountConcurrency, enableTLS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("下载失败: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 使用 LimitReader 限制最大读取大小,防止 DoS 攻击
|
||||
limitedReader := io.LimitReader(resp.Body, maxSize+1)
|
||||
data, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否超过大小限制
|
||||
if int64(len(data)) > maxSize {
|
||||
return nil, fmt.Errorf("媒体文件过大 (最大 %d 字节, 实际 %d 字节)", maxSize, len(data))
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func processCharacterUsername(usernameHint string) string {
|
||||
trimmed := strings.TrimSpace(usernameHint)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
base := trimmed
|
||||
if idx := strings.LastIndex(trimmed, "."); idx != -1 && idx+1 < len(trimmed) {
|
||||
base = trimmed[idx+1:]
|
||||
}
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
return fmt.Sprintf("%s%d", base, rng.Intn(900)+100)
|
||||
}
|
||||
|
||||
func looksLikeURL(value string) bool {
|
||||
trimmed := strings.ToLower(strings.TrimSpace(value))
|
||||
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
|
||||
}
|
||||
113
backend/internal/service/sora_repository.go
Normal file
113
backend/internal/service/sora_repository.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// SoraAccount 表示 Sora 账号扩展信息。
|
||||
type SoraAccount struct {
|
||||
AccountID int64
|
||||
AccessToken string
|
||||
SessionToken string
|
||||
RefreshToken string
|
||||
ClientID string
|
||||
Email string
|
||||
Username string
|
||||
Remark string
|
||||
UseCount int
|
||||
PlanType string
|
||||
PlanTitle string
|
||||
SubscriptionEnd *time.Time
|
||||
SoraSupported bool
|
||||
SoraInviteCode string
|
||||
SoraRedeemedCount int
|
||||
SoraRemainingCount int
|
||||
SoraTotalCount int
|
||||
SoraCooldownUntil *time.Time
|
||||
CooledUntil *time.Time
|
||||
ImageEnabled bool
|
||||
VideoEnabled bool
|
||||
ImageConcurrency int
|
||||
VideoConcurrency int
|
||||
IsExpired bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// SoraUsageStat 表示 Sora 调用统计。
|
||||
type SoraUsageStat struct {
|
||||
AccountID int64
|
||||
ImageCount int
|
||||
VideoCount int
|
||||
ErrorCount int
|
||||
LastErrorAt *time.Time
|
||||
TodayImageCount int
|
||||
TodayVideoCount int
|
||||
TodayErrorCount int
|
||||
TodayDate *time.Time
|
||||
ConsecutiveErrorCount int
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// SoraTask 表示 Sora 任务记录。
|
||||
type SoraTask struct {
|
||||
TaskID string
|
||||
AccountID int64
|
||||
Model string
|
||||
Prompt string
|
||||
Status string
|
||||
Progress float64
|
||||
ResultURLs string
|
||||
ErrorMessage string
|
||||
RetryCount int
|
||||
CreatedAt time.Time
|
||||
CompletedAt *time.Time
|
||||
}
|
||||
|
||||
// SoraCacheFile 表示 Sora 缓存文件记录。
|
||||
type SoraCacheFile struct {
|
||||
ID int64
|
||||
TaskID string
|
||||
AccountID int64
|
||||
UserID int64
|
||||
MediaType string
|
||||
OriginalURL string
|
||||
CachePath string
|
||||
CacheURL string
|
||||
SizeBytes int64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// SoraAccountRepository 定义 Sora 账号仓储接口。
|
||||
type SoraAccountRepository interface {
|
||||
GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
|
||||
GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*SoraAccount, error)
|
||||
Upsert(ctx context.Context, accountID int64, updates map[string]any) error
|
||||
}
|
||||
|
||||
// SoraUsageStatRepository 定义 Sora 调用统计仓储接口。
|
||||
type SoraUsageStatRepository interface {
|
||||
RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error
|
||||
RecordError(ctx context.Context, accountID int64) (int, error)
|
||||
ResetConsecutiveErrors(ctx context.Context, accountID int64) error
|
||||
GetByAccountID(ctx context.Context, accountID int64) (*SoraUsageStat, error)
|
||||
GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*SoraUsageStat, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]*SoraUsageStat, *pagination.PaginationResult, error)
|
||||
}
|
||||
|
||||
// SoraTaskRepository 定义 Sora 任务仓储接口。
|
||||
type SoraTaskRepository interface {
|
||||
Create(ctx context.Context, task *SoraTask) error
|
||||
UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error
|
||||
}
|
||||
|
||||
// SoraCacheFileRepository 定义 Sora 缓存文件仓储接口。
|
||||
type SoraCacheFileRepository interface {
|
||||
Create(ctx context.Context, file *SoraCacheFile) error
|
||||
ListOldest(ctx context.Context, limit int) ([]*SoraCacheFile, error)
|
||||
DeleteByIDs(ctx context.Context, ids []int64) error
|
||||
}
|
||||
313
backend/internal/service/sora_token_refresh_service.go
Normal file
313
backend/internal/service/sora_token_refresh_service.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const defaultSoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
||||
|
||||
// SoraTokenRefreshService handles Sora access token refresh.
|
||||
type SoraTokenRefreshService struct {
|
||||
accountRepo AccountRepository
|
||||
soraAccountRepo SoraAccountRepository
|
||||
settingService *SettingService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
func NewSoraTokenRefreshService(
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
settingService *SettingService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *SoraTokenRefreshService {
|
||||
return &SoraTokenRefreshService{
|
||||
accountRepo: accountRepo,
|
||||
soraAccountRepo: soraAccountRepo,
|
||||
settingService: settingService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go s.refreshLoop()
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) refreshLoop() {
|
||||
for {
|
||||
wait := s.nextRunDelay()
|
||||
timer := time.NewTimer(wait)
|
||||
select {
|
||||
case <-timer.C:
|
||||
s.refreshOnce()
|
||||
case <-s.stopCh:
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) refreshOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if !s.isEnabled(ctx) {
|
||||
log.Println("[SoraTokenRefresh] disabled by settings")
|
||||
return
|
||||
}
|
||||
if s.accountRepo == nil || s.soraAccountRepo == nil {
|
||||
log.Println("[SoraTokenRefresh] repository not configured")
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := s.accountRepo.ListByPlatform(ctx, PlatformSora)
|
||||
if err != nil {
|
||||
log.Printf("[SoraTokenRefresh] list accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
log.Println("[SoraTokenRefresh] no sora accounts")
|
||||
return
|
||||
}
|
||||
ids := make([]int64, 0, len(accounts))
|
||||
accountMap := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := accounts[i]
|
||||
ids = append(ids, acc.ID)
|
||||
accountMap[acc.ID] = &acc
|
||||
}
|
||||
accountExtras, err := s.soraAccountRepo.GetByAccountIDs(ctx, ids)
|
||||
if err != nil {
|
||||
log.Printf("[SoraTokenRefresh] load sora accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
success := 0
|
||||
failed := 0
|
||||
skipped := 0
|
||||
for accountID, account := range accountMap {
|
||||
extra := accountExtras[accountID]
|
||||
if extra == nil {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
result, err := s.refreshForAccount(ctx, account, extra)
|
||||
if err != nil {
|
||||
failed++
|
||||
log.Printf("[SoraTokenRefresh] account %d refresh failed: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
if result == nil {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
updates := map[string]any{
|
||||
"access_token": result.AccessToken,
|
||||
}
|
||||
if result.RefreshToken != "" {
|
||||
updates["refresh_token"] = result.RefreshToken
|
||||
}
|
||||
if result.Email != "" {
|
||||
updates["email"] = result.Email
|
||||
}
|
||||
if err := s.soraAccountRepo.Upsert(ctx, accountID, updates); err != nil {
|
||||
failed++
|
||||
log.Printf("[SoraTokenRefresh] account %d update failed: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
success++
|
||||
}
|
||||
log.Printf("[SoraTokenRefresh] done: success=%d failed=%d skipped=%d", success, failed, skipped)
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) refreshForAccount(ctx context.Context, account *Account, extra *SoraAccount) (*soraRefreshResult, error) {
|
||||
if extra == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if strings.TrimSpace(extra.SessionToken) == "" && strings.TrimSpace(extra.RefreshToken) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if extra.SessionToken != "" {
|
||||
result, err := s.refreshWithSessionToken(ctx, account, extra.SessionToken)
|
||||
if err == nil && result != nil && result.AccessToken != "" {
|
||||
return result, nil
|
||||
}
|
||||
if strings.TrimSpace(extra.RefreshToken) == "" {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
clientID := strings.TrimSpace(extra.ClientID)
|
||||
if clientID == "" {
|
||||
clientID = defaultSoraClientID
|
||||
}
|
||||
return s.refreshWithRefreshToken(ctx, account, extra.RefreshToken, clientID)
|
||||
}
|
||||
|
||||
type soraRefreshResult struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
Email string
|
||||
}
|
||||
|
||||
type soraSessionResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
User struct {
|
||||
Email string `json:"email"`
|
||||
} `json:"user"`
|
||||
}
|
||||
|
||||
type soraRefreshResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) refreshWithSessionToken(ctx context.Context, account *Account, sessionToken string) (*soraRefreshResult, error) {
|
||||
if s.httpUpstream == nil {
|
||||
return nil, fmt.Errorf("upstream not configured")
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://sora.chatgpt.com/api/auth/session", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Origin", "https://sora.chatgpt.com")
|
||||
req.Header.Set("Referer", "https://sora.chatgpt.com/")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
|
||||
enableTLS := false
|
||||
if s.cfg != nil {
|
||||
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
|
||||
}
|
||||
proxyURL := ""
|
||||
accountConcurrency := 0
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
}
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("session refresh failed: %d", resp.StatusCode)
|
||||
}
|
||||
var payload soraSessionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payload.AccessToken == "" {
|
||||
return nil, errors.New("session refresh missing access token")
|
||||
}
|
||||
return &soraRefreshResult{AccessToken: payload.AccessToken, Email: payload.User.Email}, nil
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) refreshWithRefreshToken(ctx context.Context, account *Account, refreshToken, clientID string) (*soraRefreshResult, error) {
|
||||
if s.httpUpstream == nil {
|
||||
return nil, fmt.Errorf("upstream not configured")
|
||||
}
|
||||
payload := map[string]any{
|
||||
"client_id": clientID,
|
||||
"grant_type": "refresh_token",
|
||||
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
|
||||
"refresh_token": refreshToken,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
|
||||
enableTLS := false
|
||||
if s.cfg != nil {
|
||||
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
|
||||
}
|
||||
proxyURL := ""
|
||||
accountConcurrency := 0
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
}
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("refresh token failed: %d", resp.StatusCode)
|
||||
}
|
||||
var payloadResp soraRefreshResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payloadResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payloadResp.AccessToken == "" {
|
||||
return nil, errors.New("refresh token missing access token")
|
||||
}
|
||||
return &soraRefreshResult{AccessToken: payloadResp.AccessToken, RefreshToken: payloadResp.RefreshToken}, nil
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) nextRunDelay() time.Duration {
|
||||
location := time.Local
|
||||
if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" {
|
||||
if tz, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil {
|
||||
location = tz
|
||||
}
|
||||
}
|
||||
now := time.Now().In(location)
|
||||
next := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, location).Add(24 * time.Hour)
|
||||
return time.Until(next)
|
||||
}
|
||||
|
||||
func (s *SoraTokenRefreshService) isEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return s.cfg != nil && s.cfg.Sora.TokenRefresh.Enabled
|
||||
}
|
||||
cfg := s.settingService.GetSoraConfig(ctx)
|
||||
return cfg.TokenRefresh.Enabled
|
||||
}
|
||||
@@ -51,6 +51,30 @@ func ProvideTokenRefreshService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSoraTokenRefreshService creates and starts SoraTokenRefreshService.
|
||||
func ProvideSoraTokenRefreshService(
|
||||
accountRepo AccountRepository,
|
||||
soraAccountRepo SoraAccountRepository,
|
||||
settingService *SettingService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *SoraTokenRefreshService {
|
||||
svc := NewSoraTokenRefreshService(accountRepo, soraAccountRepo, settingService, httpUpstream, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSoraCacheCleanupService creates and starts SoraCacheCleanupService.
|
||||
func ProvideSoraCacheCleanupService(
|
||||
cacheRepo SoraCacheFileRepository,
|
||||
settingService *SettingService,
|
||||
cfg *config.Config,
|
||||
) *SoraCacheCleanupService {
|
||||
svc := NewSoraCacheCleanupService(cacheRepo, settingService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||
@@ -222,6 +246,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewSoraCacheService,
|
||||
NewSoraGatewayService,
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
NewGeminiOAuthService,
|
||||
@@ -255,6 +281,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
ProvideSoraTokenRefreshService,
|
||||
ProvideSoraCacheCleanupService,
|
||||
ProvideAccountExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDashboardAggregationService,
|
||||
|
||||
Reference in New Issue
Block a user