package handler import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) const ( // 上游模型缓存 TTL modelCacheTTL = 1 * time.Hour // 上游获取成功 modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) ) // SoraClientHandler 处理 Sora 客户端 API 请求。 type SoraClientHandler struct { genService *service.SoraGenerationService quotaService *service.SoraQuotaService s3Storage *service.SoraS3Storage soraGatewayService *service.SoraGatewayService gatewayService *service.GatewayService mediaStorage *service.SoraMediaStorage apiKeyService *service.APIKeyService // 上游模型缓存 modelCacheMu sync.RWMutex cachedFamilies []service.SoraModelFamily modelCacheTime time.Time modelCacheUpstream bool // 是否来自上游(决定 TTL) } // NewSoraClientHandler 创建 Sora 客户端 Handler。 func NewSoraClientHandler( genService *service.SoraGenerationService, quotaService *service.SoraQuotaService, s3Storage *service.SoraS3Storage, soraGatewayService *service.SoraGatewayService, gatewayService *service.GatewayService, mediaStorage *service.SoraMediaStorage, apiKeyService *service.APIKeyService, ) *SoraClientHandler { return &SoraClientHandler{ genService: genService, quotaService: quotaService, s3Storage: s3Storage, soraGatewayService: soraGatewayService, gatewayService: gatewayService, mediaStorage: mediaStorage, apiKeyService: apiKeyService, } } // GenerateRequest 生成请求。 type GenerateRequest struct { Model string `json:"model" binding:"required"` Prompt string `json:"prompt" binding:"required"` MediaType string `json:"media_type"` // video / image,默认 video VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID } // Generate 异步生成 — 创建 pending 记录后立即返回。 // POST /api/v1/sora/generate func (h *SoraClientHandler) Generate(c *gin.Context) { userID := getUserIDFromContext(c) if userID == 0 { response.Error(c, http.StatusUnauthorized, "未登录") return } var req GenerateRequest if err := c.ShouldBindJSON(&req); err != nil { response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) return } if req.MediaType == "" { req.MediaType = "video" } req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) // 并发数检查(最多 3 个) activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) if err != nil { response.ErrorFrom(c, err) return } if activeCount >= 3 { response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") return } // 配额检查(粗略检查,实际文件大小在上传后才知道) if h.quotaService != nil { if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { var quotaErr *service.QuotaExceededError if errors.As(err, "aErr) { response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") return } response.Error(c, http.StatusForbidden, err.Error()) return } } // 获取 API Key ID 和 Group ID var apiKeyID *int64 var groupID *int64 if req.APIKeyID != nil && h.apiKeyService != nil { // 前端传递了 api_key_id,需要校验 apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) if err != nil { response.Error(c, http.StatusBadRequest, "API Key 不存在") return } if apiKey.UserID != userID { response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") return } if apiKey.Status != service.StatusAPIKeyActive { response.Error(c, http.StatusForbidden, "API Key 不可用") return } apiKeyID = &apiKey.ID groupID = apiKey.GroupID } else if id, ok := c.Get("api_key_id"); ok { // 兼容 API Key 认证路径(/sora/v1/ 网关路由) if v, ok := id.(int64); ok { apiKeyID = &v } } gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) if err != nil { if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") return } response.ErrorFrom(c, err) return } // 启动后台异步生成 goroutine go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) response.Success(c, gin.H{ "generation_id": gen.ID, "status": gen.Status, }) } // processGeneration 后台异步执行 Sora 生成任务。 // 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() // 标记为生成中 if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { if errors.Is(err, service.ErrSoraGenerationStateConflict) { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) return } logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) return } logger.LegacyPrintf( "handler.sora_client", "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", genID, userID, groupIDForLog(groupID), model, mediaType, videoCount, strings.TrimSpace(imageInput) != "", len(strings.TrimSpace(prompt)), ) // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 if groupID == nil { ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) } if h.gatewayService == nil { _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") return } // 选择 Sora 账号 account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) if err != nil { logger.LegacyPrintf( "handler.sora_client", "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", genID, userID, groupIDForLog(groupID), model, err, ) _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) return } logger.LegacyPrintf( "handler.sora_client", "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", genID, userID, groupIDForLog(groupID), model, account.ID, account.Name, account.Platform, account.Type, ) // 构建 chat completions 请求体(非流式) body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) if h.soraGatewayService == nil { _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") return } // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) recorder := httptest.NewRecorder() mockGinCtx, _ := gin.CreateTestContext(recorder) mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) // 调用 Forward(非流式) result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) if err != nil { logger.LegacyPrintf( "handler.sora_client", "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", genID, account.ID, model, recorder.Code, trimForLog(recorder.Body.String(), 400), err, ) // 检查是否已取消 gen, _ := h.genService.GetByID(ctx, genID, userID) if gen != nil && gen.Status == service.SoraGenStatusCancelled { return } _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) return } // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) if mediaURL == "" { logger.LegacyPrintf( "handler.sora_client", "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", genID, account.ID, model, recorder.Code, trimForLog(recorder.Body.String(), 400), ) _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") return } // 检查任务是否已被取消 gen, _ := h.genService.GetByID(ctx, genID, userID) if gen != nil && gen.Status == service.SoraGenStatusCancelled { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) return } // 三层降级存储:S3 → 本地 → 上游临时 URL storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) usageAdded := false if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) var quotaErr *service.QuotaExceededError if errors.As(err, "aErr) { _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") return } _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) return } usageAdded = true } // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 gen, _ = h.genService.GetByID(ctx, genID, userID) if gen != nil && gen.Status == service.SoraGenStatusCancelled { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) if usageAdded && h.quotaService != nil { _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) } return } // 标记完成 if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { if errors.Is(err, service.ErrSoraGenerationStateConflict) { h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) if usageAdded && h.quotaService != nil { _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) } return } logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) return } logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) } // storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 func (h *SoraClientHandler) storeMediaWithDegradation( ctx context.Context, userID int64, mediaType string, mediaURL string, mediaURLs []string, ) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { urls := mediaURLs if len(urls) == 0 { urls = []string{mediaURL} } // 第一层:尝试 S3 if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { keys := make([]string, 0, len(urls)) var totalSize int64 allOK := true for _, u := range urls { key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) if err != nil { logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) allOK = false // 清理已上传的文件 if len(keys) > 0 { _ = h.s3Storage.DeleteObjects(ctx, keys) } break } keys = append(keys, key) totalSize += size } if allOK && len(keys) > 0 { accessURLs := make([]string, 0, len(keys)) for _, key := range keys { accessURL, err := h.s3Storage.GetAccessURL(ctx, key) if err != nil { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) _ = h.s3Storage.DeleteObjects(ctx, keys) allOK = false break } accessURLs = append(accessURLs, accessURL) } if allOK && len(accessURLs) > 0 { return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize } } } // 第二层:尝试本地存储 if h.mediaStorage != nil && h.mediaStorage.Enabled() { storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) if err == nil && len(storedPaths) > 0 { firstPath := storedPaths[0] totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) if sizeErr != nil { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) } return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize } logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) } // 第三层:保留上游临时 URL return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 } // buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { body := map[string]any{ "model": model, "messages": []map[string]string{ {"role": "user", "content": prompt}, }, "stream": false, } if imageInput != "" { body["image_input"] = imageInput } if videoCount > 1 { body["video_count"] = videoCount } b, _ := json.Marshal(body) return b } func normalizeVideoCount(mediaType string, videoCount int) int { if mediaType != "video" { return 1 } if videoCount <= 0 { return 1 } if videoCount > 3 { return 3 } return videoCount } // extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 // OAuth 路径:ForwardResult.MediaURL 已填充。 // APIKey 路径:需从响应体解析 media_url / media_urls 字段。 func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { // 优先从 ForwardResult 获取(OAuth 路径) if result != nil && result.MediaURL != "" { // 尝试从响应体获取完整 URL 列表 if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { return urls[0], urls } return result.MediaURL, []string{result.MediaURL} } // 从响应体解析(APIKey 路径) if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { return urls[0], urls } return "", nil } // parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 func parseMediaURLsFromBody(body []byte) []string { if len(body) == 0 { return nil } var resp map[string]any if err := json.Unmarshal(body, &resp); err != nil { return nil } // 优先 media_urls(多图数组) if rawURLs, ok := resp["media_urls"]; ok { if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { urls := make([]string, 0, len(arr)) for _, item := range arr { if s, ok := item.(string); ok && s != "" { urls = append(urls, s) } } if len(urls) > 0 { return urls } } } // 回退到 media_url(单个 URL) if url, ok := resp["media_url"].(string); ok && url != "" { return []string{url} } return nil } // ListGenerations 查询生成记录列表。 // GET /api/v1/sora/generations func (h *SoraClientHandler) ListGenerations(c *gin.Context) { userID := getUserIDFromContext(c) if userID == 0 { response.Error(c, http.StatusUnauthorized, "未登录") return } page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) params := service.SoraGenerationListParams{ UserID: userID, Status: c.Query("status"), StorageType: c.Query("storage_type"), MediaType: c.Query("media_type"), Page: page, PageSize: pageSize, } gens, total, err := h.genService.List(c.Request.Context(), params) if err != nil { response.ErrorFrom(c, err) return } // 为 S3 记录动态生成预签名 URL for _, gen := range gens { _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) } response.Success(c, gin.H{ "data": gens, "total": total, "page": page, }) } // GetGeneration 查询生成记录详情。 // GET /api/v1/sora/generations/:id func (h *SoraClientHandler) GetGeneration(c *gin.Context) { userID := getUserIDFromContext(c) if userID == 0 { response.Error(c, http.StatusUnauthorized, "未登录") return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { response.Error(c, http.StatusBadRequest, "无效的 ID") return } gen, err := h.genService.GetByID(c.Request.Context(), id, userID) if err != nil { response.Error(c, http.StatusNotFound, err.Error()) return } _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) response.Success(c, gen) } // DeleteGeneration 删除生成记录。 // DELETE /api/v1/sora/generations/:id func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { userID := getUserIDFromContext(c) if userID == 0 { response.Error(c, http.StatusUnauthorized, "未登录") return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { response.Error(c, http.StatusBadRequest, "无效的 ID") return } gen, err := h.genService.GetByID(c.Request.Context(), id, userID) if err != nil { response.Error(c, http.StatusNotFound, err.Error()) return } // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { paths := gen.MediaURLs if len(paths) == 0 && gen.MediaURL != "" { paths = []string{gen.MediaURL} } if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) } } if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { response.Error(c, http.StatusNotFound, err.Error()) return } response.Success(c, gin.H{"message": "已删除"}) } // GetQuota 查询用户存储配额。 // GET /api/v1/sora/quota func (h *SoraClientHandler) GetQuota(c *gin.Context) { userID := getUserIDFromContext(c) if userID == 0 { response.Error(c, http.StatusUnauthorized, "未登录") return } if h.quotaService == nil { response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) return } quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) if err != nil { response.ErrorFrom(c, err) return } response.Success(c, quota) } // CancelGeneration 取消生成任务。 // POST /api/v1/sora/generations/:id/cancel func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { userID := getUserIDFromContext(c) if userID == 0 { response.Error(c, http.StatusUnauthorized, "未登录") return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { response.Error(c, http.StatusBadRequest, "无效的 ID") return } // 权限校验 gen, err := h.genService.GetByID(c.Request.Context(), id, userID) if err != nil { response.Error(c, http.StatusNotFound, err.Error()) return } _ = gen if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { if errors.Is(err, service.ErrSoraGenerationNotActive) { response.Error(c, http.StatusConflict, "任务已结束,无法取消") return } response.Error(c, http.StatusBadRequest, err.Error()) return } response.Success(c, gin.H{"message": "已取消"}) } // SaveToStorage 手动保存 upstream 记录到 S3。 // POST /api/v1/sora/generations/:id/save func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { userID := getUserIDFromContext(c) if userID == 0 { response.Error(c, http.StatusUnauthorized, "未登录") return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { response.Error(c, http.StatusBadRequest, "无效的 ID") return } gen, err := h.genService.GetByID(c.Request.Context(), id, userID) if err != nil { response.Error(c, http.StatusNotFound, err.Error()) return } if gen.StorageType != service.SoraStorageTypeUpstream { response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") return } if gen.MediaURL == "" { response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") return } if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") return } sourceURLs := gen.MediaURLs if len(sourceURLs) == 0 && gen.MediaURL != "" { sourceURLs = []string{gen.MediaURL} } if len(sourceURLs) == 0 { response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") return } uploadedKeys := make([]string, 0, len(sourceURLs)) accessURLs := make([]string, 0, len(sourceURLs)) var totalSize int64 for _, sourceURL := range sourceURLs { objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) if uploadErr != nil { if len(uploadedKeys) > 0 { _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) } var upstreamErr *service.UpstreamDownloadError if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") return } response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) return } accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) if err != nil { uploadedKeys = append(uploadedKeys, objectKey) _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) return } uploadedKeys = append(uploadedKeys, objectKey) accessURLs = append(accessURLs, accessURL) totalSize += fileSize } usageAdded := false if totalSize > 0 && h.quotaService != nil { if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) var quotaErr *service.QuotaExceededError if errors.As(err, "aErr) { response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") return } response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) return } usageAdded = true } if err := h.genService.UpdateStorageForCompleted( c.Request.Context(), id, accessURLs[0], accessURLs, service.SoraStorageTypeS3, uploadedKeys, totalSize, ); err != nil { _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) if usageAdded && h.quotaService != nil { _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) } response.ErrorFrom(c, err) return } response.Success(c, gin.H{ "message": "已保存到 S3", "object_key": uploadedKeys[0], "object_keys": uploadedKeys, }) } // GetStorageStatus 返回存储状态。 // GET /api/v1/sora/storage-status func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) s3Healthy := false if s3Enabled { s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) } localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() response.Success(c, gin.H{ "s3_enabled": s3Enabled, "s3_healthy": s3Healthy, "local_enabled": localEnabled, }) } func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { switch storageType { case service.SoraStorageTypeS3: if h.s3Storage != nil && len(s3Keys) > 0 { if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) } } case service.SoraStorageTypeLocal: if h.mediaStorage != nil && len(localPaths) > 0 { if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) } } } } // getUserIDFromContext 从 gin 上下文中提取用户 ID。 func getUserIDFromContext(c *gin.Context) int64 { if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { return subject.UserID } if id, ok := c.Get("user_id"); ok { switch v := id.(type) { case int64: return v case float64: return int64(v) case string: n, _ := strconv.ParseInt(v, 10, 64) return n } } // 尝试从 JWT claims 获取 if id, ok := c.Get("userID"); ok { if v, ok := id.(int64); ok { return v } } return 0 } func groupIDForLog(groupID *int64) int64 { if groupID == nil { return 0 } return *groupID } func trimForLog(raw string, maxLen int) string { trimmed := strings.TrimSpace(raw) if maxLen <= 0 || len(trimmed) <= maxLen { return trimmed } return trimmed[:maxLen] + "...(truncated)" } // GetModels 获取可用 Sora 模型家族列表。 // 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 // GET /api/v1/sora/models func (h *SoraClientHandler) GetModels(c *gin.Context) { families := h.getModelFamilies(c.Request.Context()) response.Success(c, families) } // getModelFamilies 获取模型家族列表(带缓存)。 func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { // 读锁检查缓存 h.modelCacheMu.RLock() ttl := modelCacheTTL if !h.modelCacheUpstream { ttl = modelCacheFailedTTL } if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { families := h.cachedFamilies h.modelCacheMu.RUnlock() return families } h.modelCacheMu.RUnlock() // 写锁更新缓存 h.modelCacheMu.Lock() defer h.modelCacheMu.Unlock() // double-check ttl = modelCacheTTL if !h.modelCacheUpstream { ttl = modelCacheFailedTTL } if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { return h.cachedFamilies } // 尝试从上游获取 families, err := h.fetchUpstreamModels(ctx) if err != nil { logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) families = service.BuildSoraModelFamilies() h.cachedFamilies = families h.modelCacheTime = time.Now() h.modelCacheUpstream = false return families } logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) h.cachedFamilies = families h.modelCacheTime = time.Now() h.modelCacheUpstream = true return families } // fetchUpstreamModels 从上游 Sora API 获取模型列表。 func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { if h.gatewayService == nil { return nil, fmt.Errorf("gatewayService 未初始化") } // 设置 ForcePlatform 用于 Sora 账号选择 ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) // 选择一个 Sora 账号 account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") if err != nil { return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) } // 仅支持 API Key 类型账号 if account.Type != service.AccountTypeAPIKey { return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) } apiKey := account.GetCredential("api_key") if apiKey == "" { return nil, fmt.Errorf("账号缺少 api_key") } baseURL := account.GetBaseURL() if baseURL == "" { return nil, fmt.Errorf("账号缺少 base_url") } // 构建上游模型列表请求 modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Authorization", "Bearer "+apiKey) client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("请求上游失败: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) } body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) if err != nil { return nil, fmt.Errorf("读取响应失败: %w", err) } // 解析 OpenAI 格式的模型列表 var modelsResp struct { Data []struct { ID string `json:"id"` } `json:"data"` } if err := json.Unmarshal(body, &modelsResp); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } if len(modelsResp.Data) == 0 { return nil, fmt.Errorf("上游返回空模型列表") } // 提取模型 ID modelIDs := make([]string, 0, len(modelsResp.Data)) for _, m := range modelsResp.Data { modelIDs = append(modelIDs, m.ID) } // 转换为模型家族 families := service.BuildSoraModelFamiliesFromIDs(modelIDs) if len(families) == 0 { return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") } return families, nil }