feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题

新增功能:
- 新增 Sora 账号管理和 OAuth 认证
- 新增 Sora 视频/图片生成 API 网关
- 新增 Sora 任务调度和缓存机制
- 新增 Sora 使用统计和计费支持
- 前端增加 Sora 平台配置界面

安全修复(代码审核):
- [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击
- [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽
- [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置

BUG 修复(代码审核):
- [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏
- [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏

性能优化(代码审核):
- [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销

技术细节:
- 使用 io.LimitReader 限制所有外部输入的大小
- 添加 urlvalidator 验证防止 SSRF 攻击
- 使用 sync.Map 实现线程安全的包级缓存
- 优化并发槽位管理,添加 releaseAll 模式防止泄漏

影响范围:
- 后端:新增 Sora 相关数据模型、服务、网关和管理接口
- 前端:新增 Sora 平台配置、账号管理和监控界面
- 配置:新增 Sora 相关配置项和环境变量

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-01-29 16:18:38 +08:00
parent bece1b5201
commit 13262a5698
97 changed files with 29541 additions and 68 deletions

View File

@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -49,7 +49,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`

View File

@@ -79,6 +79,23 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
SoraBaseURL: settings.SoraBaseURL,
SoraTimeout: settings.SoraTimeout,
SoraMaxRetries: settings.SoraMaxRetries,
SoraPollInterval: settings.SoraPollInterval,
SoraCallLogicMode: settings.SoraCallLogicMode,
SoraCacheEnabled: settings.SoraCacheEnabled,
SoraCacheBaseDir: settings.SoraCacheBaseDir,
SoraCacheVideoDir: settings.SoraCacheVideoDir,
SoraCacheMaxBytes: settings.SoraCacheMaxBytes,
SoraCacheAllowedHosts: settings.SoraCacheAllowedHosts,
SoraCacheUserDirEnabled: settings.SoraCacheUserDirEnabled,
SoraWatermarkFreeEnabled: settings.SoraWatermarkFreeEnabled,
SoraWatermarkFreeParseMethod: settings.SoraWatermarkFreeParseMethod,
SoraWatermarkFreeCustomParseURL: settings.SoraWatermarkFreeCustomParseURL,
SoraWatermarkFreeCustomParseToken: settings.SoraWatermarkFreeCustomParseToken,
SoraWatermarkFreeFallbackOnFailure: settings.SoraWatermarkFreeFallbackOnFailure,
SoraTokenRefreshEnabled: settings.SoraTokenRefreshEnabled,
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
@@ -138,6 +155,25 @@ type UpdateSettingsRequest struct {
EnableIdentityPatch bool `json:"enable_identity_patch"`
IdentityPatchPrompt string `json:"identity_patch_prompt"`
// Sora configuration
SoraBaseURL string `json:"sora_base_url"`
SoraTimeout int `json:"sora_timeout"`
SoraMaxRetries int `json:"sora_max_retries"`
SoraPollInterval float64 `json:"sora_poll_interval"`
SoraCallLogicMode string `json:"sora_call_logic_mode"`
SoraCacheEnabled bool `json:"sora_cache_enabled"`
SoraCacheBaseDir string `json:"sora_cache_base_dir"`
SoraCacheVideoDir string `json:"sora_cache_video_dir"`
SoraCacheMaxBytes int64 `json:"sora_cache_max_bytes"`
SoraCacheAllowedHosts []string `json:"sora_cache_allowed_hosts"`
SoraCacheUserDirEnabled bool `json:"sora_cache_user_dir_enabled"`
SoraWatermarkFreeEnabled bool `json:"sora_watermark_free_enabled"`
SoraWatermarkFreeParseMethod string `json:"sora_watermark_free_parse_method"`
SoraWatermarkFreeCustomParseURL string `json:"sora_watermark_free_custom_parse_url"`
SoraWatermarkFreeCustomParseToken string `json:"sora_watermark_free_custom_parse_token"`
SoraWatermarkFreeFallbackOnFailure bool `json:"sora_watermark_free_fallback_on_failure"`
SoraTokenRefreshEnabled bool `json:"sora_token_refresh_enabled"`
// Ops monitoring (vNext)
OpsMonitoringEnabled *bool `json:"ops_monitoring_enabled"`
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
@@ -227,6 +263,32 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// Sora 参数校验与清理
req.SoraBaseURL = strings.TrimSpace(req.SoraBaseURL)
if req.SoraBaseURL == "" {
req.SoraBaseURL = previousSettings.SoraBaseURL
}
if req.SoraBaseURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.SoraBaseURL); err != nil {
response.BadRequest(c, "Sora Base URL must be an absolute http(s) URL")
return
}
}
if req.SoraTimeout <= 0 {
req.SoraTimeout = previousSettings.SoraTimeout
}
if req.SoraMaxRetries < 0 {
req.SoraMaxRetries = previousSettings.SoraMaxRetries
}
if req.SoraPollInterval <= 0 {
req.SoraPollInterval = previousSettings.SoraPollInterval
}
if req.SoraCacheMaxBytes < 0 {
req.SoraCacheMaxBytes = 0
}
req.SoraCacheAllowedHosts = normalizeStringList(req.SoraCacheAllowedHosts)
req.SoraWatermarkFreeCustomParseURL = strings.TrimSpace(req.SoraWatermarkFreeCustomParseURL)
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -240,40 +302,57 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
SoraBaseURL: req.SoraBaseURL,
SoraTimeout: req.SoraTimeout,
SoraMaxRetries: req.SoraMaxRetries,
SoraPollInterval: req.SoraPollInterval,
SoraCallLogicMode: req.SoraCallLogicMode,
SoraCacheEnabled: req.SoraCacheEnabled,
SoraCacheBaseDir: req.SoraCacheBaseDir,
SoraCacheVideoDir: req.SoraCacheVideoDir,
SoraCacheMaxBytes: req.SoraCacheMaxBytes,
SoraCacheAllowedHosts: req.SoraCacheAllowedHosts,
SoraCacheUserDirEnabled: req.SoraCacheUserDirEnabled,
SoraWatermarkFreeEnabled: req.SoraWatermarkFreeEnabled,
SoraWatermarkFreeParseMethod: req.SoraWatermarkFreeParseMethod,
SoraWatermarkFreeCustomParseURL: req.SoraWatermarkFreeCustomParseURL,
SoraWatermarkFreeCustomParseToken: req.SoraWatermarkFreeCustomParseToken,
SoraWatermarkFreeFallbackOnFailure: req.SoraWatermarkFreeFallbackOnFailure,
SoraTokenRefreshEnabled: req.SoraTokenRefreshEnabled,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -349,6 +428,23 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
SoraBaseURL: updatedSettings.SoraBaseURL,
SoraTimeout: updatedSettings.SoraTimeout,
SoraMaxRetries: updatedSettings.SoraMaxRetries,
SoraPollInterval: updatedSettings.SoraPollInterval,
SoraCallLogicMode: updatedSettings.SoraCallLogicMode,
SoraCacheEnabled: updatedSettings.SoraCacheEnabled,
SoraCacheBaseDir: updatedSettings.SoraCacheBaseDir,
SoraCacheVideoDir: updatedSettings.SoraCacheVideoDir,
SoraCacheMaxBytes: updatedSettings.SoraCacheMaxBytes,
SoraCacheAllowedHosts: updatedSettings.SoraCacheAllowedHosts,
SoraCacheUserDirEnabled: updatedSettings.SoraCacheUserDirEnabled,
SoraWatermarkFreeEnabled: updatedSettings.SoraWatermarkFreeEnabled,
SoraWatermarkFreeParseMethod: updatedSettings.SoraWatermarkFreeParseMethod,
SoraWatermarkFreeCustomParseURL: updatedSettings.SoraWatermarkFreeCustomParseURL,
SoraWatermarkFreeCustomParseToken: updatedSettings.SoraWatermarkFreeCustomParseToken,
SoraWatermarkFreeFallbackOnFailure: updatedSettings.SoraWatermarkFreeFallbackOnFailure,
SoraTokenRefreshEnabled: updatedSettings.SoraTokenRefreshEnabled,
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
@@ -477,6 +573,57 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
changed = append(changed, "identity_patch_prompt")
}
if before.SoraBaseURL != after.SoraBaseURL {
changed = append(changed, "sora_base_url")
}
if before.SoraTimeout != after.SoraTimeout {
changed = append(changed, "sora_timeout")
}
if before.SoraMaxRetries != after.SoraMaxRetries {
changed = append(changed, "sora_max_retries")
}
if before.SoraPollInterval != after.SoraPollInterval {
changed = append(changed, "sora_poll_interval")
}
if before.SoraCallLogicMode != after.SoraCallLogicMode {
changed = append(changed, "sora_call_logic_mode")
}
if before.SoraCacheEnabled != after.SoraCacheEnabled {
changed = append(changed, "sora_cache_enabled")
}
if before.SoraCacheBaseDir != after.SoraCacheBaseDir {
changed = append(changed, "sora_cache_base_dir")
}
if before.SoraCacheVideoDir != after.SoraCacheVideoDir {
changed = append(changed, "sora_cache_video_dir")
}
if before.SoraCacheMaxBytes != after.SoraCacheMaxBytes {
changed = append(changed, "sora_cache_max_bytes")
}
if strings.Join(before.SoraCacheAllowedHosts, ",") != strings.Join(after.SoraCacheAllowedHosts, ",") {
changed = append(changed, "sora_cache_allowed_hosts")
}
if before.SoraCacheUserDirEnabled != after.SoraCacheUserDirEnabled {
changed = append(changed, "sora_cache_user_dir_enabled")
}
if before.SoraWatermarkFreeEnabled != after.SoraWatermarkFreeEnabled {
changed = append(changed, "sora_watermark_free_enabled")
}
if before.SoraWatermarkFreeParseMethod != after.SoraWatermarkFreeParseMethod {
changed = append(changed, "sora_watermark_free_parse_method")
}
if before.SoraWatermarkFreeCustomParseURL != after.SoraWatermarkFreeCustomParseURL {
changed = append(changed, "sora_watermark_free_custom_parse_url")
}
if before.SoraWatermarkFreeCustomParseToken != after.SoraWatermarkFreeCustomParseToken {
changed = append(changed, "sora_watermark_free_custom_parse_token")
}
if before.SoraWatermarkFreeFallbackOnFailure != after.SoraWatermarkFreeFallbackOnFailure {
changed = append(changed, "sora_watermark_free_fallback_on_failure")
}
if before.SoraTokenRefreshEnabled != after.SoraTokenRefreshEnabled {
changed = append(changed, "sora_token_refresh_enabled")
}
if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled {
changed = append(changed, "ops_monitoring_enabled")
}
@@ -492,6 +639,19 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
return changed
}
func normalizeStringList(values []string) []string {
if len(values) == 0 {
return []string{}
}
normalized := make([]string, 0, len(values))
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
normalized = append(normalized, trimmed)
}
}
return normalized
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`

View File

@@ -0,0 +1,355 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SoraAccountHandler Sora 账号扩展管理
// 提供 Sora 扩展表的查询与更新能力。
type SoraAccountHandler struct {
adminService service.AdminService
soraAccountRepo service.SoraAccountRepository
usageRepo service.SoraUsageStatRepository
}
// NewSoraAccountHandler 创建 SoraAccountHandler
func NewSoraAccountHandler(adminService service.AdminService, soraAccountRepo service.SoraAccountRepository, usageRepo service.SoraUsageStatRepository) *SoraAccountHandler {
return &SoraAccountHandler{
adminService: adminService,
soraAccountRepo: soraAccountRepo,
usageRepo: usageRepo,
}
}
// SoraAccountUpdateRequest 更新/创建 Sora 账号扩展请求
// 使用指针类型区分未提供与设置为空值。
type SoraAccountUpdateRequest struct {
AccessToken *string `json:"access_token"`
SessionToken *string `json:"session_token"`
RefreshToken *string `json:"refresh_token"`
ClientID *string `json:"client_id"`
Email *string `json:"email"`
Username *string `json:"username"`
Remark *string `json:"remark"`
UseCount *int `json:"use_count"`
PlanType *string `json:"plan_type"`
PlanTitle *string `json:"plan_title"`
SubscriptionEnd *int64 `json:"subscription_end"`
SoraSupported *bool `json:"sora_supported"`
SoraInviteCode *string `json:"sora_invite_code"`
SoraRedeemedCount *int `json:"sora_redeemed_count"`
SoraRemainingCount *int `json:"sora_remaining_count"`
SoraTotalCount *int `json:"sora_total_count"`
SoraCooldownUntil *int64 `json:"sora_cooldown_until"`
CooledUntil *int64 `json:"cooled_until"`
ImageEnabled *bool `json:"image_enabled"`
VideoEnabled *bool `json:"video_enabled"`
ImageConcurrency *int `json:"image_concurrency"`
VideoConcurrency *int `json:"video_concurrency"`
IsExpired *bool `json:"is_expired"`
}
// SoraAccountBatchRequest 批量导入请求
// accounts 支持批量 upsert。
type SoraAccountBatchRequest struct {
Accounts []SoraAccountBatchItem `json:"accounts"`
}
// SoraAccountBatchItem 批量导入条目
type SoraAccountBatchItem struct {
AccountID int64 `json:"account_id"`
SoraAccountUpdateRequest
}
// SoraAccountBatchResult 批量导入结果
// 仅返回成功/失败数量与明细。
type SoraAccountBatchResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []SoraAccountBatchItemResult `json:"results"`
}
// SoraAccountBatchItemResult 批量导入单条结果
type SoraAccountBatchItemResult struct {
AccountID int64 `json:"account_id"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// List 获取 Sora 账号扩展列表
// GET /api/v1/admin/sora/accounts
func (h *SoraAccountHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
search := strings.TrimSpace(c.Query("search"))
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, service.PlatformSora, "", "", search)
if err != nil {
response.ErrorFrom(c, err)
return
}
accountIDs := make([]int64, 0, len(accounts))
for i := range accounts {
accountIDs = append(accountIDs, accounts[i].ID)
}
soraMap := map[int64]*service.SoraAccount{}
if h.soraAccountRepo != nil {
soraMap, _ = h.soraAccountRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
}
usageMap := map[int64]*service.SoraUsageStat{}
if h.usageRepo != nil {
usageMap, _ = h.usageRepo.GetByAccountIDs(c.Request.Context(), accountIDs)
}
result := make([]dto.SoraAccount, 0, len(accounts))
for i := range accounts {
acc := accounts[i]
item := dto.SoraAccountFromService(&acc, soraMap[acc.ID], usageMap[acc.ID])
if item != nil {
result = append(result, *item)
}
}
response.Paginated(c, result, total, page, pageSize)
}
// Get 获取单个 Sora 账号扩展
// GET /api/v1/admin/sora/accounts/:id
func (h *SoraAccountHandler) Get(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "账号 ID 无效")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if account.Platform != service.PlatformSora {
response.BadRequest(c, "账号不是 Sora 平台")
return
}
var soraAcc *service.SoraAccount
if h.soraAccountRepo != nil {
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
}
var usage *service.SoraUsageStat
if h.usageRepo != nil {
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
}
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
}
// Upsert 更新或创建 Sora 账号扩展
// PUT /api/v1/admin/sora/accounts/:id
func (h *SoraAccountHandler) Upsert(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "账号 ID 无效")
return
}
var req SoraAccountUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数无效: "+err.Error())
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if account.Platform != service.PlatformSora {
response.BadRequest(c, "账号不是 Sora 平台")
return
}
updates := buildSoraAccountUpdates(&req)
if h.soraAccountRepo != nil && len(updates) > 0 {
if err := h.soraAccountRepo.Upsert(c.Request.Context(), accountID, updates); err != nil {
response.ErrorFrom(c, err)
return
}
}
var soraAcc *service.SoraAccount
if h.soraAccountRepo != nil {
soraAcc, _ = h.soraAccountRepo.GetByAccountID(c.Request.Context(), accountID)
}
var usage *service.SoraUsageStat
if h.usageRepo != nil {
usage, _ = h.usageRepo.GetByAccountID(c.Request.Context(), accountID)
}
response.Success(c, dto.SoraAccountFromService(account, soraAcc, usage))
}
// BatchUpsert 批量导入 Sora 账号扩展
// POST /api/v1/admin/sora/accounts/import
func (h *SoraAccountHandler) BatchUpsert(c *gin.Context) {
var req SoraAccountBatchRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数无效: "+err.Error())
return
}
if len(req.Accounts) == 0 {
response.BadRequest(c, "accounts 不能为空")
return
}
ids := make([]int64, 0, len(req.Accounts))
for _, item := range req.Accounts {
if item.AccountID > 0 {
ids = append(ids, item.AccountID)
}
}
accountMap := make(map[int64]*service.Account, len(ids))
if len(ids) > 0 {
accounts, _ := h.adminService.GetAccountsByIDs(c.Request.Context(), ids)
for _, acc := range accounts {
if acc != nil {
accountMap[acc.ID] = acc
}
}
}
result := SoraAccountBatchResult{
Results: make([]SoraAccountBatchItemResult, 0, len(req.Accounts)),
}
for _, item := range req.Accounts {
entry := SoraAccountBatchItemResult{AccountID: item.AccountID}
acc := accountMap[item.AccountID]
if acc == nil {
entry.Error = "账号不存在"
result.Results = append(result.Results, entry)
result.Failed++
continue
}
if acc.Platform != service.PlatformSora {
entry.Error = "账号不是 Sora 平台"
result.Results = append(result.Results, entry)
result.Failed++
continue
}
updates := buildSoraAccountUpdates(&item.SoraAccountUpdateRequest)
if h.soraAccountRepo != nil && len(updates) > 0 {
if err := h.soraAccountRepo.Upsert(c.Request.Context(), item.AccountID, updates); err != nil {
entry.Error = err.Error()
result.Results = append(result.Results, entry)
result.Failed++
continue
}
}
entry.Success = true
result.Results = append(result.Results, entry)
result.Success++
}
response.Success(c, result)
}
// ListUsage 获取 Sora 调用统计
// GET /api/v1/admin/sora/usage
func (h *SoraAccountHandler) ListUsage(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
if h.usageRepo == nil {
response.Paginated(c, []dto.SoraUsageStat{}, 0, page, pageSize)
return
}
stats, paginationResult, err := h.usageRepo.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
result := make([]dto.SoraUsageStat, 0, len(stats))
for _, stat := range stats {
item := dto.SoraUsageStatFromService(stat)
if item != nil {
result = append(result, *item)
}
}
response.Paginated(c, result, paginationResult.Total, paginationResult.Page, paginationResult.PageSize)
}
func buildSoraAccountUpdates(req *SoraAccountUpdateRequest) map[string]any {
if req == nil {
return nil
}
updates := make(map[string]any)
setString := func(key string, value *string) {
if value == nil {
return
}
updates[key] = strings.TrimSpace(*value)
}
setString("access_token", req.AccessToken)
setString("session_token", req.SessionToken)
setString("refresh_token", req.RefreshToken)
setString("client_id", req.ClientID)
setString("email", req.Email)
setString("username", req.Username)
setString("remark", req.Remark)
setString("plan_type", req.PlanType)
setString("plan_title", req.PlanTitle)
setString("sora_invite_code", req.SoraInviteCode)
if req.UseCount != nil {
updates["use_count"] = *req.UseCount
}
if req.SoraSupported != nil {
updates["sora_supported"] = *req.SoraSupported
}
if req.SoraRedeemedCount != nil {
updates["sora_redeemed_count"] = *req.SoraRedeemedCount
}
if req.SoraRemainingCount != nil {
updates["sora_remaining_count"] = *req.SoraRemainingCount
}
if req.SoraTotalCount != nil {
updates["sora_total_count"] = *req.SoraTotalCount
}
if req.ImageEnabled != nil {
updates["image_enabled"] = *req.ImageEnabled
}
if req.VideoEnabled != nil {
updates["video_enabled"] = *req.VideoEnabled
}
if req.ImageConcurrency != nil {
updates["image_concurrency"] = *req.ImageConcurrency
}
if req.VideoConcurrency != nil {
updates["video_concurrency"] = *req.VideoConcurrency
}
if req.IsExpired != nil {
updates["is_expired"] = *req.IsExpired
}
if req.SubscriptionEnd != nil && *req.SubscriptionEnd > 0 {
updates["subscription_end"] = time.Unix(*req.SubscriptionEnd, 0).UTC()
}
if req.SoraCooldownUntil != nil && *req.SoraCooldownUntil > 0 {
updates["sora_cooldown_until"] = time.Unix(*req.SoraCooldownUntil, 0).UTC()
}
if req.CooledUntil != nil && *req.CooledUntil > 0 {
updates["cooled_until"] = time.Unix(*req.CooledUntil, 0).UTC()
}
return updates
}