feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -0,0 +1,979 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// 上游模型缓存 TTL
modelCacheTTL = 1 * time.Hour // 上游获取成功
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
)
// SoraClientHandler 处理 Sora 客户端 API 请求。
type SoraClientHandler struct {
genService *service.SoraGenerationService
quotaService *service.SoraQuotaService
s3Storage *service.SoraS3Storage
soraGatewayService *service.SoraGatewayService
gatewayService *service.GatewayService
mediaStorage *service.SoraMediaStorage
apiKeyService *service.APIKeyService
// 上游模型缓存
modelCacheMu sync.RWMutex
cachedFamilies []service.SoraModelFamily
modelCacheTime time.Time
modelCacheUpstream bool // 是否来自上游(决定 TTL
}
// NewSoraClientHandler 创建 Sora 客户端 Handler。
func NewSoraClientHandler(
genService *service.SoraGenerationService,
quotaService *service.SoraQuotaService,
s3Storage *service.SoraS3Storage,
soraGatewayService *service.SoraGatewayService,
gatewayService *service.GatewayService,
mediaStorage *service.SoraMediaStorage,
apiKeyService *service.APIKeyService,
) *SoraClientHandler {
return &SoraClientHandler{
genService: genService,
quotaService: quotaService,
s3Storage: s3Storage,
soraGatewayService: soraGatewayService,
gatewayService: gatewayService,
mediaStorage: mediaStorage,
apiKeyService: apiKeyService,
}
}
// GenerateRequest 生成请求。
type GenerateRequest struct {
Model string `json:"model" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
MediaType string `json:"media_type"` // video / image默认 video
VideoCount int `json:"video_count,omitempty"` // 视频数量1-3
ImageInput string `json:"image_input,omitempty"` // 参考图base64 或 URL
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
}
// Generate 异步生成 — 创建 pending 记录后立即返回。
// POST /api/v1/sora/generate
func (h *SoraClientHandler) Generate(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
var req GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
return
}
if req.MediaType == "" {
req.MediaType = "video"
}
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
// 并发数检查(最多 3 个)
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if activeCount >= 3 {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
// 配额检查(粗略检查,实际文件大小在上传后才知道)
if h.quotaService != nil {
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusForbidden, err.Error())
return
}
}
// 获取 API Key ID 和 Group ID
var apiKeyID *int64
var groupID *int64
if req.APIKeyID != nil && h.apiKeyService != nil {
// 前端传递了 api_key_id需要校验
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
if err != nil {
response.Error(c, http.StatusBadRequest, "API Key 不存在")
return
}
if apiKey.UserID != userID {
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
return
}
if apiKey.Status != service.StatusAPIKeyActive {
response.Error(c, http.StatusForbidden, "API Key 不可用")
return
}
apiKeyID = &apiKey.ID
groupID = apiKey.GroupID
} else if id, ok := c.Get("api_key_id"); ok {
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
if v, ok := id.(int64); ok {
apiKeyID = &v
}
}
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
if err != nil {
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
return
}
response.ErrorFrom(c, err)
return
}
// 启动后台异步生成 goroutine
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
response.Success(c, gin.H{
"generation_id": gen.ID,
"status": gen.Status,
})
}
// processGeneration 后台异步执行 Sora 生成任务。
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储S3 → 本地 → 上游)→ 更新记录。
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// 标记为生成中
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
genID,
userID,
groupIDForLog(groupID),
model,
mediaType,
videoCount,
strings.TrimSpace(imageInput) != "",
len(strings.TrimSpace(prompt)),
)
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
if groupID == nil {
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
}
if h.gatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
return
}
// 选择 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
genID,
userID,
groupIDForLog(groupID),
model,
err,
)
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
return
}
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
genID,
userID,
groupIDForLog(groupID),
model,
account.ID,
account.Name,
account.Platform,
account.Type,
)
// 构建 chat completions 请求体(非流式)
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
if h.soraGatewayService == nil {
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
return
}
// 创建 mock gin 上下文用于 Forward捕获响应以提取媒体 URL
recorder := httptest.NewRecorder()
mockGinCtx, _ := gin.CreateTestContext(recorder)
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
// 调用 Forward非流式
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
if err != nil {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
err,
)
// 检查是否已取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
return
}
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
return
}
// 提取媒体 URL优先从 ForwardResult其次从响应体解析
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
if mediaURL == "" {
logger.LegacyPrintf(
"handler.sora_client",
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
genID,
account.ID,
model,
recorder.Code,
trimForLog(recorder.Body.String(), 400),
)
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
return
}
// 检查任务是否已被取消
gen, _ := h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
return
}
// 三层降级存储S3 → 本地 → 上游临时 URL
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
usageAdded := false
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
return
}
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
return
}
usageAdded = true
}
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
gen, _ = h.genService.GetByID(ctx, genID, userID)
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
// 标记完成
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
}
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
return
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
}
// storeMediaWithDegradation 实现三层降级存储链S3 → 本地 → 上游。
func (h *SoraClientHandler) storeMediaWithDegradation(
ctx context.Context, userID int64, mediaType string,
mediaURL string, mediaURLs []string,
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
urls := mediaURLs
if len(urls) == 0 {
urls = []string{mediaURL}
}
// 第一层:尝试 S3
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
keys := make([]string, 0, len(urls))
var totalSize int64
allOK := true
for _, u := range urls {
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
allOK = false
// 清理已上传的文件
if len(keys) > 0 {
_ = h.s3Storage.DeleteObjects(ctx, keys)
}
break
}
keys = append(keys, key)
totalSize += size
}
if allOK && len(keys) > 0 {
accessURLs := make([]string, 0, len(keys))
for _, key := range keys {
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
_ = h.s3Storage.DeleteObjects(ctx, keys)
allOK = false
break
}
accessURLs = append(accessURLs, accessURL)
}
if allOK && len(accessURLs) > 0 {
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
}
}
}
// 第二层:尝试本地存储
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
if err == nil && len(storedPaths) > 0 {
firstPath := storedPaths[0]
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
if sizeErr != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
}
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
}
// 第三层:保留上游临时 URL
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
}
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
body := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": prompt},
},
"stream": false,
}
if imageInput != "" {
body["image_input"] = imageInput
}
if videoCount > 1 {
body["video_count"] = videoCount
}
b, _ := json.Marshal(body)
return b
}
func normalizeVideoCount(mediaType string, videoCount int) int {
if mediaType != "video" {
return 1
}
if videoCount <= 0 {
return 1
}
if videoCount > 3 {
return 3
}
return videoCount
}
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
// OAuth 路径ForwardResult.MediaURL 已填充。
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
// 优先从 ForwardResult 获取OAuth 路径)
if result != nil && result.MediaURL != "" {
// 尝试从响应体获取完整 URL 列表
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return result.MediaURL, []string{result.MediaURL}
}
// 从响应体解析APIKey 路径)
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
return urls[0], urls
}
return "", nil
}
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
func parseMediaURLsFromBody(body []byte) []string {
if len(body) == 0 {
return nil
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil
}
// 优先 media_urls多图数组
if rawURLs, ok := resp["media_urls"]; ok {
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
urls := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok && s != "" {
urls = append(urls, s)
}
}
if len(urls) > 0 {
return urls
}
}
}
// 回退到 media_url单个 URL
if url, ok := resp["media_url"].(string); ok && url != "" {
return []string{url}
}
return nil
}
// ListGenerations 查询生成记录列表。
// GET /api/v1/sora/generations
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
params := service.SoraGenerationListParams{
UserID: userID,
Status: c.Query("status"),
StorageType: c.Query("storage_type"),
MediaType: c.Query("media_type"),
Page: page,
PageSize: pageSize,
}
gens, total, err := h.genService.List(c.Request.Context(), params)
if err != nil {
response.ErrorFrom(c, err)
return
}
// 为 S3 记录动态生成预签名 URL
for _, gen := range gens {
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
}
response.Success(c, gin.H{
"data": gens,
"total": total,
"page": page,
})
}
// GetGeneration 查询生成记录详情。
// GET /api/v1/sora/generations/:id
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
response.Success(c, gen)
}
// DeleteGeneration 删除生成记录。
// DELETE /api/v1/sora/generations/:id
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
paths := gen.MediaURLs
if len(paths) == 0 && gen.MediaURL != "" {
paths = []string{gen.MediaURL}
}
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
}
}
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
response.Success(c, gin.H{"message": "已删除"})
}
// GetQuota 查询用户存储配额。
// GET /api/v1/sora/quota
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
if h.quotaService == nil {
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
return
}
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, quota)
}
// CancelGeneration 取消生成任务。
// POST /api/v1/sora/generations/:id/cancel
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
// 权限校验
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
_ = gen
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
if errors.Is(err, service.ErrSoraGenerationNotActive) {
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
return
}
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"message": "已取消"})
}
// SaveToStorage 手动保存 upstream 记录到 S3。
// POST /api/v1/sora/generations/:id/save
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
userID := getUserIDFromContext(c)
if userID == 0 {
response.Error(c, http.StatusUnauthorized, "未登录")
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.Error(c, http.StatusBadRequest, "无效的 ID")
return
}
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
if err != nil {
response.Error(c, http.StatusNotFound, err.Error())
return
}
if gen.StorageType != service.SoraStorageTypeUpstream {
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
return
}
if gen.MediaURL == "" {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
return
}
sourceURLs := gen.MediaURLs
if len(sourceURLs) == 0 && gen.MediaURL != "" {
sourceURLs = []string{gen.MediaURL}
}
if len(sourceURLs) == 0 {
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
return
}
uploadedKeys := make([]string, 0, len(sourceURLs))
accessURLs := make([]string, 0, len(sourceURLs))
var totalSize int64
for _, sourceURL := range sourceURLs {
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
if uploadErr != nil {
if len(uploadedKeys) > 0 {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
}
var upstreamErr *service.UpstreamDownloadError
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
return
}
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
return
}
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
if err != nil {
uploadedKeys = append(uploadedKeys, objectKey)
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
return
}
uploadedKeys = append(uploadedKeys, objectKey)
accessURLs = append(accessURLs, accessURL)
totalSize += fileSize
}
usageAdded := false
if totalSize > 0 && h.quotaService != nil {
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
var quotaErr *service.QuotaExceededError
if errors.As(err, &quotaErr) {
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
return
}
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
return
}
usageAdded = true
}
if err := h.genService.UpdateStorageForCompleted(
c.Request.Context(),
id,
accessURLs[0],
accessURLs,
service.SoraStorageTypeS3,
uploadedKeys,
totalSize,
); err != nil {
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
if usageAdded && h.quotaService != nil {
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"message": "已保存到 S3",
"object_key": uploadedKeys[0],
"object_keys": uploadedKeys,
})
}
// GetStorageStatus 返回存储状态。
// GET /api/v1/sora/storage-status
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
s3Healthy := false
if s3Enabled {
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
}
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
response.Success(c, gin.H{
"s3_enabled": s3Enabled,
"s3_healthy": s3Healthy,
"local_enabled": localEnabled,
})
}
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
switch storageType {
case service.SoraStorageTypeS3:
if h.s3Storage != nil && len(s3Keys) > 0 {
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
}
}
case service.SoraStorageTypeLocal:
if h.mediaStorage != nil && len(localPaths) > 0 {
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
}
}
}
}
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
func getUserIDFromContext(c *gin.Context) int64 {
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return subject.UserID
}
if id, ok := c.Get("user_id"); ok {
switch v := id.(type) {
case int64:
return v
case float64:
return int64(v)
case string:
n, _ := strconv.ParseInt(v, 10, 64)
return n
}
}
// 尝试从 JWT claims 获取
if id, ok := c.Get("userID"); ok {
if v, ok := id.(int64); ok {
return v
}
}
return 0
}
func groupIDForLog(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
func trimForLog(raw string, maxLen int) string {
trimmed := strings.TrimSpace(raw)
if maxLen <= 0 || len(trimmed) <= maxLen {
return trimmed
}
return trimmed[:maxLen] + "...(truncated)"
}
// GetModels 获取可用 Sora 模型家族列表。
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
// GET /api/v1/sora/models
func (h *SoraClientHandler) GetModels(c *gin.Context) {
families := h.getModelFamilies(c.Request.Context())
response.Success(c, families)
}
// getModelFamilies 获取模型家族列表(带缓存)。
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
// 读锁检查缓存
h.modelCacheMu.RLock()
ttl := modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
families := h.cachedFamilies
h.modelCacheMu.RUnlock()
return families
}
h.modelCacheMu.RUnlock()
// 写锁更新缓存
h.modelCacheMu.Lock()
defer h.modelCacheMu.Unlock()
// double-check
ttl = modelCacheTTL
if !h.modelCacheUpstream {
ttl = modelCacheFailedTTL
}
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
return h.cachedFamilies
}
// 尝试从上游获取
families, err := h.fetchUpstreamModels(ctx)
if err != nil {
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
families = service.BuildSoraModelFamilies()
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = false
return families
}
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
h.cachedFamilies = families
h.modelCacheTime = time.Now()
h.modelCacheUpstream = true
return families
}
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
if h.gatewayService == nil {
return nil, fmt.Errorf("gatewayService 未初始化")
}
// 设置 ForcePlatform 用于 Sora 账号选择
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
// 选择一个 Sora 账号
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
if err != nil {
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
}
// 仅支持 API Key 类型账号
if account.Type != service.AccountTypeAPIKey {
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
}
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return nil, fmt.Errorf("账号缺少 api_key")
}
baseURL := account.GetBaseURL()
if baseURL == "" {
return nil, fmt.Errorf("账号缺少 base_url")
}
// 构建上游模型列表请求
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求上游失败: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
// 解析 OpenAI 格式的模型列表
var modelsResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(modelsResp.Data) == 0 {
return nil, fmt.Errorf("上游返回空模型列表")
}
// 提取模型 ID
modelIDs := make([]string, 0, len(modelsResp.Data))
for _, m := range modelsResp.Data {
modelIDs = append(modelIDs, m.ID)
}
// 转换为模型家族
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
if len(families) == 0 {
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
}
return families, nil
}