feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题
新增功能: - 新增 Sora 账号管理和 OAuth 认证 - 新增 Sora 视频/图片生成 API 网关 - 新增 Sora 任务调度和缓存机制 - 新增 Sora 使用统计和计费支持 - 前端增加 Sora 平台配置界面 安全修复(代码审核): - [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击 - [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽 - [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置 BUG 修复(代码审核): - [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏 - [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏 性能优化(代码审核): - [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销 技术细节: - 使用 io.LimitReader 限制所有外部输入的大小 - 添加 urlvalidator 验证防止 SSRF 攻击 - 使用 sync.Map 实现线程安全的包级缓存 - 优化并发槽位管理,添加 releaseAll 模式防止泄漏 影响范围: - 后端:新增 Sora 相关数据模型、服务、网关和管理接口 - 前端:新增 Sora 平台配置、账号管理和监控界面 - 配置:新增 Sora 相关配置项和环境变量 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
364
backend/internal/handler/sora_gateway_handler.go
Normal file
364
backend/internal/handler/sora_gateway_handler.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sora"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora OpenAI compatible endpoints.
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler.
|
||||
func NewSoraGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
}
|
||||
return &SoraGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletions handles Sora OpenAI-compatible chat completions endpoint.
|
||||
// POST /v1/chat/completions
|
||||
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := reqBody["model"].(string)
|
||||
if strings.TrimSpace(model) == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
stream, _ := reqBody["stream"].(bool)
|
||||
|
||||
prompt, imageData, videoData, remixID, err := parseSoraPrompt(reqBody)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
if remixID == "" {
|
||||
remixID = sora.ExtractRemixID(prompt)
|
||||
}
|
||||
if remixID != "" {
|
||||
prompt = strings.ReplaceAll(prompt, remixID, "")
|
||||
}
|
||||
|
||||
if apiKey.Group != nil && apiKey.Group.Platform != service.PlatformSora {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "当前分组不支持 Sora 平台")
|
||||
return
|
||||
}
|
||||
|
||||
streamStarted := false
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
if err == nil && !canWait {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
maxSwitches := h.maxAccountSwitches
|
||||
if mode := h.soraGatewayService.CallLogicMode(c.Request.Context()); strings.EqualFold(mode, "native") {
|
||||
maxSwitches = 1
|
||||
}
|
||||
|
||||
for switchCount := 0; switchCount < maxSwitches; switchCount++ {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, "", model, failedAccountIDs, "")
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "server_error", err.Error())
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
releaseFunc := selection.ReleaseFunc
|
||||
|
||||
result, err := h.soraGatewayService.Generate(c.Request.Context(), account, service.SoraGenerationRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Image: imageData,
|
||||
Video: videoData,
|
||||
RemixTargetID: remixID,
|
||||
Stream: stream,
|
||||
UserID: subject.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
// 失败路径:立即释放槽位,而非 defer
|
||||
if releaseFunc != nil {
|
||||
releaseFunc()
|
||||
}
|
||||
|
||||
if errors.Is(err, service.ErrSoraAccountMissingToken) || errors.Is(err, service.ErrSoraAccountNotEligible) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "server_error", err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 成功路径:使用 defer 在函数退出时释放
|
||||
if releaseFunc != nil {
|
||||
defer releaseFunc()
|
||||
}
|
||||
|
||||
h.respondCompletion(c, model, result, stream)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleFailoverExhausted(c, http.StatusServiceUnavailable, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) respondCompletion(c *gin.Context, model string, result *service.SoraGenerationResult, stream bool) {
|
||||
if result == nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Empty response")
|
||||
return
|
||||
}
|
||||
if stream {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
first := buildSoraStreamChunk(model, "", true, "")
|
||||
if _, err := c.Writer.WriteString(first); err != nil {
|
||||
return
|
||||
}
|
||||
final := buildSoraStreamChunk(model, result.Content, false, "stop")
|
||||
if _, err := c.Writer.WriteString(final); err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, buildSoraNonStreamResponse(model, result.Content))
|
||||
}
|
||||
|
||||
func buildSoraStreamChunk(model, content string, isFirst bool, finishReason string) string {
|
||||
chunkID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli())
|
||||
delta := map[string]any{}
|
||||
if isFirst {
|
||||
delta["role"] = "assistant"
|
||||
}
|
||||
if content != "" {
|
||||
delta["content"] = content
|
||||
} else {
|
||||
delta["content"] = nil
|
||||
}
|
||||
response := map[string]any{
|
||||
"id": chunkID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload, _ := json.Marshal(response)
|
||||
return "data: " + string(payload) + "\n\n"
|
||||
}
|
||||
|
||||
func buildSoraNonStreamResponse(model, content string) map[string]any {
|
||||
return map[string]any{
|
||||
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixMilli()),
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseSoraPrompt(req map[string]any) (prompt, imageData, videoData, remixID string, err error) {
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok || len(messages) == 0 {
|
||||
return "", "", "", "", fmt.Errorf("messages is required")
|
||||
}
|
||||
last := messages[len(messages)-1]
|
||||
msg, ok := last.(map[string]any)
|
||||
if !ok {
|
||||
return "", "", "", "", fmt.Errorf("invalid message format")
|
||||
}
|
||||
content, ok := msg["content"]
|
||||
if !ok {
|
||||
return "", "", "", "", fmt.Errorf("content is required")
|
||||
}
|
||||
|
||||
if v, ok := req["image"].(string); ok && v != "" {
|
||||
imageData = v
|
||||
}
|
||||
if v, ok := req["video"].(string); ok && v != "" {
|
||||
videoData = v
|
||||
}
|
||||
if v, ok := req["remix_target_id"].(string); ok {
|
||||
remixID = v
|
||||
}
|
||||
|
||||
switch value := content.(type) {
|
||||
case string:
|
||||
prompt = value
|
||||
case []any:
|
||||
for _, item := range value {
|
||||
part, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch part["type"] {
|
||||
case "text":
|
||||
if text, ok := part["text"].(string); ok {
|
||||
prompt = text
|
||||
}
|
||||
case "image_url":
|
||||
if image, ok := part["image_url"].(map[string]any); ok {
|
||||
if url, ok := image["url"].(string); ok {
|
||||
imageData = url
|
||||
}
|
||||
}
|
||||
case "video_url":
|
||||
if video, ok := part["video_url"].(map[string]any); ok {
|
||||
if url, ok := video["url"].(string); ok {
|
||||
videoData = url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return "", "", "", "", fmt.Errorf("invalid content format")
|
||||
}
|
||||
if strings.TrimSpace(prompt) == "" && strings.TrimSpace(videoData) == "" {
|
||||
return "", "", "", "", fmt.Errorf("prompt is required")
|
||||
}
|
||||
return prompt, imageData, videoData, remixID, nil
|
||||
}
|
||||
|
||||
func looksLikeURL(value string) bool {
|
||||
trimmed := strings.ToLower(strings.TrimSpace(value))
|
||||
return strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://")
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", err.Error(), true)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
message := "No available Sora accounts"
|
||||
h.handleStreamingAwareError(c, statusCode, "server_error", message, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
payload := map[string]any{"error": map[string]any{"message": message, "type": errType, "param": nil, "code": nil}}
|
||||
data, _ := json.Marshal(payload)
|
||||
_, _ = c.Writer.WriteString("data: " + string(data) + "\n\n")
|
||||
_, _ = c.Writer.WriteString("data: [DONE]\n\n")
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"message": message,
|
||||
"type": errType,
|
||||
"param": nil,
|
||||
"code": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user