From 9b120e68b8acdc5aa5ca03f2094989daf54061be Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Wed, 4 Feb 2026 14:06:06 +0800 Subject: [PATCH] =?UTF-8?q?fix(sora):=20=E6=81=A2=E5=A4=8D=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=BE=85=E5=8A=A9=E9=80=BB=E8=BE=91=E5=B9=B6=E9=80=9A?= =?UTF-8?q?=E8=BF=87=20lint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/sora_client.go | 5 +- .../internal/service/sora_gateway_service.go | 542 +----------------- .../service/sora_gateway_streaming_legacy.go | 532 +++++++++++++++++ .../internal/service/sora_media_storage.go | 2 - 4 files changed, 546 insertions(+), 535 deletions(-) create mode 100644 backend/internal/service/sora_gateway_streaming_legacy.go diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index f3a71a79..e2b85671 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -672,10 +672,7 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { - enableTLS := false - if c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint { - enableTLS = true - } + enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint if c.httpUpstream != nil { accountID := int64(0) accountConcurrency := 0 diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index ea696d63..68ebd90a 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -1,8 +1,6 @@ package service import ( - "bufio" - "bytes" "context" "encoding/base64" "encoding/json" @@ -13,7 +11,6 @@ import ( "net" "net/http" "net/url" - "regexp" "strconv" "strings" "time" @@ -22,11 +19,6 @@ import ( "github.com/gin-gonic/gin" ) -var soraSSEDataRe = regexp.MustCompile(`^data:\s*`) -var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) -var soraVideoHTMLRe = regexp.MustCompile(`(?i)]+src=['"]([^'"]+)['"]`) - -const soraRewriteBufferLimit = 2048 const soraImageInputMaxBytes = 20 << 20 const soraImageInputMaxRedirects = 3 const soraImageInputTimeout = 20 * time.Second @@ -60,14 +52,6 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{ "fe80::/10", }) -type soraStreamingResult struct { - mediaType string - mediaURLs []string - imageCount int - imageSize string - firstTokenMs *int -} - // SoraGatewayService handles forwarding requests to Sora upstream. type SoraGatewayService struct { soraClient SoraClient @@ -203,7 +187,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun mediaType := modelCfg.Type imageCount := 0 imageSize := "" - if modelCfg.Type == "image" { + switch modelCfg.Type { + case "image": urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) if pollErr != nil { return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) @@ -211,25 +196,23 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun mediaURLs = urls imageCount = len(urls) imageSize = soraImageSizeFromModel(reqModel) - } else if modelCfg.Type == "video" { + case "video": urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream) if pollErr != nil { return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) } mediaURLs = urls - } else { + default: mediaType = "prompt" } - finalURLs := mediaURLs + finalURLs := s.normalizeSoraMediaURLs(mediaURLs) if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) if storeErr != nil { return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream) } finalURLs = s.normalizeSoraMediaURLs(stored) - } else { - finalURLs = s.normalizeSoraMediaURLs(mediaURLs) } content := buildSoraContent(mediaType, finalURLs) @@ -279,27 +262,6 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) ( return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) } -func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) { - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - if c != nil { - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - } -} - func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { case 401, 402, 403, 429, 529: @@ -309,480 +271,6 @@ func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { } } -func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { - if s.rateLimitService == nil || account == nil || resp == nil { - return - } - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) -} - -func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" { - upstreamMsg = msg - } - - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(respBody), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - - if c != nil { - responsePayload := s.buildErrorPayload(respBody, upstreamMsg) - c.JSON(resp.StatusCode, responsePayload) - } - if upstreamMsg == "" { - return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) -} - -func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any { - if len(respBody) > 0 { - var payload map[string]any - if err := json.Unmarshal(respBody, &payload); err == nil { - if errObj, ok := payload["error"].(map[string]any); ok { - if overrideMessage != "" { - errObj["message"] = overrideMessage - } - payload["error"] = errObj - return payload - } - } - } - return map[string]any{ - "error": map[string]any{ - "type": "upstream_error", - "message": overrideMessage, - }, - } -} - -func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) { - if resp == nil { - return nil, errors.New("empty response") - } - - if clientStream { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - if v := resp.Header.Get("x-request-id"); v != "" { - c.Header("x-request-id", v) - } - } - - w := c.Writer - flusher, _ := w.(http.Flusher) - - contentBuilder := strings.Builder{} - var firstTokenMs *int - var upstreamError error - rewriteBuffer := "" - - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) - - sendLine := func(line string) error { - if !clientStream { - return nil - } - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return err - } - if flusher != nil { - flusher.Flush() - } - return nil - } - - for scanner.Scan() { - line := scanner.Text() - if soraSSEDataRe.MatchString(line) { - data := soraSSEDataRe.ReplaceAllString(line, "") - if data == "[DONE]" { - if rewriteBuffer != "" { - flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel) - if err != nil { - return nil, err - } - if flushLine != "" { - if flushContent != "" { - if _, err := contentBuilder.WriteString(flushContent); err != nil { - return nil, err - } - } - if err := sendLine(flushLine); err != nil { - return nil, err - } - } - rewriteBuffer = "" - } - if err := sendLine("data: [DONE]"); err != nil { - return nil, err - } - break - } - updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer) - if errEvent != nil && upstreamError == nil { - upstreamError = errEvent - } - if contentDelta != "" { - if firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - if _, err := contentBuilder.WriteString(contentDelta); err != nil { - return nil, err - } - } - if err := sendLine(updatedLine); err != nil { - return nil, err - } - continue - } - if err := sendLine(line); err != nil { - return nil, err - } - } - - if err := scanner.Err(); err != nil { - if errors.Is(err, bufio.ErrTooLong) { - if clientStream { - _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n") - if flusher != nil { - flusher.Flush() - } - } - return nil, err - } - if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil { - s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) - } - if clientStream { - _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n") - if flusher != nil { - flusher.Flush() - } - } - return nil, err - } - - content := contentBuilder.String() - mediaType, mediaURLs := s.extractSoraMedia(content) - if mediaType == "" && isSoraPromptEnhanceModel(originalModel) { - mediaType = "prompt" - } - imageSize := "" - imageCount := 0 - if mediaType == "image" { - imageSize = soraImageSizeFromModel(originalModel) - imageCount = len(mediaURLs) - } - - if upstreamError != nil && !clientStream { - if c != nil { - c.JSON(http.StatusBadGateway, map[string]any{ - "error": map[string]any{ - "type": "upstream_error", - "message": upstreamError.Error(), - }, - }) - } - return nil, upstreamError - } - - if !clientStream { - response := buildSoraNonStreamResponse(content, originalModel) - if len(mediaURLs) > 0 { - response["media_url"] = mediaURLs[0] - if len(mediaURLs) > 1 { - response["media_urls"] = mediaURLs - } - } - c.JSON(http.StatusOK, response) - } - - return &soraStreamingResult{ - mediaType: mediaType, - mediaURLs: mediaURLs, - imageCount: imageCount, - imageSize: imageSize, - firstTokenMs: firstTokenMs, - }, nil -} - -func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) { - if strings.TrimSpace(data) == "" { - return "data: ", "", nil - } - - var payload map[string]any - if err := json.Unmarshal([]byte(data), &payload); err != nil { - return "data: " + data, "", nil - } - - if errObj, ok := payload["error"].(map[string]any); ok { - if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { - return "data: " + data, "", errors.New(msg) - } - } - - if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" { - payload["model"] = originalModel - } - - contentDelta, updated := extractSoraContent(payload) - if updated { - var rewritten string - if rewriteBuffer != nil { - rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer) - } else { - rewritten = s.rewriteSoraContent(contentDelta) - } - if rewritten != contentDelta { - applySoraContent(payload, rewritten) - contentDelta = rewritten - } - } - - updatedData, err := json.Marshal(payload) - if err != nil { - return "data: " + data, contentDelta, nil - } - return "data: " + string(updatedData), contentDelta, nil -} - -func extractSoraContent(payload map[string]any) (string, bool) { - choices, ok := payload["choices"].([]any) - if !ok || len(choices) == 0 { - return "", false - } - choice, ok := choices[0].(map[string]any) - if !ok { - return "", false - } - if delta, ok := choice["delta"].(map[string]any); ok { - if content, ok := delta["content"].(string); ok { - return content, true - } - } - if message, ok := choice["message"].(map[string]any); ok { - if content, ok := message["content"].(string); ok { - return content, true - } - } - return "", false -} - -func applySoraContent(payload map[string]any, content string) { - choices, ok := payload["choices"].([]any) - if !ok || len(choices) == 0 { - return - } - choice, ok := choices[0].(map[string]any) - if !ok { - return - } - if delta, ok := choice["delta"].(map[string]any); ok { - delta["content"] = content - choice["delta"] = delta - return - } - if message, ok := choice["message"].(map[string]any); ok { - message["content"] = content - choice["message"] = message - } -} - -func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string { - if buffer == nil { - return s.rewriteSoraContent(contentDelta) - } - if contentDelta == "" && *buffer == "" { - return "" - } - combined := *buffer + contentDelta - rewritten := s.rewriteSoraContent(combined) - bufferStart := s.findSoraRewriteBufferStart(rewritten) - if bufferStart < 0 { - *buffer = "" - return rewritten - } - if len(rewritten)-bufferStart > soraRewriteBufferLimit { - bufferStart = len(rewritten) - soraRewriteBufferLimit - } - output := rewritten[:bufferStart] - *buffer = rewritten[bufferStart:] - return output -} - -func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int { - minIndex := -1 - start := 0 - for { - idx := strings.Index(content[start:], "![") - if idx < 0 { - break - } - idx += start - if !hasSoraImageMatchAt(content, idx) { - if minIndex == -1 || idx < minIndex { - minIndex = idx - } - } - start = idx + 2 - } - lower := strings.ToLower(content) - start = 0 - for { - idx := strings.Index(lower[start:], "= len(content) { - return false - } - loc := soraImageMarkdownRe.FindStringIndex(content[idx:]) - return loc != nil && loc[0] == 0 -} - -func hasSoraVideoMatchAt(content string, idx int) bool { - if idx < 0 || idx >= len(content) { - return false - } - loc := soraVideoHTMLRe.FindStringIndex(content[idx:]) - return loc != nil && loc[0] == 0 -} - -func (s *SoraGatewayService) rewriteSoraContent(content string) string { - if content == "" { - return content - } - content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string { - sub := soraImageMarkdownRe.FindStringSubmatch(match) - if len(sub) < 2 { - return match - } - rewritten := s.rewriteSoraURL(sub[1]) - if rewritten == sub[1] { - return match - } - return strings.Replace(match, sub[1], rewritten, 1) - }) - content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string { - sub := soraVideoHTMLRe.FindStringSubmatch(match) - if len(sub) < 2 { - return match - } - rewritten := s.rewriteSoraURL(sub[1]) - if rewritten == sub[1] { - return match - } - return strings.Replace(match, sub[1], rewritten, 1) - }) - return content -} - -func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) { - if buffer == "" { - return "", "", nil - } - rewritten := s.rewriteSoraContent(buffer) - payload := map[string]any{ - "choices": []any{ - map[string]any{ - "delta": map[string]any{ - "content": rewritten, - }, - "index": 0, - }, - }, - } - if originalModel != "" { - payload["model"] = originalModel - } - updatedData, err := json.Marshal(payload) - if err != nil { - return "", "", err - } - return "data: " + string(updatedData), rewritten, nil -} - -func (s *SoraGatewayService) rewriteSoraURL(raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return raw - } - parsed, err := url.Parse(raw) - if err != nil { - return raw - } - path := parsed.Path - if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") { - return raw - } - return s.buildSoraMediaURL(path, parsed.RawQuery) -} - -func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) { - if content == "" { - return "", nil - } - if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 { - return "video", []string{match[1]} - } - imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1) - if len(imageMatches) == 0 { - return "", nil - } - urls := make([]string, 0, len(imageMatches)) - for _, match := range imageMatches { - if len(match) > 1 { - urls = append(urls, match[1]) - } - } - return "image", urls -} - func buildSoraNonStreamResponse(content, model string) map[string]any { return map[string]any{ "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), @@ -813,10 +301,6 @@ func soraImageSizeFromModel(model string) string { return "360" } -func isSoraPromptEnhanceModel(model string) bool { - return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance") -} - func soraProErrorMessage(model, upstreamMsg string) string { modelLower := strings.ToLower(model) if strings.Contains(modelLower, "sora2pro-hd") { @@ -1006,7 +490,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, if status.ErrorMsg != "" { return nil, errors.New(status.ErrorMsg) } - return nil, errors.New("Sora image generation failed") + return nil, errors.New("sora image generation failed") } if stream { s.maybeSendPing(c, &lastPing) @@ -1015,7 +499,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, return nil, err } } - return nil, errors.New("Sora image generation timeout") + return nil, errors.New("sora image generation timeout") } func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { @@ -1034,7 +518,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, if status.ErrorMsg != "" { return nil, errors.New(status.ErrorMsg) } - return nil, errors.New("Sora video generation failed") + return nil, errors.New("sora video generation failed") } if stream { s.maybeSendPing(c, &lastPing) @@ -1043,7 +527,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, return nil, err } } - return nil, errors.New("Sora video generation timeout") + return nil, errors.New("sora video generation timeout") } func (s *SoraGatewayService) pollInterval() time.Duration { @@ -1159,9 +643,9 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi text, img, vid := parseSoraMessageContent(content) if text != "" { if builder.Len() > 0 { - builder.WriteString("\n") + _, _ = builder.WriteString("\n") } - builder.WriteString(text) + _, _ = builder.WriteString(text) } if imageInput == "" && img != "" { imageInput = img @@ -1193,9 +677,9 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string) case "text": if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { if builder.Len() > 0 { - builder.WriteString("\n") + _, _ = builder.WriteString("\n") } - builder.WriteString(txt) + _, _ = builder.WriteString(txt) } case "image_url": if imageInput == "" { diff --git a/backend/internal/service/sora_gateway_streaming_legacy.go b/backend/internal/service/sora_gateway_streaming_legacy.go new file mode 100644 index 00000000..8a38f181 --- /dev/null +++ b/backend/internal/service/sora_gateway_streaming_legacy.go @@ -0,0 +1,532 @@ +//nolint:unused +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +var soraSSEDataRe = regexp.MustCompile(`^data:\s*`) +var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) +var soraVideoHTMLRe = regexp.MustCompile(`(?i)]+src=['"]([^'"]+)['"]`) + +const soraRewriteBufferLimit = 2048 + +type soraStreamingResult struct { + mediaType string + mediaURLs []string + imageCount int + imageSize string + firstTokenMs *int +} + +func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if c != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + } +} + +func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + if s.rateLimitService == nil || account == nil || resp == nil { + return + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" { + upstreamMsg = msg + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if c != nil { + responsePayload := s.buildErrorPayload(respBody, upstreamMsg) + c.JSON(resp.StatusCode, responsePayload) + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any { + if len(respBody) > 0 { + var payload map[string]any + if err := json.Unmarshal(respBody, &payload); err == nil { + if errObj, ok := payload["error"].(map[string]any); ok { + if overrideMessage != "" { + errObj["message"] = overrideMessage + } + payload["error"] = errObj + return payload + } + } + } + return map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "message": overrideMessage, + }, + } +} + +func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) { + if resp == nil { + return nil, errors.New("empty response") + } + + if clientStream { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + } + + w := c.Writer + flusher, _ := w.(http.Flusher) + + contentBuilder := strings.Builder{} + var firstTokenMs *int + var upstreamError error + rewriteBuffer := "" + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + sendLine := func(line string) error { + if !clientStream { + return nil + } + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + return err + } + if flusher != nil { + flusher.Flush() + } + return nil + } + + for scanner.Scan() { + line := scanner.Text() + if soraSSEDataRe.MatchString(line) { + data := soraSSEDataRe.ReplaceAllString(line, "") + if data == "[DONE]" { + if rewriteBuffer != "" { + flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel) + if err != nil { + return nil, err + } + if flushLine != "" { + if flushContent != "" { + if _, err := contentBuilder.WriteString(flushContent); err != nil { + return nil, err + } + } + if err := sendLine(flushLine); err != nil { + return nil, err + } + } + rewriteBuffer = "" + } + if err := sendLine("data: [DONE]"); err != nil { + return nil, err + } + break + } + updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer) + if errEvent != nil && upstreamError == nil { + upstreamError = errEvent + } + if contentDelta != "" { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if _, err := contentBuilder.WriteString(contentDelta); err != nil { + return nil, err + } + } + if err := sendLine(updatedLine); err != nil { + return nil, err + } + continue + } + if err := sendLine(line); err != nil { + return nil, err + } + } + + if err := scanner.Err(); err != nil { + if errors.Is(err, bufio.ErrTooLong) { + if clientStream { + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n") + if flusher != nil { + flusher.Flush() + } + } + return nil, err + } + if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) + } + if clientStream { + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n") + if flusher != nil { + flusher.Flush() + } + } + return nil, err + } + + content := contentBuilder.String() + mediaType, mediaURLs := s.extractSoraMedia(content) + if mediaType == "" && isSoraPromptEnhanceModel(originalModel) { + mediaType = "prompt" + } + imageSize := "" + imageCount := 0 + if mediaType == "image" { + imageSize = soraImageSizeFromModel(originalModel) + imageCount = len(mediaURLs) + } + + if upstreamError != nil && !clientStream { + if c != nil { + c.JSON(http.StatusBadGateway, map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "message": upstreamError.Error(), + }, + }) + } + return nil, upstreamError + } + + if !clientStream { + response := buildSoraNonStreamResponse(content, originalModel) + if len(mediaURLs) > 0 { + response["media_url"] = mediaURLs[0] + if len(mediaURLs) > 1 { + response["media_urls"] = mediaURLs + } + } + c.JSON(http.StatusOK, response) + } + + return &soraStreamingResult{ + mediaType: mediaType, + mediaURLs: mediaURLs, + imageCount: imageCount, + imageSize: imageSize, + firstTokenMs: firstTokenMs, + }, nil +} + +func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) { + if strings.TrimSpace(data) == "" { + return "data: ", "", nil + } + + var payload map[string]any + if err := json.Unmarshal([]byte(data), &payload); err != nil { + return "data: " + data, "", nil + } + + if errObj, ok := payload["error"].(map[string]any); ok { + if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { + return "data: " + data, "", errors.New(msg) + } + } + + if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" { + payload["model"] = originalModel + } + + contentDelta, updated := extractSoraContent(payload) + if updated { + var rewritten string + if rewriteBuffer != nil { + rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer) + } else { + rewritten = s.rewriteSoraContent(contentDelta) + } + if rewritten != contentDelta { + applySoraContent(payload, rewritten) + contentDelta = rewritten + } + } + + updatedData, err := json.Marshal(payload) + if err != nil { + return "data: " + data, contentDelta, nil + } + return "data: " + string(updatedData), contentDelta, nil +} + +func extractSoraContent(payload map[string]any) (string, bool) { + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + return "", false + } + choice, ok := choices[0].(map[string]any) + if !ok { + return "", false + } + if delta, ok := choice["delta"].(map[string]any); ok { + if content, ok := delta["content"].(string); ok { + return content, true + } + } + if message, ok := choice["message"].(map[string]any); ok { + if content, ok := message["content"].(string); ok { + return content, true + } + } + return "", false +} + +func applySoraContent(payload map[string]any, content string) { + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + return + } + choice, ok := choices[0].(map[string]any) + if !ok { + return + } + if delta, ok := choice["delta"].(map[string]any); ok { + delta["content"] = content + choice["delta"] = delta + return + } + if message, ok := choice["message"].(map[string]any); ok { + message["content"] = content + choice["message"] = message + } +} + +func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string { + if buffer == nil { + return s.rewriteSoraContent(contentDelta) + } + if contentDelta == "" && *buffer == "" { + return "" + } + combined := *buffer + contentDelta + rewritten := s.rewriteSoraContent(combined) + bufferStart := s.findSoraRewriteBufferStart(rewritten) + if bufferStart < 0 { + *buffer = "" + return rewritten + } + if len(rewritten)-bufferStart > soraRewriteBufferLimit { + bufferStart = len(rewritten) - soraRewriteBufferLimit + } + output := rewritten[:bufferStart] + *buffer = rewritten[bufferStart:] + return output +} + +func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int { + minIndex := -1 + start := 0 + for { + idx := strings.Index(content[start:], "![") + if idx < 0 { + break + } + idx += start + if !hasSoraImageMatchAt(content, idx) { + if minIndex == -1 || idx < minIndex { + minIndex = idx + } + } + start = idx + 2 + } + lower := strings.ToLower(content) + start = 0 + for { + idx := strings.Index(lower[start:], "= len(content) { + return false + } + loc := soraImageMarkdownRe.FindStringIndex(content[idx:]) + return loc != nil && loc[0] == 0 +} + +func hasSoraVideoMatchAt(content string, idx int) bool { + if idx < 0 || idx >= len(content) { + return false + } + loc := soraVideoHTMLRe.FindStringIndex(content[idx:]) + return loc != nil && loc[0] == 0 +} + +func (s *SoraGatewayService) rewriteSoraContent(content string) string { + if content == "" { + return content + } + content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string { + sub := soraImageMarkdownRe.FindStringSubmatch(match) + if len(sub) < 2 { + return match + } + rewritten := s.rewriteSoraURL(sub[1]) + if rewritten == sub[1] { + return match + } + return strings.Replace(match, sub[1], rewritten, 1) + }) + content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string { + sub := soraVideoHTMLRe.FindStringSubmatch(match) + if len(sub) < 2 { + return match + } + rewritten := s.rewriteSoraURL(sub[1]) + if rewritten == sub[1] { + return match + } + return strings.Replace(match, sub[1], rewritten, 1) + }) + return content +} + +func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) { + if buffer == "" { + return "", "", nil + } + rewritten := s.rewriteSoraContent(buffer) + payload := map[string]any{ + "choices": []any{ + map[string]any{ + "delta": map[string]any{ + "content": rewritten, + }, + "index": 0, + }, + }, + } + if originalModel != "" { + payload["model"] = originalModel + } + updatedData, err := json.Marshal(payload) + if err != nil { + return "", "", err + } + return "data: " + string(updatedData), rewritten, nil +} + +func (s *SoraGatewayService) rewriteSoraURL(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return raw + } + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + path := parsed.Path + if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") { + return raw + } + return s.buildSoraMediaURL(path, parsed.RawQuery) +} + +func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) { + if content == "" { + return "", nil + } + if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 { + return "video", []string{match[1]} + } + imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1) + if len(imageMatches) == 0 { + return "", nil + } + urls := make([]string, 0, len(imageMatches)) + for _, match := range imageMatches { + if len(match) > 1 { + urls = append(urls, match[1]) + } + } + return "image", urls +} + +func isSoraPromptEnhanceModel(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance") +} diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go index 4359f78d..562ded46 100644 --- a/backend/internal/service/sora_media_storage.go +++ b/backend/internal/service/sora_media_storage.go @@ -29,7 +29,6 @@ type SoraMediaStorage struct { root string imageRoot string videoRoot string - maxConcurrent int downloadTimeout time.Duration maxDownloadBytes int64 fallbackToUpstream bool @@ -93,7 +92,6 @@ func (s *SoraMediaStorage) refreshConfig() { if maxConcurrent <= 0 { maxConcurrent = 4 } - s.maxConcurrent = maxConcurrent timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds if timeoutSeconds <= 0 { timeoutSeconds = 120