feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题
新增功能: - 新增 Sora 账号管理和 OAuth 认证 - 新增 Sora 视频/图片生成 API 网关 - 新增 Sora 任务调度和缓存机制 - 新增 Sora 使用统计和计费支持 - 前端增加 Sora 平台配置界面 安全修复(代码审核): - [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击 - [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽 - [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置 BUG 修复(代码审核): - [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏 - [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏 性能优化(代码审核): - [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销 技术细节: - 使用 io.LimitReader 限制所有外部输入的大小 - 添加 urlvalidator 验证防止 SSRF 攻击 - 使用 sync.Map 实现线程安全的包级缓存 - 优化并发槽位管理,添加 releaseAll 模式防止泄漏 影响范围: - 后端:新增 Sora 相关数据模型、服务、网关和管理接口 - 前端:新增 Sora 平台配置、账号管理和监控界面 - 配置:新增 Sora 相关配置项和环境变量 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -49,7 +49,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
|
||||
@@ -79,6 +79,23 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
SoraBaseURL: settings.SoraBaseURL,
|
||||
SoraTimeout: settings.SoraTimeout,
|
||||
SoraMaxRetries: settings.SoraMaxRetries,
|
||||
SoraPollInterval: settings.SoraPollInterval,
|
||||
SoraCallLogicMode: settings.SoraCallLogicMode,
|
||||
SoraCacheEnabled: settings.SoraCacheEnabled,
|
||||
SoraCacheBaseDir: settings.SoraCacheBaseDir,
|
||||
SoraCacheVideoDir: settings.SoraCacheVideoDir,
|
||||
SoraCacheMaxBytes: settings.SoraCacheMaxBytes,
|
||||
SoraCacheAllowedHosts: settings.SoraCacheAllowedHosts,
|
||||
SoraCacheUserDirEnabled: settings.SoraCacheUserDirEnabled,
|
||||
SoraWatermarkFreeEnabled: settings.SoraWatermarkFreeEnabled,
|
||||
SoraWatermarkFreeParseMethod: settings.SoraWatermarkFreeParseMethod,
|
||||
SoraWatermarkFreeCustomParseURL: settings.SoraWatermarkFreeCustomParseURL,
|
||||
SoraWatermarkFreeCustomParseToken: settings.SoraWatermarkFreeCustomParseToken,
|
||||
SoraWatermarkFreeFallbackOnFailure: settings.SoraWatermarkFreeFallbackOnFailure,
|
||||
SoraTokenRefreshEnabled: settings.SoraTokenRefreshEnabled,
|
||||
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||
@@ -138,6 +155,25 @@ type UpdateSettingsRequest struct {
|
||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||
|
||||
// Sora configuration
|
||||
SoraBaseURL string `json:"sora_base_url"`
|
||||
SoraTimeout int `json:"sora_timeout"`
|
||||
SoraMaxRetries int `json:"sora_max_retries"`
|
||||
SoraPollInterval float64 `json:"sora_poll_interval"`
|
||||
SoraCallLogicMode string `json:"sora_call_logic_mode"`
|
||||
SoraCacheEnabled bool `json:"sora_cache_enabled"`
|
||||
SoraCacheBaseDir string `json:"sora_cache_base_dir"`
|
||||
SoraCacheVideoDir string `json:"sora_cache_video_dir"`
|
||||
SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"`
|
||||
SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"`
|
||||
SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"`
|
||||
SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"`
|
||||
SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"`
|
||||
SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"`
|
||||
SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"`
|
||||
SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"`
|
||||
SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"`
|
||||
|
||||
// Ops monitoring (vNext)
|
||||
OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"`
|
||||
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
|
||||
@@ -227,6 +263,32 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Sora 参数校验与清理
|
||||
req.SoraBaseURL = strings.TrimSpace(req.SoraBaseURL)
|
||||
if req.SoraBaseURL == "" {
|
||||
req.SoraBaseURL = previousSettings.SoraBaseURL
|
||||
}
|
||||
if req.SoraBaseURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.SoraBaseURL); err != nil {
|
||||
response.BadRequest(c, "Sora Base URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.SoraTimeout <= 0 {
|
||||
req.SoraTimeout = previousSettings.SoraTimeout
|
||||
}
|
||||
if req.SoraMaxRetries < 0 {
|
||||
req.SoraMaxRetries = previousSettings.SoraMaxRetries
|
||||
}
|
||||
if req.SoraPollInterval <= 0 {
|
||||
req.SoraPollInterval = previousSettings.SoraPollInterval
|
||||
}
|
||||
if req.SoraCacheMaxBytes < 0 {
|
||||
req.SoraCacheMaxBytes = 0
|
||||
}
|
||||
req.SoraCacheAllowedHosts = normalizeStringList(req.SoraCacheAllowedHosts)
|
||||
req.SoraWatermarkFreeCustomParseURL = strings.TrimSpace(req.SoraWatermarkFreeCustomParseURL)
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -240,40 +302,57 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
SoraBaseURL: req.SoraBaseURL,
|
||||
SoraTimeout: req.SoraTimeout,
|
||||
SoraMaxRetries: req.SoraMaxRetries,
|
||||
SoraPollInterval: req.SoraPollInterval,
|
||||
SoraCallLogicMode: req.SoraCallLogicMode,
|
||||
SoraCacheEnabled: req.SoraCacheEnabled,
|
||||
SoraCacheBaseDir: req.SoraCacheBaseDir,
|
||||
SoraCacheVideoDir: req.SoraCacheVideoDir,
|
||||
SoraCacheMaxBytes: req.SoraCacheMaxBytes,
|
||||
SoraCacheAllowedHosts: req.SoraCacheAllowedHosts,
|
||||
SoraCacheUserDirEnabled: req.SoraCacheUserDirEnabled,
|
||||
SoraWatermarkFreeEnabled: req.SoraWatermarkFreeEnabled,
|
||||
SoraWatermarkFreeParseMethod: req.SoraWatermarkFreeParseMethod,
|
||||
SoraWatermarkFreeCustomParseURL: req.SoraWatermarkFreeCustomParseURL,
|
||||
SoraWatermarkFreeCustomParseToken: req.SoraWatermarkFreeCustomParseToken,
|
||||
SoraWatermarkFreeFallbackOnFailure: req.SoraWatermarkFreeFallbackOnFailure,
|
||||
SoraTokenRefreshEnabled: req.SoraTokenRefreshEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -349,6 +428,23 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
SoraBaseURL: updatedSettings.SoraBaseURL,
|
||||
SoraTimeout: updatedSettings.SoraTimeout,
|
||||
SoraMaxRetries: updatedSettings.SoraMaxRetries,
|
||||
SoraPollInterval: updatedSettings.SoraPollInterval,
|
||||
SoraCallLogicMode: updatedSettings.SoraCallLogicMode,
|
||||
SoraCacheEnabled: updatedSettings.SoraCacheEnabled,
|
||||
SoraCacheBaseDir: updatedSettings.SoraCacheBaseDir,
|
||||
SoraCacheVideoDir: updatedSettings.SoraCacheVideoDir,
|
||||
SoraCacheMaxBytes: updatedSettings.SoraCacheMaxBytes,
|
||||
SoraCacheAllowedHosts: updatedSettings.SoraCacheAllowedHosts,
|
||||
SoraCacheUserDirEnabled: updatedSettings.SoraCacheUserDirEnabled,
|
||||
SoraWatermarkFreeEnabled: updatedSettings.SoraWatermarkFreeEnabled,
|
||||
SoraWatermarkFreeParseMethod: updatedSettings.SoraWatermarkFreeParseMethod,
|
||||
SoraWatermarkFreeCustomParseURL: updatedSettings.SoraWatermarkFreeCustomParseURL,
|
||||
SoraWatermarkFreeCustomParseToken: updatedSettings.SoraWatermarkFreeCustomParseToken,
|
||||
SoraWatermarkFreeFallbackOnFailure: updatedSettings.SoraWatermarkFreeFallbackOnFailure,
|
||||
SoraTokenRefreshEnabled: updatedSettings.SoraTokenRefreshEnabled,
|
||||
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||
@@ -477,6 +573,57 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||
changed = append(changed, "identity_patch_prompt")
|
||||
}
|
||||
if before.SoraBaseURL != after.SoraBaseURL {
|
||||
changed = append(changed, "sora_base_url")
|
||||
}
|
||||
if before.SoraTimeout != after.SoraTimeout {
|
||||
changed = append(changed, "sora_timeout")
|
||||
}
|
||||
if before.SoraMaxRetries != after.SoraMaxRetries {
|
||||
changed = append(changed, "sora_max_retries")
|
||||
}
|
||||
if before.SoraPollInterval != after.SoraPollInterval {
|
||||
changed = append(changed, "sora_poll_interval")
|
||||
}
|
||||
if before.SoraCallLogicMode != after.SoraCallLogicMode {
|
||||
changed = append(changed, "sora_call_logic_mode")
|
||||
}
|
||||
if before.SoraCacheEnabled != after.SoraCacheEnabled {
|
||||
changed = append(changed, "sora_cache_enabled")
|
||||
}
|
||||
if before.SoraCacheBaseDir != after.SoraCacheBaseDir {
|
||||
changed = append(changed, "sora_cache_base_dir")
|
||||
}
|
||||
if before.SoraCacheVideoDir != after.SoraCacheVideoDir {
|
||||
changed = append(changed, "sora_cache_video_dir")
|
||||
}
|
||||
if before.SoraCacheMaxBytes != after.SoraCacheMaxBytes {
|
||||
changed = append(changed, "sora_cache_max_bytes")
|
||||
}
|
||||
if strings.Join(before.SoraCacheAllowedHosts, ",") != strings.Join(after.SoraCacheAllowedHosts, ",") {
|
||||
changed = append(changed, "sora_cache_allowed_hosts")
|
||||
}
|
||||
if before.SoraCacheUserDirEnabled != after.SoraCacheUserDirEnabled {
|
||||
changed = append(changed, "sora_cache_user_dir_enabled")
|
||||
}
|
||||
if before.SoraWatermarkFreeEnabled != after.SoraWatermarkFreeEnabled {
|
||||
changed = append(changed, "sora_watermark_free_enabled")
|
||||
}
|
||||
if before.SoraWatermarkFreeParseMethod != after.SoraWatermarkFreeParseMethod {
|
||||
changed = append(changed, "sora_watermark_free_parse_method")
|
||||
}
|
||||
if before.SoraWatermarkFreeCustomParseURL != after.SoraWatermarkFreeCustomParseURL {
|
||||
changed = append(changed, "sora_watermark_free_custom_parse_url")
|
||||
}
|
||||
if before.SoraWatermarkFreeCustomParseToken != after.SoraWatermarkFreeCustomParseToken {
|
||||
changed = append(changed, "sora_watermark_free_custom_parse_token")
|
||||
}
|
||||
if before.SoraWatermarkFreeFallbackOnFailure != after.SoraWatermarkFreeFallbackOnFailure {
|
||||
changed = append(changed, "sora_watermark_free_fallback_on_failure")
|
||||
}
|
||||
if before.SoraTokenRefreshEnabled != after.SoraTokenRefreshEnabled {
|
||||
changed = append(changed, "sora_token_refresh_enabled")
|
||||
}
|
||||
if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled {
|
||||
changed = append(changed, "ops_monitoring_enabled")
|
||||
}
|
||||
@@ -492,6 +639,19 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeStringList(values []string) []string {
|
||||
if len(values) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
|
||||
355
backend/internal/handler/admin/sora_account_handler.go
Normal file
355
backend/internal/handler/admin/sora_account_handler.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SoraAccountHandler Sora 账号扩展管理
|
||||
// 提供 Sora 扩展表的查询与更新能力。
|
||||
type SoraAccountHandler struct {
|
||||
adminService service.AdminService
|
||||
soraAccountRepo service.SoraAccountRepository
|
||||
usageRepo service.SoraUsageStatRepository
|
||||
}
|
||||
|
||||
// NewSoraAccountHandler 创建 SoraAccountHandler
|
||||
func NewSoraAccountHandler(adminService service.AdminService, soraAccountRepo service.SoraAccountRepository, usageRepo service.SoraUsageStatRepository) *SoraAccountHandler {
|
||||
return &SoraAccountHandler{
|
||||
adminService: adminService,
|
||||
soraAccountRepo: soraAccountRepo,
|
||||
usageRepo: usageRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// SoraAccountUpdateRequest 更新/创建 Sora 账号扩展请求
|
||||
// 使用指针类型区分未提供与设置为空值。
|
||||
type SoraAccountUpdateRequest struct {
|
||||
AccessToken *string `json:"access_token"`
|
||||
SessionToken *string `json:"session_token"`
|
||||
RefreshToken *string `json:"refresh_token"`
|
||||
ClientID *string `json:"client_id"`
|
||||
Email *string `json:"email"`
|
||||
Username *string `json:"username"`
|
||||
Remark *string `json:"remark"`
|
||||
UseCount *int `json:"use_count"`
|
||||
PlanType *string `json:"plan_type"`
|
||||
PlanTitle *string `json:"plan_title"`
|
||||
SubscriptionEnd *int64 `json:"subscription_end"`
|
||||
SoraSupported *bool `json:"sora_supported"`
|
||||
SoraInviteCode *string `json:"sora_invite_code"`
|
||||
SoraRedeemedCount *int `json:"sora_redeemed_count"`
|
||||
SoraRemainingCount *int `json:"sora_remaining_count"`
|
||||
SoraTotalCount *int `json:"sora_total_count"`
|
||||
SoraCooldownUntil *int64 `json:"sora_cooldown_until"`
|
||||
CooledUntil *int64 `json:"cooled_until"`
|
||||
ImageEnabled *bool `json:"image_enabled"`
|
||||
VideoEnabled *bool `json:"video_enabled"`
|
||||
ImageConcurrency *int `json:"image_concurrency"`
|
||||
VideoConcurrency *int `json:"video_concurrency"`
|
||||
IsExpired *bool `json:"is_expired"`
|
||||
}
|
||||
|
||||
// SoraAccountBatchRequest 批量导入请求
|
||||
// accounts 支持批量 upsert。
|
||||
type SoraAccountBatchRequest struct {
|
||||
Accounts []SoraAccountBatchItem `json:"accounts"`
|
||||
}
|
||||
|
||||
// SoraAccountBatchItem 批量导入条目
|
||||
type SoraAccountBatchItem struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
SoraAccountUpdateRequest
|
||||
}
|
||||
|
||||
// SoraAccountBatchResult 批量导入结果
|
||||
// 仅返回成功/失败数量与明细。
|
||||
type SoraAccountBatchResult struct {
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
Results []SoraAccountBatchItemResult `json:"results"`
|
||||
}
|
||||
|
||||
// SoraAccountBatchItemResult 批量导入单条结果
|
||||
type SoraAccountBatchItemResult struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// List 获取 Sora 账号扩展列表
|
||||
// GET /api/v1/admin/sora/accounts
|
||||
func (h *SoraAccountHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, service.PlatformSora, "", "", search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
accountIDs := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
accountIDs = append(accountIDs, accounts[i].ID)
|
||||
}
|
||||
|
||||
soraMap := map[int64]*service.SoraAccount{}
|
||||
if h.soraAccountRepo != nil {
|
||||
soraMap, _ = h.soraAccountRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
|
||||
}
|
||||
|
||||
usageMap := map[int64]*service.SoraUsageStat{}
|
||||
if h.usageRepo != nil {
|
||||
usageMap, _ = h.usageRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
|
||||
}
|
||||
|
||||
result := make([]dto.SoraAccount, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := accounts[i]
|
||||
item := dto.SoraAccountFromService(&acc, soraMap[acc.ID], usageMap[acc.ID])
|
||||
if item != nil {
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Get 获取单个 Sora 账号扩展
|
||||
// GET /api/v1/admin/sora/accounts/:id
|
||||
func (h *SoraAccountHandler) Get(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "账号 ID 无效")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if account.Platform != service.PlatformSora {
|
||||
response.BadRequest(c, "账号不是 Sora 平台")
|
||||
return
|
||||
}
|
||||
|
||||
var soraAcc *service.SoraAccount
|
||||
if h.soraAccountRepo != nil {
|
||||
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
var usage *service.SoraUsageStat
|
||||
if h.usageRepo != nil {
|
||||
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
|
||||
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
|
||||
}
|
||||
|
||||
// Upsert 更新或创建 Sora 账号扩展
|
||||
// PUT /api/v1/admin/sora/accounts/:id
|
||||
func (h *SoraAccountHandler) Upsert(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "账号 ID 无效")
|
||||
return
|
||||
}
|
||||
|
||||
var req SoraAccountUpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求参数无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if account.Platform != service.PlatformSora {
|
||||
response.BadRequest(c, "账号不是 Sora 平台")
|
||||
return
|
||||
}
|
||||
|
||||
updates := buildSoraAccountUpdates(&req)
|
||||
if h.soraAccountRepo != nil && len(updates) > 0 {
|
||||
if err := h.soraAccountRepo.Upsert(c.Request.Context(), accountID, updates); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var soraAcc *service.SoraAccount
|
||||
if h.soraAccountRepo != nil {
|
||||
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
var usage *service.SoraUsageStat
|
||||
if h.usageRepo != nil {
|
||||
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
|
||||
}
|
||||
|
||||
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
|
||||
}
|
||||
|
||||
// BatchUpsert 批量导入 Sora 账号扩展
|
||||
// POST /api/v1/admin/sora/accounts/import
|
||||
func (h *SoraAccountHandler) BatchUpsert(c *gin.Context) {
|
||||
var req SoraAccountBatchRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "请求参数无效: "+err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.Accounts) == 0 {
|
||||
response.BadRequest(c, "accounts 不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(req.Accounts))
|
||||
for _, item := range req.Accounts {
|
||||
if item.AccountID > 0 {
|
||||
ids = append(ids, item.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
accountMap := make(map[int64]*service.Account, len(ids))
|
||||
if len(ids) > 0 {
|
||||
accounts, _ := h.adminService.GetAccountsByIDs(c.Request.Context(), ids)
|
||||
for _, acc := range accounts {
|
||||
if acc != nil {
|
||||
accountMap[acc.ID] = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := SoraAccountBatchResult{
|
||||
Results: make([]SoraAccountBatchItemResult, 0, len(req.Accounts)),
|
||||
}
|
||||
|
||||
for _, item := range req.Accounts {
|
||||
entry := SoraAccountBatchItemResult{AccountID: item.AccountID}
|
||||
acc := accountMap[item.AccountID]
|
||||
if acc == nil {
|
||||
entry.Error = "账号不存在"
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Failed++
|
||||
continue
|
||||
}
|
||||
if acc.Platform != service.PlatformSora {
|
||||
entry.Error = "账号不是 Sora 平台"
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Failed++
|
||||
continue
|
||||
}
|
||||
updates := buildSoraAccountUpdates(&item.SoraAccountUpdateRequest)
|
||||
if h.soraAccountRepo != nil && len(updates) > 0 {
|
||||
if err := h.soraAccountRepo.Upsert(c.Request.Context(), item.AccountID, updates); err != nil {
|
||||
entry.Error = err.Error()
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Failed++
|
||||
continue
|
||||
}
|
||||
}
|
||||
entry.Success = true
|
||||
result.Results = append(result.Results, entry)
|
||||
result.Success++
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListUsage 获取 Sora 调用统计
|
||||
// GET /api/v1/admin/sora/usage
|
||||
func (h *SoraAccountHandler) ListUsage(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
if h.usageRepo == nil {
|
||||
response.Paginated(c, []dto.SoraUsageStat{}, 0, page, pageSize)
|
||||
return
|
||||
}
|
||||
stats, paginationResult, err := h.usageRepo.List(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
result := make([]dto.SoraUsageStat, 0, len(stats))
|
||||
for _, stat := range stats {
|
||||
item := dto.SoraUsageStatFromService(stat)
|
||||
if item != nil {
|
||||
result = append(result, *item)
|
||||
}
|
||||
}
|
||||
response.Paginated(c, result, paginationResult.Total, paginationResult.Page, paginationResult.PageSize)
|
||||
}
|
||||
|
||||
func buildSoraAccountUpdates(req *SoraAccountUpdateRequest) map[string]any {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
updates := make(map[string]any)
|
||||
setString := func(key string, value *string) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
updates[key] = strings.TrimSpace(*value)
|
||||
}
|
||||
setString("access_token", req.AccessToken)
|
||||
setString("session_token", req.SessionToken)
|
||||
setString("refresh_token", req.RefreshToken)
|
||||
setString("client_id", req.ClientID)
|
||||
setString("email", req.Email)
|
||||
setString("username", req.Username)
|
||||
setString("remark", req.Remark)
|
||||
setString("plan_type", req.PlanType)
|
||||
setString("plan_title", req.PlanTitle)
|
||||
setString("sora_invite_code", req.SoraInviteCode)
|
||||
|
||||
if req.UseCount != nil {
|
||||
updates["use_count"] = *req.UseCount
|
||||
}
|
||||
if req.SoraSupported != nil {
|
||||
updates["sora_supported"] = *req.SoraSupported
|
||||
}
|
||||
if req.SoraRedeemedCount != nil {
|
||||
updates["sora_redeemed_count"] = *req.SoraRedeemedCount
|
||||
}
|
||||
if req.SoraRemainingCount != nil {
|
||||
updates["sora_remaining_count"] = *req.SoraRemainingCount
|
||||
}
|
||||
if req.SoraTotalCount != nil {
|
||||
updates["sora_total_count"] = *req.SoraTotalCount
|
||||
}
|
||||
if req.ImageEnabled != nil {
|
||||
updates["image_enabled"] = *req.ImageEnabled
|
||||
}
|
||||
if req.VideoEnabled != nil {
|
||||
updates["video_enabled"] = *req.VideoEnabled
|
||||
}
|
||||
if req.ImageConcurrency != nil {
|
||||
updates["image_concurrency"] = *req.ImageConcurrency
|
||||
}
|
||||
if req.VideoConcurrency != nil {
|
||||
updates["video_concurrency"] = *req.VideoConcurrency
|
||||
}
|
||||
if req.IsExpired != nil {
|
||||
updates["is_expired"] = *req.IsExpired
|
||||
}
|
||||
if req.SubscriptionEnd != nil && *req.SubscriptionEnd > 0 {
|
||||
updates["subscription_end"] = time.Unix(*req.SubscriptionEnd, 0).UTC()
|
||||
}
|
||||
if req.SoraCooldownUntil != nil && *req.SoraCooldownUntil > 0 {
|
||||
updates["sora_cooldown_until"] = time.Unix(*req.SoraCooldownUntil, 0).UTC()
|
||||
}
|
||||
if req.CooledUntil != nil && *req.CooledUntil > 0 {
|
||||
updates["cooled_until"] = time.Unix(*req.CooledUntil, 0).UTC()
|
||||
}
|
||||
return updates
|
||||
}
|
||||
@@ -287,6 +287,72 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
||||
}
|
||||
}
|
||||
|
||||
func SoraUsageStatFromService(stat *service.SoraUsageStat) *SoraUsageStat {
|
||||
if stat == nil {
|
||||
return nil
|
||||
}
|
||||
return &SoraUsageStat{
|
||||
AccountID: stat.AccountID,
|
||||
ImageCount: stat.ImageCount,
|
||||
VideoCount: stat.VideoCount,
|
||||
ErrorCount: stat.ErrorCount,
|
||||
LastErrorAt: timeToUnixSeconds(stat.LastErrorAt),
|
||||
TodayImageCount: stat.TodayImageCount,
|
||||
TodayVideoCount: stat.TodayVideoCount,
|
||||
TodayErrorCount: stat.TodayErrorCount,
|
||||
TodayDate: timeToUnixSeconds(stat.TodayDate),
|
||||
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
|
||||
CreatedAt: stat.CreatedAt,
|
||||
UpdatedAt: stat.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func SoraAccountFromService(account *service.Account, soraAcc *service.SoraAccount, usage *service.SoraUsageStat) *SoraAccount {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
out := &SoraAccount{
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
AccountStatus: account.Status,
|
||||
AccountType: account.Type,
|
||||
AccountConcurrency: account.Concurrency,
|
||||
ProxyID: account.ProxyID,
|
||||
Usage: SoraUsageStatFromService(usage),
|
||||
CreatedAt: account.CreatedAt,
|
||||
UpdatedAt: account.UpdatedAt,
|
||||
}
|
||||
if soraAcc == nil {
|
||||
return out
|
||||
}
|
||||
out.AccessToken = soraAcc.AccessToken
|
||||
out.SessionToken = soraAcc.SessionToken
|
||||
out.RefreshToken = soraAcc.RefreshToken
|
||||
out.ClientID = soraAcc.ClientID
|
||||
out.Email = soraAcc.Email
|
||||
out.Username = soraAcc.Username
|
||||
out.Remark = soraAcc.Remark
|
||||
out.UseCount = soraAcc.UseCount
|
||||
out.PlanType = soraAcc.PlanType
|
||||
out.PlanTitle = soraAcc.PlanTitle
|
||||
out.SubscriptionEnd = timeToUnixSeconds(soraAcc.SubscriptionEnd)
|
||||
out.SoraSupported = soraAcc.SoraSupported
|
||||
out.SoraInviteCode = soraAcc.SoraInviteCode
|
||||
out.SoraRedeemedCount = soraAcc.SoraRedeemedCount
|
||||
out.SoraRemainingCount = soraAcc.SoraRemainingCount
|
||||
out.SoraTotalCount = soraAcc.SoraTotalCount
|
||||
out.SoraCooldownUntil = timeToUnixSeconds(soraAcc.SoraCooldownUntil)
|
||||
out.CooledUntil = timeToUnixSeconds(soraAcc.CooledUntil)
|
||||
out.ImageEnabled = soraAcc.ImageEnabled
|
||||
out.VideoEnabled = soraAcc.VideoEnabled
|
||||
out.ImageConcurrency = soraAcc.ImageConcurrency
|
||||
out.VideoConcurrency = soraAcc.VideoConcurrency
|
||||
out.IsExpired = soraAcc.IsExpired
|
||||
out.CreatedAt = soraAcc.CreatedAt
|
||||
out.UpdatedAt = soraAcc.UpdatedAt
|
||||
return out
|
||||
}
|
||||
|
||||
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
|
||||
if a == nil {
|
||||
return nil
|
||||
|
||||
@@ -46,6 +46,25 @@ type SystemSettings struct {
|
||||
EnableIdentityPatch bool `json:"enable_identity_patch"`
|
||||
IdentityPatchPrompt string `json:"identity_patch_prompt"`
|
||||
|
||||
// Sora configuration
|
||||
SoraBaseURL string `json:"sora_base_url"`
|
||||
SoraTimeout int `json:"sora_timeout"`
|
||||
SoraMaxRetries int `json:"sora_max_retries"`
|
||||
SoraPollInterval float64 `json:"sora_poll_interval"`
|
||||
SoraCallLogicMode string `json:"sora_call_logic_mode"`
|
||||
SoraCacheEnabled bool `json:"sora_cache_enabled"`
|
||||
SoraCacheBaseDir string `json:"sora_cache_base_dir"`
|
||||
SoraCacheVideoDir string `json:"sora_cache_video_dir"`
|
||||
SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"`
|
||||
SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"`
|
||||
SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"`
|
||||
SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"`
|
||||
SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"`
|
||||
SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"`
|
||||
SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"`
|
||||
SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"`
|
||||
SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"`
|
||||
|
||||
// Ops monitoring (vNext)
|
||||
OpsMonitoringEnabled bool `json:"ops_monitoring_enabled"`
|
||||
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
|
||||
|
||||
@@ -141,6 +141,56 @@ type Account struct {
|
||||
Groups []*Group `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
type SoraUsageStat struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
ImageCount int `json:"image_count"`
|
||||
VideoCount int `json:"video_count"`
|
||||
ErrorCount int `json:"error_count"`
|
||||
LastErrorAt *int64 `json:"last_error_at"`
|
||||
TodayImageCount int `json:"today_image_count"`
|
||||
TodayVideoCount int `json:"today_video_count"`
|
||||
TodayErrorCount int `json:"today_error_count"`
|
||||
TodayDate *int64 `json:"today_date"`
|
||||
ConsecutiveErrorCount int `json:"consecutive_error_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type SoraAccount struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
AccountName string `json:"account_name"`
|
||||
AccountStatus string `json:"account_status"`
|
||||
AccountType string `json:"account_type"`
|
||||
AccountConcurrency int `json:"account_concurrency"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
SessionToken string `json:"session_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ClientID string `json:"client_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Remark string `json:"remark"`
|
||||
UseCount int `json:"use_count"`
|
||||
PlanType string `json:"plan_type"`
|
||||
PlanTitle string `json:"plan_title"`
|
||||
SubscriptionEnd *int64 `json:"subscription_end"`
|
||||
SoraSupported bool `json:"sora_supported"`
|
||||
SoraInviteCode string `json:"sora_invite_code"`
|
||||
SoraRedeemedCount int `json:"sora_redeemed_count"`
|
||||
SoraRemainingCount int `json:"sora_remaining_count"`
|
||||
SoraTotalCount int `json:"sora_total_count"`
|
||||
SoraCooldownUntil *int64 `json:"sora_cooldown_until"`
|
||||
CooledUntil *int64 `json:"cooled_until"`
|
||||
ImageEnabled bool `json:"image_enabled"`
|
||||
VideoEnabled bool `json:"video_enabled"`
|
||||
ImageConcurrency int `json:"image_concurrency"`
|
||||
VideoConcurrency int `json:"video_concurrency"`
|
||||
IsExpired bool `json:"is_expired"`
|
||||
Usage *SoraUsageStat `json:"usage,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -508,6 +509,13 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if platform == service.PlatformSora {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": sora.ListModels(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
|
||||
@@ -17,6 +17,7 @@ type AdminHandlers struct {
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
SoraAccount *admin.SoraAccountHandler
|
||||
Setting *admin.SettingHandler
|
||||
Ops *admin.OpsHandler
|
||||
System *admin.SystemHandler
|
||||
@@ -36,6 +37,7 @@ type Handlers struct {
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
Setting *SettingHandler
|
||||
}
|
||||
|
||||
|
||||
@@ -814,6 +814,8 @@ func guessPlatformFromPath(path string) string {
|
||||
return service.PlatformAntigravity
|
||||
case strings.HasPrefix(p, "/v1beta/"):
|
||||
return service.PlatformGemini
|
||||
case strings.Contains(p, "/chat/completions"):
|
||||
return service.PlatformSora
|
||||
case strings.Contains(p, "/responses"):
|
||||
return service.PlatformOpenAI
|
||||
default:
|
||||
|
||||
364
backend/internal/handler/sora_gateway_handler.go
Normal file
364
backend/internal/handler/sora_gateway_handler.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora OpenAI compatible endpoints.
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler.
|
||||
func NewSoraGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
}
|
||||
return &SoraGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletions handles Sora OpenAI-compatible chat completions endpoint.
|
||||
// POST /v1/chat/completions
|
||||
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := reqBody["model"].(string)
|
||||
if strings.TrimSpace(model) == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
stream, _ := reqBody["stream"].(bool)
|
||||
|
||||
prompt, imageData, videoData, remixID, err := parseSoraPrompt(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
if remixID == "" {
|
||||
remixID = sora.ExtractRemixID(prompt)
|
||||
}
|
||||
if remixID != "" {
|
||||
prompt = strings.ReplaceAll(prompt, remixID, "")
|
||||
}
|
||||
|
||||
if apiKey.Group != nil && apiKey.Group.Platform != service.PlatformSora {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "当前分组不支持 Sora 平台")
|
||||
return
|
||||
}
|
||||
|
||||
streamStarted := false
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
if err == nil && !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
maxSwitches := h.maxAccountSwitches
|
||||
if mode := h.soraGatewayService.CallLogicMode(c.Request.Context()); strings.EqualFold(mode, "native") {
|
||||
maxSwitches = 1
|
||||
}
|
||||
|
||||
for switchCount := 0; switchCount < maxSwitches; switchCount++ {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, "", model, failedAccountIDs, "")
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "server_error", err.Error())
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
releaseFunc := selection.ReleaseFunc
|
||||
|
||||
result, err := h.soraGatewayService.Generate(c.Request.Context(), account, service.SoraGenerationRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Image: imageData,
|
||||
Video: videoData,
|
||||
RemixTargetID: remixID,
|
||||
Stream: stream,
|
||||
UserID: subject.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
// 失败路径:立即释放槽位,而非 defer
|
||||
if releaseFunc != nil {
|
||||
releaseFunc()
|
||||
}
|
||||
|
||||
if errors.Is(err, service.ErrSoraAccountMissingToken) || errors.Is(err, service.ErrSoraAccountNotEligible) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "server_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 成功路径:使用 defer 在函数退出时释放
|
||||
if releaseFunc != nil {
|
||||
defer releaseFunc()
|
||||
}
|
||||
|
||||
h.respondCompletion(c, model, result, stream)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleFailoverExhausted(c, http.StatusServiceUnavailable, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) respondCompletion(c *gin.Context, model string, result *service.SoraGenerationResult, stream bool) {
|
||||
if result == nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Empty response")
|
||||
return
|
||||
}
|
||||
if stream {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
first := buildSoraStreamChunk(model, "", true, "")
|
||||
if _, err := c.Writer.WriteString(first); err != nil {
|
||||
return
|
||||
}
|
||||
final := buildSoraStreamChunk(model, result.Content, false, "stop")
|
||||
if _, err := c.Writer.WriteString(final); err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildSoraNonStreamResponse(model, result.Content))
|
||||
}
|
||||
|
||||
func buildSoraStreamChunk(model, content string, isFirst bool, finishReason string) string {
|
||||
chunkID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli())
|
||||
delta := map[string]any{}
|
||||
if isFirst {
|
||||
delta["role"] = "assistant"
|
||||
}
|
||||
if content != "" {
|
||||
delta["content"] = content
|
||||
} else {
|
||||
delta["content"] = nil
|
||||
}
|
||||
response := map[string]any{
|
||||
"id": chunkID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload, _ := json.Marshal(response)
|
||||
return "data: " + string(payload) + "\n\n"
|
||||
}
|
||||
|
||||
func buildSoraNonStreamResponse(model, content string) map[string]any {
|
||||
return map[string]any{
|
||||
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()),
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseSoraPrompt(req map[string]any) (prompt, imageData, videoData, remixID string, err error) {
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok || len(messages) == 0 {
|
||||
return "", "", "", "", fmt.Errorf("messages is required")
|
||||
}
|
||||
last := messages[len(messages)-1]
|
||||
msg, ok := last.(map[string]any)
|
||||
if !ok {
|
||||
return "", "", "", "", fmt.Errorf("invalid message format")
|
||||
}
|
||||
content, ok := msg["content"]
|
||||
if !ok {
|
||||
return "", "", "", "", fmt.Errorf("content is required")
|
||||
}
|
||||
|
||||
if v, ok := req["image"].(string); ok && v != "" {
|
||||
imageData = v
|
||||
}
|
||||
if v, ok := req["video"].(string); ok && v != "" {
|
||||
videoData = v
|
||||
}
|
||||
if v, ok := req["remix_target_id"].(string); ok {
|
||||
remixID = v
|
||||
}
|
||||
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
prompt = value
|
||||
case []any:
|
||||
for _, item := range value {
|
||||
part, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch part["type"] {
|
||||
case "text":
|
||||
if text, ok := part["text"].(string); ok {
|
||||
prompt = text
|
||||
}
|
||||
case "image_url":
|
||||
if image, ok := part["image_url"].(map[string]any); ok {
|
||||
if url, ok := image["url"].(string); ok {
|
||||
imageData = url
|
||||
}
|
||||
}
|
||||
case "video_url":
|
||||
if video, ok := part["video_url"].(map[string]any); ok {
|
||||
if url, ok := video["url"].(string); ok {
|
||||
videoData = url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return "", "", "", "", fmt.Errorf("invalid content format")
|
||||
}
|
||||
if strings.TrimSpace(prompt) == "" && strings.TrimSpace(videoData) == "" {
|
||||
return "", "", "", "", fmt.Errorf("prompt is required")
|
||||
}
|
||||
return prompt, imageData, videoData, remixID, nil
|
||||
}
|
||||
|
||||
func looksLikeURL(value string) bool {
|
||||
trimmed := strings.ToLower(strings.TrimSpace(value))
|
||||
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", err.Error(), true)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
message := "No available Sora accounts"
|
||||
h.handleStreamingAwareError(c, statusCode, "server_error", message, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
payload := map[string]any{"error": map[string]any{"message": message, "type": errType, "param": nil, "code": nil}}
|
||||
data, _ := json.Marshal(payload)
|
||||
_, _ = c.Writer.WriteString("data: " + string(data) + "\n\n")
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"message": message,
|
||||
"type": errType,
|
||||
"param": nil,
|
||||
"code": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -20,6 +20,7 @@ func ProvideAdminHandlers(
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
soraAccountHandler *admin.SoraAccountHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
opsHandler *admin.OpsHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
@@ -39,6 +40,7 @@ func ProvideAdminHandlers(
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
SoraAccount: soraAccountHandler,
|
||||
Setting: settingHandler,
|
||||
Ops: opsHandler,
|
||||
System: systemHandler,
|
||||
@@ -69,6 +71,7 @@ func ProvideHandlers(
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
@@ -81,6 +84,7 @@ func ProvideHandlers(
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
}
|
||||
}
|
||||
@@ -96,6 +100,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewSoraGatewayHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
@@ -110,6 +115,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
admin.NewSoraAccountHandler,
|
||||
admin.NewSettingHandler,
|
||||
admin.NewOpsHandler,
|
||||
ProvideSystemHandler,
|
||||
|
||||
Reference in New Issue
Block a user