fix(sora): 恢复流式辅助逻辑并通过 lint
This commit is contained in:
@@ -672,10 +672,7 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||||
enableTLS := false
|
enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
|
||||||
if c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint {
|
|
||||||
enableTLS = true
|
|
||||||
}
|
|
||||||
if c.httpUpstream != nil {
|
if c.httpUpstream != nil {
|
||||||
accountID := int64(0)
|
accountID := int64(0)
|
||||||
accountConcurrency := 0
|
accountConcurrency := 0
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -13,7 +11,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -22,11 +19,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
|
||||||
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
|
||||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
|
||||||
|
|
||||||
const soraRewriteBufferLimit = 2048
|
|
||||||
const soraImageInputMaxBytes = 20 << 20
|
const soraImageInputMaxBytes = 20 << 20
|
||||||
const soraImageInputMaxRedirects = 3
|
const soraImageInputMaxRedirects = 3
|
||||||
const soraImageInputTimeout = 20 * time.Second
|
const soraImageInputTimeout = 20 * time.Second
|
||||||
@@ -60,14 +52,6 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
|
|||||||
"fe80::/10",
|
"fe80::/10",
|
||||||
})
|
})
|
||||||
|
|
||||||
type soraStreamingResult struct {
|
|
||||||
mediaType string
|
|
||||||
mediaURLs []string
|
|
||||||
imageCount int
|
|
||||||
imageSize string
|
|
||||||
firstTokenMs *int
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoraGatewayService handles forwarding requests to Sora upstream.
|
// SoraGatewayService handles forwarding requests to Sora upstream.
|
||||||
type SoraGatewayService struct {
|
type SoraGatewayService struct {
|
||||||
soraClient SoraClient
|
soraClient SoraClient
|
||||||
@@ -203,7 +187,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
mediaType := modelCfg.Type
|
mediaType := modelCfg.Type
|
||||||
imageCount := 0
|
imageCount := 0
|
||||||
imageSize := ""
|
imageSize := ""
|
||||||
if modelCfg.Type == "image" {
|
switch modelCfg.Type {
|
||||||
|
case "image":
|
||||||
urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
|
urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
|
||||||
if pollErr != nil {
|
if pollErr != nil {
|
||||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||||
@@ -211,25 +196,23 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
mediaURLs = urls
|
mediaURLs = urls
|
||||||
imageCount = len(urls)
|
imageCount = len(urls)
|
||||||
imageSize = soraImageSizeFromModel(reqModel)
|
imageSize = soraImageSizeFromModel(reqModel)
|
||||||
} else if modelCfg.Type == "video" {
|
case "video":
|
||||||
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
||||||
if pollErr != nil {
|
if pollErr != nil {
|
||||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||||
}
|
}
|
||||||
mediaURLs = urls
|
mediaURLs = urls
|
||||||
} else {
|
default:
|
||||||
mediaType = "prompt"
|
mediaType = "prompt"
|
||||||
}
|
}
|
||||||
|
|
||||||
finalURLs := mediaURLs
|
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
||||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||||
if storeErr != nil {
|
if storeErr != nil {
|
||||||
return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream)
|
return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream)
|
||||||
}
|
}
|
||||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||||
} else {
|
|
||||||
finalURLs = s.normalizeSoraMediaURLs(mediaURLs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := buildSoraContent(mediaType, finalURLs)
|
content := buildSoraContent(mediaType, finalURLs)
|
||||||
@@ -279,27 +262,6 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
|
|||||||
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
|
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|
||||||
setOpsUpstreamError(c, 0, safeErr, "")
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: 0,
|
|
||||||
Kind: "request_error",
|
|
||||||
Message: safeErr,
|
|
||||||
})
|
|
||||||
if c != nil {
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream request failed",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 401, 402, 403, 429, 529:
|
case 401, 402, 403, 429, 529:
|
||||||
@@ -309,480 +271,6 @@ func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
|
||||||
if s.rateLimitService == nil || account == nil || resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
|
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
||||||
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
||||||
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
|
|
||||||
upstreamMsg = msg
|
|
||||||
}
|
|
||||||
|
|
||||||
upstreamDetail := ""
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
|
||||||
if maxBytes <= 0 {
|
|
||||||
maxBytes = 2048
|
|
||||||
}
|
|
||||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
|
||||||
}
|
|
||||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
|
||||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
||||||
Platform: account.Platform,
|
|
||||||
AccountID: account.ID,
|
|
||||||
AccountName: account.Name,
|
|
||||||
UpstreamStatusCode: resp.StatusCode,
|
|
||||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
||||||
Kind: "http_error",
|
|
||||||
Message: upstreamMsg,
|
|
||||||
Detail: upstreamDetail,
|
|
||||||
})
|
|
||||||
|
|
||||||
if c != nil {
|
|
||||||
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
|
|
||||||
c.JSON(resp.StatusCode, responsePayload)
|
|
||||||
}
|
|
||||||
if upstreamMsg == "" {
|
|
||||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
|
|
||||||
if len(respBody) > 0 {
|
|
||||||
var payload map[string]any
|
|
||||||
if err := json.Unmarshal(respBody, &payload); err == nil {
|
|
||||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
|
||||||
if overrideMessage != "" {
|
|
||||||
errObj["message"] = overrideMessage
|
|
||||||
}
|
|
||||||
payload["error"] = errObj
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return map[string]any{
|
|
||||||
"error": map[string]any{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": overrideMessage,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
|
|
||||||
if resp == nil {
|
|
||||||
return nil, errors.New("empty response")
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientStream {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("X-Accel-Buffering", "no")
|
|
||||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
|
||||||
c.Header("x-request-id", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w := c.Writer
|
|
||||||
flusher, _ := w.(http.Flusher)
|
|
||||||
|
|
||||||
contentBuilder := strings.Builder{}
|
|
||||||
var firstTokenMs *int
|
|
||||||
var upstreamError error
|
|
||||||
rewriteBuffer := ""
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
maxLineSize := defaultMaxLineSize
|
|
||||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
|
||||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
|
||||||
}
|
|
||||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
|
||||||
|
|
||||||
sendLine := func(line string) error {
|
|
||||||
if !clientStream {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if soraSSEDataRe.MatchString(line) {
|
|
||||||
data := soraSSEDataRe.ReplaceAllString(line, "")
|
|
||||||
if data == "[DONE]" {
|
|
||||||
if rewriteBuffer != "" {
|
|
||||||
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if flushLine != "" {
|
|
||||||
if flushContent != "" {
|
|
||||||
if _, err := contentBuilder.WriteString(flushContent); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := sendLine(flushLine); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriteBuffer = ""
|
|
||||||
}
|
|
||||||
if err := sendLine("data: [DONE]"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
|
|
||||||
if errEvent != nil && upstreamError == nil {
|
|
||||||
upstreamError = errEvent
|
|
||||||
}
|
|
||||||
if contentDelta != "" {
|
|
||||||
if firstTokenMs == nil {
|
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
|
||||||
firstTokenMs = &ms
|
|
||||||
}
|
|
||||||
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := sendLine(updatedLine); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := sendLine(line); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if errors.Is(err, bufio.ErrTooLong) {
|
|
||||||
if clientStream {
|
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
|
|
||||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
|
||||||
}
|
|
||||||
if clientStream {
|
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
|
|
||||||
if flusher != nil {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
content := contentBuilder.String()
|
|
||||||
mediaType, mediaURLs := s.extractSoraMedia(content)
|
|
||||||
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
|
|
||||||
mediaType = "prompt"
|
|
||||||
}
|
|
||||||
imageSize := ""
|
|
||||||
imageCount := 0
|
|
||||||
if mediaType == "image" {
|
|
||||||
imageSize = soraImageSizeFromModel(originalModel)
|
|
||||||
imageCount = len(mediaURLs)
|
|
||||||
}
|
|
||||||
|
|
||||||
if upstreamError != nil && !clientStream {
|
|
||||||
if c != nil {
|
|
||||||
c.JSON(http.StatusBadGateway, map[string]any{
|
|
||||||
"error": map[string]any{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": upstreamError.Error(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, upstreamError
|
|
||||||
}
|
|
||||||
|
|
||||||
if !clientStream {
|
|
||||||
response := buildSoraNonStreamResponse(content, originalModel)
|
|
||||||
if len(mediaURLs) > 0 {
|
|
||||||
response["media_url"] = mediaURLs[0]
|
|
||||||
if len(mediaURLs) > 1 {
|
|
||||||
response["media_urls"] = mediaURLs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &soraStreamingResult{
|
|
||||||
mediaType: mediaType,
|
|
||||||
mediaURLs: mediaURLs,
|
|
||||||
imageCount: imageCount,
|
|
||||||
imageSize: imageSize,
|
|
||||||
firstTokenMs: firstTokenMs,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
|
|
||||||
if strings.TrimSpace(data) == "" {
|
|
||||||
return "data: ", "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var payload map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
|
||||||
return "data: " + data, "", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
|
||||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
|
||||||
return "data: " + data, "", errors.New(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
|
|
||||||
payload["model"] = originalModel
|
|
||||||
}
|
|
||||||
|
|
||||||
contentDelta, updated := extractSoraContent(payload)
|
|
||||||
if updated {
|
|
||||||
var rewritten string
|
|
||||||
if rewriteBuffer != nil {
|
|
||||||
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
|
|
||||||
} else {
|
|
||||||
rewritten = s.rewriteSoraContent(contentDelta)
|
|
||||||
}
|
|
||||||
if rewritten != contentDelta {
|
|
||||||
applySoraContent(payload, rewritten)
|
|
||||||
contentDelta = rewritten
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedData, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "data: " + data, contentDelta, nil
|
|
||||||
}
|
|
||||||
return "data: " + string(updatedData), contentDelta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractSoraContent(payload map[string]any) (string, bool) {
|
|
||||||
choices, ok := payload["choices"].([]any)
|
|
||||||
if !ok || len(choices) == 0 {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
choice, ok := choices[0].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
|
||||||
if content, ok := delta["content"].(string); ok {
|
|
||||||
return content, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if message, ok := choice["message"].(map[string]any); ok {
|
|
||||||
if content, ok := message["content"].(string); ok {
|
|
||||||
return content, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func applySoraContent(payload map[string]any, content string) {
|
|
||||||
choices, ok := payload["choices"].([]any)
|
|
||||||
if !ok || len(choices) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
choice, ok := choices[0].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
|
||||||
delta["content"] = content
|
|
||||||
choice["delta"] = delta
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if message, ok := choice["message"].(map[string]any); ok {
|
|
||||||
message["content"] = content
|
|
||||||
choice["message"] = message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
|
|
||||||
if buffer == nil {
|
|
||||||
return s.rewriteSoraContent(contentDelta)
|
|
||||||
}
|
|
||||||
if contentDelta == "" && *buffer == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
combined := *buffer + contentDelta
|
|
||||||
rewritten := s.rewriteSoraContent(combined)
|
|
||||||
bufferStart := s.findSoraRewriteBufferStart(rewritten)
|
|
||||||
if bufferStart < 0 {
|
|
||||||
*buffer = ""
|
|
||||||
return rewritten
|
|
||||||
}
|
|
||||||
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
|
|
||||||
bufferStart = len(rewritten) - soraRewriteBufferLimit
|
|
||||||
}
|
|
||||||
output := rewritten[:bufferStart]
|
|
||||||
*buffer = rewritten[bufferStart:]
|
|
||||||
return output
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
|
|
||||||
minIndex := -1
|
|
||||||
start := 0
|
|
||||||
for {
|
|
||||||
idx := strings.Index(content[start:], "![")
|
|
||||||
if idx < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
idx += start
|
|
||||||
if !hasSoraImageMatchAt(content, idx) {
|
|
||||||
if minIndex == -1 || idx < minIndex {
|
|
||||||
minIndex = idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
start = idx + 2
|
|
||||||
}
|
|
||||||
lower := strings.ToLower(content)
|
|
||||||
start = 0
|
|
||||||
for {
|
|
||||||
idx := strings.Index(lower[start:], "<video")
|
|
||||||
if idx < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
idx += start
|
|
||||||
if !hasSoraVideoMatchAt(content, idx) {
|
|
||||||
if minIndex == -1 || idx < minIndex {
|
|
||||||
minIndex = idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
start = idx + len("<video")
|
|
||||||
}
|
|
||||||
return minIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasSoraImageMatchAt(content string, idx int) bool {
|
|
||||||
if idx < 0 || idx >= len(content) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
|
|
||||||
return loc != nil && loc[0] == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasSoraVideoMatchAt(content string, idx int) bool {
|
|
||||||
if idx < 0 || idx >= len(content) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
|
|
||||||
return loc != nil && loc[0] == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
|
||||||
if content == "" {
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
|
|
||||||
sub := soraImageMarkdownRe.FindStringSubmatch(match)
|
|
||||||
if len(sub) < 2 {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
rewritten := s.rewriteSoraURL(sub[1])
|
|
||||||
if rewritten == sub[1] {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
return strings.Replace(match, sub[1], rewritten, 1)
|
|
||||||
})
|
|
||||||
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
|
|
||||||
sub := soraVideoHTMLRe.FindStringSubmatch(match)
|
|
||||||
if len(sub) < 2 {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
rewritten := s.rewriteSoraURL(sub[1])
|
|
||||||
if rewritten == sub[1] {
|
|
||||||
return match
|
|
||||||
}
|
|
||||||
return strings.Replace(match, sub[1], rewritten, 1)
|
|
||||||
})
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
|
|
||||||
if buffer == "" {
|
|
||||||
return "", "", nil
|
|
||||||
}
|
|
||||||
rewritten := s.rewriteSoraContent(buffer)
|
|
||||||
payload := map[string]any{
|
|
||||||
"choices": []any{
|
|
||||||
map[string]any{
|
|
||||||
"delta": map[string]any{
|
|
||||||
"content": rewritten,
|
|
||||||
},
|
|
||||||
"index": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if originalModel != "" {
|
|
||||||
payload["model"] = originalModel
|
|
||||||
}
|
|
||||||
updatedData, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
return "data: " + string(updatedData), rewritten, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
|
||||||
raw = strings.TrimSpace(raw)
|
|
||||||
if raw == "" {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
parsed, err := url.Parse(raw)
|
|
||||||
if err != nil {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
path := parsed.Path
|
|
||||||
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
return s.buildSoraMediaURL(path, parsed.RawQuery)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
|
|
||||||
if content == "" {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
|
|
||||||
return "video", []string{match[1]}
|
|
||||||
}
|
|
||||||
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
|
|
||||||
if len(imageMatches) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
urls := make([]string, 0, len(imageMatches))
|
|
||||||
for _, match := range imageMatches {
|
|
||||||
if len(match) > 1 {
|
|
||||||
urls = append(urls, match[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "image", urls
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildSoraNonStreamResponse(content, model string) map[string]any {
|
func buildSoraNonStreamResponse(content, model string) map[string]any {
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
|
"id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
|
||||||
@@ -813,10 +301,6 @@ func soraImageSizeFromModel(model string) string {
|
|||||||
return "360"
|
return "360"
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSoraPromptEnhanceModel(model string) bool {
|
|
||||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
|
|
||||||
}
|
|
||||||
|
|
||||||
func soraProErrorMessage(model, upstreamMsg string) string {
|
func soraProErrorMessage(model, upstreamMsg string) string {
|
||||||
modelLower := strings.ToLower(model)
|
modelLower := strings.ToLower(model)
|
||||||
if strings.Contains(modelLower, "sora2pro-hd") {
|
if strings.Contains(modelLower, "sora2pro-hd") {
|
||||||
@@ -1006,7 +490,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
|
|||||||
if status.ErrorMsg != "" {
|
if status.ErrorMsg != "" {
|
||||||
return nil, errors.New(status.ErrorMsg)
|
return nil, errors.New(status.ErrorMsg)
|
||||||
}
|
}
|
||||||
return nil, errors.New("Sora image generation failed")
|
return nil, errors.New("sora image generation failed")
|
||||||
}
|
}
|
||||||
if stream {
|
if stream {
|
||||||
s.maybeSendPing(c, &lastPing)
|
s.maybeSendPing(c, &lastPing)
|
||||||
@@ -1015,7 +499,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("Sora image generation timeout")
|
return nil, errors.New("sora image generation timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||||
@@ -1034,7 +518,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
|
|||||||
if status.ErrorMsg != "" {
|
if status.ErrorMsg != "" {
|
||||||
return nil, errors.New(status.ErrorMsg)
|
return nil, errors.New(status.ErrorMsg)
|
||||||
}
|
}
|
||||||
return nil, errors.New("Sora video generation failed")
|
return nil, errors.New("sora video generation failed")
|
||||||
}
|
}
|
||||||
if stream {
|
if stream {
|
||||||
s.maybeSendPing(c, &lastPing)
|
s.maybeSendPing(c, &lastPing)
|
||||||
@@ -1043,7 +527,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("Sora video generation timeout")
|
return nil, errors.New("sora video generation timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) pollInterval() time.Duration {
|
func (s *SoraGatewayService) pollInterval() time.Duration {
|
||||||
@@ -1159,9 +643,9 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
|||||||
text, img, vid := parseSoraMessageContent(content)
|
text, img, vid := parseSoraMessageContent(content)
|
||||||
if text != "" {
|
if text != "" {
|
||||||
if builder.Len() > 0 {
|
if builder.Len() > 0 {
|
||||||
builder.WriteString("\n")
|
_, _ = builder.WriteString("\n")
|
||||||
}
|
}
|
||||||
builder.WriteString(text)
|
_, _ = builder.WriteString(text)
|
||||||
}
|
}
|
||||||
if imageInput == "" && img != "" {
|
if imageInput == "" && img != "" {
|
||||||
imageInput = img
|
imageInput = img
|
||||||
@@ -1193,9 +677,9 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
|
|||||||
case "text":
|
case "text":
|
||||||
if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
|
if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
|
||||||
if builder.Len() > 0 {
|
if builder.Len() > 0 {
|
||||||
builder.WriteString("\n")
|
_, _ = builder.WriteString("\n")
|
||||||
}
|
}
|
||||||
builder.WriteString(txt)
|
_, _ = builder.WriteString(txt)
|
||||||
}
|
}
|
||||||
case "image_url":
|
case "image_url":
|
||||||
if imageInput == "" {
|
if imageInput == "" {
|
||||||
|
|||||||
532
backend/internal/service/sora_gateway_streaming_legacy.go
Normal file
532
backend/internal/service/sora_gateway_streaming_legacy.go
Normal file
@@ -0,0 +1,532 @@
|
|||||||
|
//nolint:unused
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||||
|
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||||
|
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||||
|
|
||||||
|
const soraRewriteBufferLimit = 2048
|
||||||
|
|
||||||
|
type soraStreamingResult struct {
|
||||||
|
mediaType string
|
||||||
|
mediaURLs []string
|
||||||
|
imageCount int
|
||||||
|
imageSize string
|
||||||
|
firstTokenMs *int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) setUpstreamRequestError(c *gin.Context, account *Account, err error) {
|
||||||
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: 0,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: safeErr,
|
||||||
|
})
|
||||||
|
if c != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream request failed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
|
if s.rateLimitService == nil || account == nil || resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, reqModel string) (*ForwardResult, error) {
|
||||||
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
if msg := soraProErrorMessage(reqModel, upstreamMsg); msg != "" {
|
||||||
|
upstreamMsg = msg
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
})
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
responsePayload := s.buildErrorPayload(respBody, upstreamMsg)
|
||||||
|
c.JSON(resp.StatusCode, responsePayload)
|
||||||
|
}
|
||||||
|
if upstreamMsg == "" {
|
||||||
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) buildErrorPayload(respBody []byte, overrideMessage string) map[string]any {
|
||||||
|
if len(respBody) > 0 {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(respBody, &payload); err == nil {
|
||||||
|
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||||
|
if overrideMessage != "" {
|
||||||
|
errObj["message"] = overrideMessage
|
||||||
|
}
|
||||||
|
payload["error"] = errObj
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": overrideMessage,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel string, clientStream bool) (*soraStreamingResult, error) {
|
||||||
|
if resp == nil {
|
||||||
|
return nil, errors.New("empty response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientStream {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||||
|
c.Header("x-request-id", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w := c.Writer
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
|
||||||
|
contentBuilder := strings.Builder{}
|
||||||
|
var firstTokenMs *int
|
||||||
|
var upstreamError error
|
||||||
|
rewriteBuffer := ""
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
maxLineSize := defaultMaxLineSize
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||||
|
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||||
|
}
|
||||||
|
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||||
|
|
||||||
|
sendLine := func(line string) error {
|
||||||
|
if !clientStream {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if soraSSEDataRe.MatchString(line) {
|
||||||
|
data := soraSSEDataRe.ReplaceAllString(line, "")
|
||||||
|
if data == "[DONE]" {
|
||||||
|
if rewriteBuffer != "" {
|
||||||
|
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if flushLine != "" {
|
||||||
|
if flushContent != "" {
|
||||||
|
if _, err := contentBuilder.WriteString(flushContent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := sendLine(flushLine); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriteBuffer = ""
|
||||||
|
}
|
||||||
|
if err := sendLine("data: [DONE]"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
|
||||||
|
if errEvent != nil && upstreamError == nil {
|
||||||
|
upstreamError = errEvent
|
||||||
|
}
|
||||||
|
if contentDelta != "" {
|
||||||
|
if firstTokenMs == nil {
|
||||||
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
|
firstTokenMs = &ms
|
||||||
|
}
|
||||||
|
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := sendLine(updatedLine); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := sendLine(line); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
if errors.Is(err, bufio.ErrTooLong) {
|
||||||
|
if clientStream {
|
||||||
|
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"response_too_large\"}\n\n")
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if ctx.Err() == context.DeadlineExceeded && s.rateLimitService != nil && account != nil {
|
||||||
|
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||||
|
}
|
||||||
|
if clientStream {
|
||||||
|
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"stream_read_error\"}\n\n")
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
content := contentBuilder.String()
|
||||||
|
mediaType, mediaURLs := s.extractSoraMedia(content)
|
||||||
|
if mediaType == "" && isSoraPromptEnhanceModel(originalModel) {
|
||||||
|
mediaType = "prompt"
|
||||||
|
}
|
||||||
|
imageSize := ""
|
||||||
|
imageCount := 0
|
||||||
|
if mediaType == "image" {
|
||||||
|
imageSize = soraImageSizeFromModel(originalModel)
|
||||||
|
imageCount = len(mediaURLs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if upstreamError != nil && !clientStream {
|
||||||
|
if c != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": upstreamError.Error(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil, upstreamError
|
||||||
|
}
|
||||||
|
|
||||||
|
if !clientStream {
|
||||||
|
response := buildSoraNonStreamResponse(content, originalModel)
|
||||||
|
if len(mediaURLs) > 0 {
|
||||||
|
response["media_url"] = mediaURLs[0]
|
||||||
|
if len(mediaURLs) > 1 {
|
||||||
|
response["media_urls"] = mediaURLs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &soraStreamingResult{
|
||||||
|
mediaType: mediaType,
|
||||||
|
mediaURLs: mediaURLs,
|
||||||
|
imageCount: imageCount,
|
||||||
|
imageSize: imageSize,
|
||||||
|
firstTokenMs: firstTokenMs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
|
||||||
|
if strings.TrimSpace(data) == "" {
|
||||||
|
return "data: ", "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||||
|
return "data: " + data, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||||
|
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||||
|
return "data: " + data, "", errors.New(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if model, ok := payload["model"].(string); ok && model != "" && originalModel != "" {
|
||||||
|
payload["model"] = originalModel
|
||||||
|
}
|
||||||
|
|
||||||
|
contentDelta, updated := extractSoraContent(payload)
|
||||||
|
if updated {
|
||||||
|
var rewritten string
|
||||||
|
if rewriteBuffer != nil {
|
||||||
|
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
|
||||||
|
} else {
|
||||||
|
rewritten = s.rewriteSoraContent(contentDelta)
|
||||||
|
}
|
||||||
|
if rewritten != contentDelta {
|
||||||
|
applySoraContent(payload, rewritten)
|
||||||
|
contentDelta = rewritten
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedData, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "data: " + data, contentDelta, nil
|
||||||
|
}
|
||||||
|
return "data: " + string(updatedData), contentDelta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSoraContent(payload map[string]any) (string, bool) {
|
||||||
|
choices, ok := payload["choices"].([]any)
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
choice, ok := choices[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||||
|
if content, ok := delta["content"].(string); ok {
|
||||||
|
return content, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if message, ok := choice["message"].(map[string]any); ok {
|
||||||
|
if content, ok := message["content"].(string); ok {
|
||||||
|
return content, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func applySoraContent(payload map[string]any, content string) {
|
||||||
|
choices, ok := payload["choices"].([]any)
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
choice, ok := choices[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||||
|
delta["content"] = content
|
||||||
|
choice["delta"] = delta
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if message, ok := choice["message"].(map[string]any); ok {
|
||||||
|
message["content"] = content
|
||||||
|
choice["message"] = message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
|
||||||
|
if buffer == nil {
|
||||||
|
return s.rewriteSoraContent(contentDelta)
|
||||||
|
}
|
||||||
|
if contentDelta == "" && *buffer == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
combined := *buffer + contentDelta
|
||||||
|
rewritten := s.rewriteSoraContent(combined)
|
||||||
|
bufferStart := s.findSoraRewriteBufferStart(rewritten)
|
||||||
|
if bufferStart < 0 {
|
||||||
|
*buffer = ""
|
||||||
|
return rewritten
|
||||||
|
}
|
||||||
|
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
|
||||||
|
bufferStart = len(rewritten) - soraRewriteBufferLimit
|
||||||
|
}
|
||||||
|
output := rewritten[:bufferStart]
|
||||||
|
*buffer = rewritten[bufferStart:]
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
|
||||||
|
minIndex := -1
|
||||||
|
start := 0
|
||||||
|
for {
|
||||||
|
idx := strings.Index(content[start:], "![")
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
idx += start
|
||||||
|
if !hasSoraImageMatchAt(content, idx) {
|
||||||
|
if minIndex == -1 || idx < minIndex {
|
||||||
|
minIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start = idx + 2
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(content)
|
||||||
|
start = 0
|
||||||
|
for {
|
||||||
|
idx := strings.Index(lower[start:], "<video")
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
idx += start
|
||||||
|
if !hasSoraVideoMatchAt(content, idx) {
|
||||||
|
if minIndex == -1 || idx < minIndex {
|
||||||
|
minIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start = idx + len("<video")
|
||||||
|
}
|
||||||
|
return minIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSoraImageMatchAt(content string, idx int) bool {
|
||||||
|
if idx < 0 || idx >= len(content) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
|
||||||
|
return loc != nil && loc[0] == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSoraVideoMatchAt(content string, idx int) bool {
|
||||||
|
if idx < 0 || idx >= len(content) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
|
||||||
|
return loc != nil && loc[0] == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
||||||
|
if content == "" {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
content = soraImageMarkdownRe.ReplaceAllStringFunc(content, func(match string) string {
|
||||||
|
sub := soraImageMarkdownRe.FindStringSubmatch(match)
|
||||||
|
if len(sub) < 2 {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
rewritten := s.rewriteSoraURL(sub[1])
|
||||||
|
if rewritten == sub[1] {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
return strings.Replace(match, sub[1], rewritten, 1)
|
||||||
|
})
|
||||||
|
content = soraVideoHTMLRe.ReplaceAllStringFunc(content, func(match string) string {
|
||||||
|
sub := soraVideoHTMLRe.FindStringSubmatch(match)
|
||||||
|
if len(sub) < 2 {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
rewritten := s.rewriteSoraURL(sub[1])
|
||||||
|
if rewritten == sub[1] {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
return strings.Replace(match, sub[1], rewritten, 1)
|
||||||
|
})
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
|
||||||
|
if buffer == "" {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
rewritten := s.rewriteSoraContent(buffer)
|
||||||
|
payload := map[string]any{
|
||||||
|
"choices": []any{
|
||||||
|
map[string]any{
|
||||||
|
"delta": map[string]any{
|
||||||
|
"content": rewritten,
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if originalModel != "" {
|
||||||
|
payload["model"] = originalModel
|
||||||
|
}
|
||||||
|
updatedData, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return "data: " + string(updatedData), rewritten, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
path := parsed.Path
|
||||||
|
if !strings.HasPrefix(path, "/tmp/") && !strings.HasPrefix(path, "/static/") {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return s.buildSoraMediaURL(path, parsed.RawQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) extractSoraMedia(content string) (string, []string) {
|
||||||
|
if content == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if match := soraVideoHTMLRe.FindStringSubmatch(content); len(match) > 1 {
|
||||||
|
return "video", []string{match[1]}
|
||||||
|
}
|
||||||
|
imageMatches := soraImageMarkdownRe.FindAllStringSubmatch(content, -1)
|
||||||
|
if len(imageMatches) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
urls := make([]string, 0, len(imageMatches))
|
||||||
|
for _, match := range imageMatches {
|
||||||
|
if len(match) > 1 {
|
||||||
|
urls = append(urls, match[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "image", urls
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSoraPromptEnhanceModel(model string) bool {
|
||||||
|
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "prompt-enhance")
|
||||||
|
}
|
||||||
@@ -29,7 +29,6 @@ type SoraMediaStorage struct {
|
|||||||
root string
|
root string
|
||||||
imageRoot string
|
imageRoot string
|
||||||
videoRoot string
|
videoRoot string
|
||||||
maxConcurrent int
|
|
||||||
downloadTimeout time.Duration
|
downloadTimeout time.Duration
|
||||||
maxDownloadBytes int64
|
maxDownloadBytes int64
|
||||||
fallbackToUpstream bool
|
fallbackToUpstream bool
|
||||||
@@ -93,7 +92,6 @@ func (s *SoraMediaStorage) refreshConfig() {
|
|||||||
if maxConcurrent <= 0 {
|
if maxConcurrent <= 0 {
|
||||||
maxConcurrent = 4
|
maxConcurrent = 4
|
||||||
}
|
}
|
||||||
s.maxConcurrent = maxConcurrent
|
|
||||||
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
|
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
|
||||||
if timeoutSeconds <= 0 {
|
if timeoutSeconds <= 0 {
|
||||||
timeoutSeconds = 120
|
timeoutSeconds = 120
|
||||||
|
|||||||
Reference in New Issue
Block a user